monai-weekly 1.5.dev2505__py3-none-any.whl → 1.5.dev2507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/transforms.py +1 -4
- monai/data/meta_tensor.py +5 -0
- monai/data/utils.py +6 -13
- monai/inferers/utils.py +1 -2
- monai/losses/dice.py +2 -14
- monai/losses/ds_loss.py +1 -3
- monai/networks/layers/simplelayers.py +2 -14
- monai/networks/utils.py +4 -16
- monai/transforms/compose.py +28 -11
- monai/transforms/croppad/array.py +1 -6
- monai/transforms/io/array.py +0 -1
- monai/transforms/transform.py +15 -6
- monai/transforms/utils.py +1 -2
- monai/utils/jupyter_utils.py +1 -1
- monai/utils/tf32.py +0 -10
- monai/visualize/class_activation_maps.py +5 -8
- monai/visualize/img2tensorboard.py +2 -2
- {monai_weekly-1.5.dev2505.dist-info → monai_weekly-1.5.dev2507.dist-info}/METADATA +2 -2
- monai_weekly-1.5.dev2507.dist-info/RECORD +1181 -0
- {monai_weekly-1.5.dev2505.dist-info → monai_weekly-1.5.dev2507.dist-info}/top_level.txt +1 -0
- tests/apps/__init__.py +10 -0
- tests/apps/deepedit/__init__.py +10 -0
- tests/apps/deepedit/test_deepedit_transforms.py +314 -0
- tests/apps/deepgrow/__init__.py +10 -0
- tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
- tests/apps/deepgrow/transforms/__init__.py +10 -0
- tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
- tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
- tests/apps/detection/__init__.py +10 -0
- tests/apps/detection/metrics/__init__.py +10 -0
- tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
- tests/apps/detection/networks/__init__.py +10 -0
- tests/apps/detection/networks/test_retinanet.py +210 -0
- tests/apps/detection/networks/test_retinanet_detector.py +203 -0
- tests/apps/detection/test_box_transform.py +370 -0
- tests/apps/detection/utils/__init__.py +10 -0
- tests/apps/detection/utils/test_anchor_box.py +88 -0
- tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
- tests/apps/detection/utils/test_box_coder.py +43 -0
- tests/apps/detection/utils/test_detector_boxselector.py +67 -0
- tests/apps/detection/utils/test_detector_utils.py +96 -0
- tests/apps/detection/utils/test_hardnegsampler.py +54 -0
- tests/apps/nuclick/__init__.py +10 -0
- tests/apps/nuclick/test_nuclick_transforms.py +259 -0
- tests/apps/pathology/__init__.py +10 -0
- tests/apps/pathology/handlers/__init__.py +10 -0
- tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
- tests/apps/pathology/test_lesion_froc.py +333 -0
- tests/apps/pathology/test_pathology_prob_nms.py +55 -0
- tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
- tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
- tests/apps/pathology/transforms/__init__.py +10 -0
- tests/apps/pathology/transforms/post/__init__.py +10 -0
- tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
- tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
- tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
- tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
- tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
- tests/apps/pathology/transforms/post/test_watershed.py +60 -0
- tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
- tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
- tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
- tests/apps/reconstruction/__init__.py +10 -0
- tests/apps/reconstruction/nets/__init__.py +10 -0
- tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
- tests/apps/reconstruction/test_complex_utils.py +77 -0
- tests/apps/reconstruction/test_fastmri_reader.py +82 -0
- tests/apps/reconstruction/test_mri_utils.py +37 -0
- tests/apps/reconstruction/transforms/__init__.py +10 -0
- tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
- tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
- tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
- tests/apps/test_auto3dseg_bundlegen.py +156 -0
- tests/apps/test_check_hash.py +53 -0
- tests/apps/test_cross_validation.py +74 -0
- tests/apps/test_decathlondataset.py +93 -0
- tests/apps/test_download_and_extract.py +70 -0
- tests/apps/test_download_url_yandex.py +45 -0
- tests/apps/test_mednistdataset.py +72 -0
- tests/apps/test_mmar_download.py +154 -0
- tests/apps/test_tciadataset.py +123 -0
- tests/apps/vista3d/__init__.py +10 -0
- tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
- tests/apps/vista3d/test_vista3d_sampler.py +100 -0
- tests/apps/vista3d/test_vista3d_transforms.py +94 -0
- tests/bundle/__init__.py +10 -0
- tests/bundle/test_bundle_ckpt_export.py +107 -0
- tests/bundle/test_bundle_download.py +435 -0
- tests/bundle/test_bundle_get_data.py +94 -0
- tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
- tests/bundle/test_bundle_trt_export.py +147 -0
- tests/bundle/test_bundle_utils.py +149 -0
- tests/bundle/test_bundle_verify_metadata.py +66 -0
- tests/bundle/test_bundle_verify_net.py +76 -0
- tests/bundle/test_bundle_workflow.py +272 -0
- tests/bundle/test_component_locator.py +38 -0
- tests/bundle/test_config_item.py +138 -0
- tests/bundle/test_config_parser.py +392 -0
- tests/bundle/test_reference_resolver.py +114 -0
- tests/config/__init__.py +10 -0
- tests/config/test_cv2_dist.py +53 -0
- tests/engines/__init__.py +10 -0
- tests/engines/test_ensemble_evaluator.py +94 -0
- tests/engines/test_prepare_batch_default.py +76 -0
- tests/engines/test_prepare_batch_default_dist.py +76 -0
- tests/engines/test_prepare_batch_diffusion.py +104 -0
- tests/engines/test_prepare_batch_extra_input.py +80 -0
- tests/fl/__init__.py +10 -0
- tests/fl/monai_algo/__init__.py +10 -0
- tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
- tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
- tests/fl/test_fl_monai_algo_stats.py +81 -0
- tests/fl/utils/__init__.py +10 -0
- tests/fl/utils/test_fl_exchange_object.py +63 -0
- tests/handlers/__init__.py +10 -0
- tests/handlers/test_handler_checkpoint_loader.py +182 -0
- tests/handlers/test_handler_checkpoint_saver.py +233 -0
- tests/handlers/test_handler_classification_saver.py +64 -0
- tests/handlers/test_handler_classification_saver_dist.py +77 -0
- tests/handlers/test_handler_clearml_image.py +65 -0
- tests/handlers/test_handler_clearml_stats.py +65 -0
- tests/handlers/test_handler_confusion_matrix.py +104 -0
- tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
- tests/handlers/test_handler_decollate_batch.py +66 -0
- tests/handlers/test_handler_early_stop.py +68 -0
- tests/handlers/test_handler_garbage_collector.py +73 -0
- tests/handlers/test_handler_hausdorff_distance.py +111 -0
- tests/handlers/test_handler_ignite_metric.py +191 -0
- tests/handlers/test_handler_lr_scheduler.py +94 -0
- tests/handlers/test_handler_mean_dice.py +98 -0
- tests/handlers/test_handler_mean_iou.py +76 -0
- tests/handlers/test_handler_metrics_reloaded.py +149 -0
- tests/handlers/test_handler_metrics_saver.py +89 -0
- tests/handlers/test_handler_metrics_saver_dist.py +120 -0
- tests/handlers/test_handler_mlflow.py +296 -0
- tests/handlers/test_handler_nvtx.py +93 -0
- tests/handlers/test_handler_panoptic_quality.py +89 -0
- tests/handlers/test_handler_parameter_scheduler.py +136 -0
- tests/handlers/test_handler_post_processing.py +74 -0
- tests/handlers/test_handler_prob_map_producer.py +111 -0
- tests/handlers/test_handler_regression_metrics.py +160 -0
- tests/handlers/test_handler_regression_metrics_dist.py +245 -0
- tests/handlers/test_handler_rocauc.py +48 -0
- tests/handlers/test_handler_rocauc_dist.py +54 -0
- tests/handlers/test_handler_stats.py +281 -0
- tests/handlers/test_handler_surface_distance.py +113 -0
- tests/handlers/test_handler_tb_image.py +61 -0
- tests/handlers/test_handler_tb_stats.py +166 -0
- tests/handlers/test_handler_validation.py +59 -0
- tests/handlers/test_trt_compile.py +145 -0
- tests/handlers/test_write_metrics_reports.py +68 -0
- tests/inferers/__init__.py +10 -0
- tests/inferers/test_avg_merger.py +179 -0
- tests/inferers/test_controlnet_inferers.py +1310 -0
- tests/inferers/test_diffusion_inferer.py +236 -0
- tests/inferers/test_latent_diffusion_inferer.py +824 -0
- tests/inferers/test_patch_inferer.py +309 -0
- tests/inferers/test_saliency_inferer.py +55 -0
- tests/inferers/test_slice_inferer.py +57 -0
- tests/inferers/test_sliding_window_inference.py +377 -0
- tests/inferers/test_sliding_window_splitter.py +284 -0
- tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
- tests/inferers/test_zarr_avg_merger.py +326 -0
- tests/integration/__init__.py +10 -0
- tests/integration/test_auto3dseg_ensemble.py +211 -0
- tests/integration/test_auto3dseg_hpo.py +189 -0
- tests/integration/test_deepedit_interaction.py +122 -0
- tests/integration/test_downsample_block.py +50 -0
- tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
- tests/integration/test_integration_autorunner.py +201 -0
- tests/integration/test_integration_bundle_run.py +240 -0
- tests/integration/test_integration_classification_2d.py +282 -0
- tests/integration/test_integration_determinism.py +95 -0
- tests/integration/test_integration_fast_train.py +231 -0
- tests/integration/test_integration_gpu_customization.py +159 -0
- tests/integration/test_integration_lazy_samples.py +219 -0
- tests/integration/test_integration_nnunetv2_runner.py +96 -0
- tests/integration/test_integration_segmentation_3d.py +304 -0
- tests/integration/test_integration_sliding_window.py +100 -0
- tests/integration/test_integration_stn.py +133 -0
- tests/integration/test_integration_unet_2d.py +67 -0
- tests/integration/test_integration_workers.py +61 -0
- tests/integration/test_integration_workflows.py +365 -0
- tests/integration/test_integration_workflows_adversarial.py +173 -0
- tests/integration/test_integration_workflows_gan.py +158 -0
- tests/integration/test_loader_semaphore.py +48 -0
- tests/integration/test_mapping_filed.py +122 -0
- tests/integration/test_meta_affine.py +183 -0
- tests/integration/test_metatensor_integration.py +114 -0
- tests/integration/test_module_list.py +76 -0
- tests/integration/test_one_of.py +283 -0
- tests/integration/test_pad_collation.py +124 -0
- tests/integration/test_reg_loss_integration.py +107 -0
- tests/integration/test_retinanet_predict_utils.py +154 -0
- tests/integration/test_seg_loss_integration.py +159 -0
- tests/integration/test_spatial_combine_transforms.py +185 -0
- tests/integration/test_testtimeaugmentation.py +186 -0
- tests/integration/test_vis_gradbased.py +69 -0
- tests/integration/test_vista3d_utils.py +159 -0
- tests/losses/__init__.py +10 -0
- tests/losses/deform/__init__.py +10 -0
- tests/losses/deform/test_bending_energy.py +88 -0
- tests/losses/deform/test_diffusion_loss.py +117 -0
- tests/losses/image_dissimilarity/__init__.py +10 -0
- tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
- tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
- tests/losses/test_adversarial_loss.py +94 -0
- tests/losses/test_barlow_twins_loss.py +109 -0
- tests/losses/test_cldice_loss.py +51 -0
- tests/losses/test_contrastive_loss.py +86 -0
- tests/losses/test_dice_ce_loss.py +123 -0
- tests/losses/test_dice_focal_loss.py +124 -0
- tests/losses/test_dice_loss.py +227 -0
- tests/losses/test_ds_loss.py +189 -0
- tests/losses/test_focal_loss.py +379 -0
- tests/losses/test_generalized_dice_focal_loss.py +85 -0
- tests/losses/test_generalized_dice_loss.py +221 -0
- tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
- tests/losses/test_giou_loss.py +62 -0
- tests/losses/test_hausdorff_loss.py +264 -0
- tests/losses/test_masked_dice_loss.py +152 -0
- tests/losses/test_masked_loss.py +87 -0
- tests/losses/test_multi_scale.py +86 -0
- tests/losses/test_nacl_loss.py +167 -0
- tests/losses/test_perceptual_loss.py +122 -0
- tests/losses/test_spectral_loss.py +86 -0
- tests/losses/test_ssim_loss.py +59 -0
- tests/losses/test_sure_loss.py +72 -0
- tests/losses/test_tversky_loss.py +198 -0
- tests/losses/test_unified_focal_loss.py +66 -0
- tests/metrics/__init__.py +10 -0
- tests/metrics/test_compute_confusion_matrix.py +294 -0
- tests/metrics/test_compute_f_beta.py +80 -0
- tests/metrics/test_compute_fid_metric.py +40 -0
- tests/metrics/test_compute_froc.py +143 -0
- tests/metrics/test_compute_generalized_dice.py +240 -0
- tests/metrics/test_compute_meandice.py +306 -0
- tests/metrics/test_compute_meaniou.py +223 -0
- tests/metrics/test_compute_mmd_metric.py +56 -0
- tests/metrics/test_compute_multiscalessim_metric.py +83 -0
- tests/metrics/test_compute_panoptic_quality.py +113 -0
- tests/metrics/test_compute_regression_metrics.py +196 -0
- tests/metrics/test_compute_roc_auc.py +155 -0
- tests/metrics/test_compute_variance.py +147 -0
- tests/metrics/test_cumulative.py +63 -0
- tests/metrics/test_cumulative_average.py +74 -0
- tests/metrics/test_cumulative_average_dist.py +48 -0
- tests/metrics/test_hausdorff_distance.py +209 -0
- tests/metrics/test_label_quality_score.py +134 -0
- tests/metrics/test_loss_metric.py +57 -0
- tests/metrics/test_metrics_reloaded.py +96 -0
- tests/metrics/test_ssim_metric.py +78 -0
- tests/metrics/test_surface_dice.py +416 -0
- tests/metrics/test_surface_distance.py +186 -0
- tests/networks/__init__.py +10 -0
- tests/networks/blocks/__init__.py +10 -0
- tests/networks/blocks/dints_block/__init__.py +10 -0
- tests/networks/blocks/dints_block/test_acn_block.py +41 -0
- tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
- tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
- tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
- tests/networks/blocks/test_adn.py +86 -0
- tests/networks/blocks/test_convolutions.py +156 -0
- tests/networks/blocks/test_crf_cpu.py +513 -0
- tests/networks/blocks/test_crf_cuda.py +528 -0
- tests/networks/blocks/test_crossattention.py +185 -0
- tests/networks/blocks/test_denseblock.py +105 -0
- tests/networks/blocks/test_dynunet_block.py +116 -0
- tests/networks/blocks/test_fpn_block.py +88 -0
- tests/networks/blocks/test_localnet_block.py +121 -0
- tests/networks/blocks/test_mlp.py +78 -0
- tests/networks/blocks/test_patchembedding.py +212 -0
- tests/networks/blocks/test_regunet_block.py +103 -0
- tests/networks/blocks/test_se_block.py +85 -0
- tests/networks/blocks/test_se_blocks.py +78 -0
- tests/networks/blocks/test_segresnet_block.py +57 -0
- tests/networks/blocks/test_selfattention.py +232 -0
- tests/networks/blocks/test_simple_aspp.py +87 -0
- tests/networks/blocks/test_spatialattention.py +55 -0
- tests/networks/blocks/test_subpixel_upsample.py +87 -0
- tests/networks/blocks/test_text_encoding.py +49 -0
- tests/networks/blocks/test_transformerblock.py +90 -0
- tests/networks/blocks/test_unetr_block.py +158 -0
- tests/networks/blocks/test_upsample_block.py +134 -0
- tests/networks/blocks/warp/__init__.py +10 -0
- tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
- tests/networks/blocks/warp/test_warp.py +250 -0
- tests/networks/layers/__init__.py +10 -0
- tests/networks/layers/filtering/__init__.py +10 -0
- tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
- tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
- tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
- tests/networks/layers/filtering/test_phl_cpu.py +259 -0
- tests/networks/layers/filtering/test_phl_cuda.py +167 -0
- tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
- tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
- tests/networks/layers/test_affine_transform.py +385 -0
- tests/networks/layers/test_apply_filter.py +89 -0
- tests/networks/layers/test_channel_pad.py +51 -0
- tests/networks/layers/test_conjugate_gradient.py +56 -0
- tests/networks/layers/test_drop_path.py +46 -0
- tests/networks/layers/test_gaussian.py +317 -0
- tests/networks/layers/test_gaussian_filter.py +206 -0
- tests/networks/layers/test_get_layers.py +65 -0
- tests/networks/layers/test_gmm.py +314 -0
- tests/networks/layers/test_grid_pull.py +93 -0
- tests/networks/layers/test_hilbert_transform.py +131 -0
- tests/networks/layers/test_lltm.py +62 -0
- tests/networks/layers/test_median_filter.py +52 -0
- tests/networks/layers/test_polyval.py +55 -0
- tests/networks/layers/test_preset_filters.py +136 -0
- tests/networks/layers/test_savitzky_golay_filter.py +141 -0
- tests/networks/layers/test_separable_filter.py +87 -0
- tests/networks/layers/test_skip_connection.py +48 -0
- tests/networks/layers/test_vector_quantizer.py +89 -0
- tests/networks/layers/test_weight_init.py +50 -0
- tests/networks/nets/__init__.py +10 -0
- tests/networks/nets/dints/__init__.py +10 -0
- tests/networks/nets/dints/test_dints_cell.py +110 -0
- tests/networks/nets/dints/test_dints_mixop.py +84 -0
- tests/networks/nets/regunet/__init__.py +10 -0
- tests/networks/nets/regunet/test_localnet.py +86 -0
- tests/networks/nets/regunet/test_regunet.py +88 -0
- tests/networks/nets/test_ahnet.py +224 -0
- tests/networks/nets/test_attentionunet.py +88 -0
- tests/networks/nets/test_autoencoder.py +95 -0
- tests/networks/nets/test_autoencoderkl.py +337 -0
- tests/networks/nets/test_basic_unet.py +102 -0
- tests/networks/nets/test_basic_unetplusplus.py +109 -0
- tests/networks/nets/test_bundle_init_bundle.py +55 -0
- tests/networks/nets/test_cell_sam_wrapper.py +58 -0
- tests/networks/nets/test_controlnet.py +215 -0
- tests/networks/nets/test_daf3d.py +62 -0
- tests/networks/nets/test_densenet.py +121 -0
- tests/networks/nets/test_diffusion_model_unet.py +585 -0
- tests/networks/nets/test_dints_network.py +168 -0
- tests/networks/nets/test_discriminator.py +59 -0
- tests/networks/nets/test_dynunet.py +181 -0
- tests/networks/nets/test_efficientnet.py +400 -0
- tests/networks/nets/test_flexible_unet.py +341 -0
- tests/networks/nets/test_fullyconnectednet.py +69 -0
- tests/networks/nets/test_generator.py +59 -0
- tests/networks/nets/test_globalnet.py +103 -0
- tests/networks/nets/test_highresnet.py +67 -0
- tests/networks/nets/test_hovernet.py +218 -0
- tests/networks/nets/test_mednext.py +122 -0
- tests/networks/nets/test_milmodel.py +92 -0
- tests/networks/nets/test_net_adapter.py +68 -0
- tests/networks/nets/test_network_consistency.py +86 -0
- tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
- tests/networks/nets/test_quicknat.py +57 -0
- tests/networks/nets/test_resnet.py +340 -0
- tests/networks/nets/test_segresnet.py +120 -0
- tests/networks/nets/test_segresnet_ds.py +156 -0
- tests/networks/nets/test_senet.py +151 -0
- tests/networks/nets/test_spade_autoencoderkl.py +295 -0
- tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
- tests/networks/nets/test_spade_vaegan.py +140 -0
- tests/networks/nets/test_swin_unetr.py +139 -0
- tests/networks/nets/test_torchvision_fc_model.py +201 -0
- tests/networks/nets/test_transchex.py +84 -0
- tests/networks/nets/test_transformer.py +108 -0
- tests/networks/nets/test_unet.py +208 -0
- tests/networks/nets/test_unetr.py +137 -0
- tests/networks/nets/test_varautoencoder.py +127 -0
- tests/networks/nets/test_vista3d.py +84 -0
- tests/networks/nets/test_vit.py +139 -0
- tests/networks/nets/test_vitautoenc.py +112 -0
- tests/networks/nets/test_vnet.py +81 -0
- tests/networks/nets/test_voxelmorph.py +280 -0
- tests/networks/nets/test_vqvae.py +274 -0
- tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
- tests/networks/schedulers/__init__.py +10 -0
- tests/networks/schedulers/test_scheduler_ddim.py +83 -0
- tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
- tests/networks/schedulers/test_scheduler_pndm.py +108 -0
- tests/networks/test_bundle_onnx_export.py +71 -0
- tests/networks/test_convert_to_onnx.py +106 -0
- tests/networks/test_convert_to_torchscript.py +46 -0
- tests/networks/test_convert_to_trt.py +79 -0
- tests/networks/test_save_state.py +73 -0
- tests/networks/test_to_onehot.py +63 -0
- tests/networks/test_varnet.py +63 -0
- tests/networks/utils/__init__.py +10 -0
- tests/networks/utils/test_copy_model_state.py +187 -0
- tests/networks/utils/test_eval_mode.py +34 -0
- tests/networks/utils/test_freeze_layers.py +61 -0
- tests/networks/utils/test_replace_module.py +98 -0
- tests/networks/utils/test_train_mode.py +34 -0
- tests/optimizers/__init__.py +10 -0
- tests/optimizers/test_generate_param_groups.py +105 -0
- tests/optimizers/test_lr_finder.py +108 -0
- tests/optimizers/test_lr_scheduler.py +71 -0
- tests/optimizers/test_optim_novograd.py +100 -0
- tests/profile_subclass/__init__.py +10 -0
- tests/profile_subclass/cprofile_profiling.py +29 -0
- tests/profile_subclass/min_classes.py +30 -0
- tests/profile_subclass/profiling.py +73 -0
- tests/profile_subclass/pyspy_profiling.py +41 -0
- tests/transforms/__init__.py +10 -0
- tests/transforms/compose/__init__.py +10 -0
- tests/transforms/compose/test_compose.py +758 -0
- tests/transforms/compose/test_some_of.py +258 -0
- tests/transforms/croppad/__init__.py +10 -0
- tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
- tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
- tests/transforms/functional/__init__.py +10 -0
- tests/transforms/functional/test_apply.py +75 -0
- tests/transforms/functional/test_resample.py +50 -0
- tests/transforms/intensity/__init__.py +10 -0
- tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
- tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
- tests/transforms/intensity/test_foreground_mask.py +98 -0
- tests/transforms/intensity/test_foreground_maskd.py +106 -0
- tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
- tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
- tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
- tests/transforms/inverse/__init__.py +10 -0
- tests/transforms/inverse/test_inverse_array.py +76 -0
- tests/transforms/inverse/test_traceable_transform.py +59 -0
- tests/transforms/post/__init__.py +10 -0
- tests/transforms/post/test_label_filterd.py +78 -0
- tests/transforms/post/test_probnms.py +72 -0
- tests/transforms/post/test_probnmsd.py +79 -0
- tests/transforms/post/test_remove_small_objects.py +102 -0
- tests/transforms/spatial/__init__.py +10 -0
- tests/transforms/spatial/test_convert_box_points.py +119 -0
- tests/transforms/spatial/test_grid_patch.py +134 -0
- tests/transforms/spatial/test_grid_patchd.py +102 -0
- tests/transforms/spatial/test_rand_grid_patch.py +150 -0
- tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
- tests/transforms/spatial/test_spatial_resampled.py +124 -0
- tests/transforms/test_activations.py +120 -0
- tests/transforms/test_activationsd.py +64 -0
- tests/transforms/test_adaptors.py +160 -0
- tests/transforms/test_add_coordinate_channels.py +53 -0
- tests/transforms/test_add_coordinate_channelsd.py +67 -0
- tests/transforms/test_add_extreme_points_channel.py +80 -0
- tests/transforms/test_add_extreme_points_channeld.py +77 -0
- tests/transforms/test_adjust_contrast.py +70 -0
- tests/transforms/test_adjust_contrastd.py +64 -0
- tests/transforms/test_affine.py +245 -0
- tests/transforms/test_affine_grid.py +152 -0
- tests/transforms/test_affined.py +190 -0
- tests/transforms/test_as_channel_last.py +38 -0
- tests/transforms/test_as_channel_lastd.py +44 -0
- tests/transforms/test_as_discrete.py +81 -0
- tests/transforms/test_as_discreted.py +82 -0
- tests/transforms/test_border_pad.py +49 -0
- tests/transforms/test_border_padd.py +45 -0
- tests/transforms/test_bounding_rect.py +54 -0
- tests/transforms/test_bounding_rectd.py +53 -0
- tests/transforms/test_cast_to_type.py +63 -0
- tests/transforms/test_cast_to_typed.py +74 -0
- tests/transforms/test_center_scale_crop.py +55 -0
- tests/transforms/test_center_scale_cropd.py +56 -0
- tests/transforms/test_center_spatial_crop.py +56 -0
- tests/transforms/test_center_spatial_cropd.py +63 -0
- tests/transforms/test_classes_to_indices.py +93 -0
- tests/transforms/test_classes_to_indicesd.py +110 -0
- tests/transforms/test_clip_intensity_percentiles.py +196 -0
- tests/transforms/test_clip_intensity_percentilesd.py +193 -0
- tests/transforms/test_compose_get_number_conversions.py +127 -0
- tests/transforms/test_concat_itemsd.py +82 -0
- tests/transforms/test_convert_to_multi_channel.py +59 -0
- tests/transforms/test_convert_to_multi_channeld.py +37 -0
- tests/transforms/test_copy_itemsd.py +86 -0
- tests/transforms/test_create_grid_and_affine.py +274 -0
- tests/transforms/test_crop_foreground.py +164 -0
- tests/transforms/test_crop_foregroundd.py +205 -0
- tests/transforms/test_cucim_dict_transform.py +142 -0
- tests/transforms/test_cucim_transform.py +141 -0
- tests/transforms/test_data_stats.py +221 -0
- tests/transforms/test_data_statsd.py +249 -0
- tests/transforms/test_delete_itemsd.py +58 -0
- tests/transforms/test_detect_envelope.py +159 -0
- tests/transforms/test_distance_transform_edt.py +202 -0
- tests/transforms/test_divisible_pad.py +49 -0
- tests/transforms/test_divisible_padd.py +42 -0
- tests/transforms/test_ensure_channel_first.py +113 -0
- tests/transforms/test_ensure_channel_firstd.py +85 -0
- tests/transforms/test_ensure_type.py +94 -0
- tests/transforms/test_ensure_typed.py +110 -0
- tests/transforms/test_fg_bg_to_indices.py +83 -0
- tests/transforms/test_fg_bg_to_indicesd.py +78 -0
- tests/transforms/test_fill_holes.py +207 -0
- tests/transforms/test_fill_holesd.py +209 -0
- tests/transforms/test_flatten_sub_keysd.py +64 -0
- tests/transforms/test_flip.py +83 -0
- tests/transforms/test_flipd.py +90 -0
- tests/transforms/test_fourier.py +70 -0
- tests/transforms/test_gaussian_sharpen.py +92 -0
- tests/transforms/test_gaussian_sharpend.py +92 -0
- tests/transforms/test_gaussian_smooth.py +96 -0
- tests/transforms/test_gaussian_smoothd.py +96 -0
- tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
- tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
- tests/transforms/test_generate_spatial_bounding_box.py +114 -0
- tests/transforms/test_get_extreme_points.py +57 -0
- tests/transforms/test_gibbs_noise.py +75 -0
- tests/transforms/test_gibbs_noised.py +88 -0
- tests/transforms/test_grid_distortion.py +113 -0
- tests/transforms/test_grid_distortiond.py +87 -0
- tests/transforms/test_grid_split.py +88 -0
- tests/transforms/test_grid_splitd.py +96 -0
- tests/transforms/test_histogram_normalize.py +59 -0
- tests/transforms/test_histogram_normalized.py +59 -0
- tests/transforms/test_image_filter.py +259 -0
- tests/transforms/test_intensity_stats.py +73 -0
- tests/transforms/test_intensity_statsd.py +90 -0
- tests/transforms/test_inverse.py +521 -0
- tests/transforms/test_inverse_collation.py +147 -0
- tests/transforms/test_invert.py +105 -0
- tests/transforms/test_invertd.py +142 -0
- tests/transforms/test_k_space_spike_noise.py +81 -0
- tests/transforms/test_k_space_spike_noised.py +98 -0
- tests/transforms/test_keep_largest_connected_component.py +419 -0
- tests/transforms/test_keep_largest_connected_componentd.py +348 -0
- tests/transforms/test_label_filter.py +78 -0
- tests/transforms/test_label_to_contour.py +179 -0
- tests/transforms/test_label_to_contourd.py +182 -0
- tests/transforms/test_label_to_mask.py +69 -0
- tests/transforms/test_label_to_maskd.py +70 -0
- tests/transforms/test_load_image.py +502 -0
- tests/transforms/test_load_imaged.py +198 -0
- tests/transforms/test_load_spacing_orientation.py +149 -0
- tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
- tests/transforms/test_map_binary_to_indices.py +75 -0
- tests/transforms/test_map_classes_to_indices.py +135 -0
- tests/transforms/test_map_label_value.py +89 -0
- tests/transforms/test_map_label_valued.py +85 -0
- tests/transforms/test_map_transform.py +45 -0
- tests/transforms/test_mask_intensity.py +74 -0
- tests/transforms/test_mask_intensityd.py +68 -0
- tests/transforms/test_mean_ensemble.py +77 -0
- tests/transforms/test_mean_ensembled.py +91 -0
- tests/transforms/test_median_smooth.py +41 -0
- tests/transforms/test_median_smoothd.py +65 -0
- tests/transforms/test_morphological_ops.py +101 -0
- tests/transforms/test_nifti_endianness.py +107 -0
- tests/transforms/test_normalize_intensity.py +143 -0
- tests/transforms/test_normalize_intensityd.py +81 -0
- tests/transforms/test_nvtx_decorator.py +289 -0
- tests/transforms/test_nvtx_transform.py +143 -0
- tests/transforms/test_orientation.py +247 -0
- tests/transforms/test_orientationd.py +112 -0
- tests/transforms/test_rand_adjust_contrast.py +45 -0
- tests/transforms/test_rand_adjust_contrastd.py +44 -0
- tests/transforms/test_rand_affine.py +201 -0
- tests/transforms/test_rand_affine_grid.py +212 -0
- tests/transforms/test_rand_affined.py +281 -0
- tests/transforms/test_rand_axis_flip.py +50 -0
- tests/transforms/test_rand_axis_flipd.py +50 -0
- tests/transforms/test_rand_bias_field.py +69 -0
- tests/transforms/test_rand_bias_fieldd.py +65 -0
- tests/transforms/test_rand_coarse_dropout.py +110 -0
- tests/transforms/test_rand_coarse_dropoutd.py +107 -0
- tests/transforms/test_rand_coarse_shuffle.py +65 -0
- tests/transforms/test_rand_coarse_shuffled.py +59 -0
- tests/transforms/test_rand_crop_by_label_classes.py +170 -0
- tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
- tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
- tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
- tests/transforms/test_rand_cucim_dict_transform.py +162 -0
- tests/transforms/test_rand_cucim_transform.py +162 -0
- tests/transforms/test_rand_deform_grid.py +138 -0
- tests/transforms/test_rand_elastic_2d.py +127 -0
- tests/transforms/test_rand_elastic_3d.py +104 -0
- tests/transforms/test_rand_elasticd_2d.py +177 -0
- tests/transforms/test_rand_elasticd_3d.py +156 -0
- tests/transforms/test_rand_flip.py +60 -0
- tests/transforms/test_rand_flipd.py +55 -0
- tests/transforms/test_rand_gaussian_noise.py +48 -0
- tests/transforms/test_rand_gaussian_noised.py +54 -0
- tests/transforms/test_rand_gaussian_sharpen.py +140 -0
- tests/transforms/test_rand_gaussian_sharpend.py +143 -0
- tests/transforms/test_rand_gaussian_smooth.py +98 -0
- tests/transforms/test_rand_gaussian_smoothd.py +98 -0
- tests/transforms/test_rand_gibbs_noise.py +103 -0
- tests/transforms/test_rand_gibbs_noised.py +117 -0
- tests/transforms/test_rand_grid_distortion.py +99 -0
- tests/transforms/test_rand_grid_distortiond.py +90 -0
- tests/transforms/test_rand_histogram_shift.py +92 -0
- tests/transforms/test_rand_k_space_spike_noise.py +92 -0
- tests/transforms/test_rand_k_space_spike_noised.py +76 -0
- tests/transforms/test_rand_rician_noise.py +52 -0
- tests/transforms/test_rand_rician_noised.py +52 -0
- tests/transforms/test_rand_rotate.py +166 -0
- tests/transforms/test_rand_rotate90.py +100 -0
- tests/transforms/test_rand_rotate90d.py +112 -0
- tests/transforms/test_rand_rotated.py +187 -0
- tests/transforms/test_rand_scale_crop.py +78 -0
- tests/transforms/test_rand_scale_cropd.py +98 -0
- tests/transforms/test_rand_scale_intensity.py +54 -0
- tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
- tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
- tests/transforms/test_rand_scale_intensityd.py +53 -0
- tests/transforms/test_rand_shift_intensity.py +52 -0
- tests/transforms/test_rand_shift_intensityd.py +67 -0
- tests/transforms/test_rand_simulate_low_resolution.py +83 -0
- tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
- tests/transforms/test_rand_spatial_crop.py +107 -0
- tests/transforms/test_rand_spatial_crop_samples.py +128 -0
- tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
- tests/transforms/test_rand_spatial_cropd.py +112 -0
- tests/transforms/test_rand_std_shift_intensity.py +43 -0
- tests/transforms/test_rand_std_shift_intensityd.py +38 -0
- tests/transforms/test_rand_zoom.py +105 -0
- tests/transforms/test_rand_zoomd.py +108 -0
- tests/transforms/test_randidentity.py +49 -0
- tests/transforms/test_random_order.py +144 -0
- tests/transforms/test_randtorchvisiond.py +65 -0
- tests/transforms/test_regularization.py +139 -0
- tests/transforms/test_remove_repeated_channel.py +34 -0
- tests/transforms/test_remove_repeated_channeld.py +44 -0
- tests/transforms/test_repeat_channel.py +34 -0
- tests/transforms/test_repeat_channeld.py +41 -0
- tests/transforms/test_resample_backends.py +65 -0
- tests/transforms/test_resample_to_match.py +110 -0
- tests/transforms/test_resample_to_matchd.py +93 -0
- tests/transforms/test_resampler.py +165 -0
- tests/transforms/test_resize.py +140 -0
- tests/transforms/test_resize_with_pad_or_crop.py +91 -0
- tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
- tests/transforms/test_resized.py +163 -0
- tests/transforms/test_rotate.py +160 -0
- tests/transforms/test_rotate90.py +212 -0
- tests/transforms/test_rotate90d.py +106 -0
- tests/transforms/test_rotated.py +179 -0
- tests/transforms/test_save_classificationd.py +109 -0
- tests/transforms/test_save_image.py +80 -0
- tests/transforms/test_save_imaged.py +130 -0
- tests/transforms/test_savitzky_golay_smooth.py +73 -0
- tests/transforms/test_savitzky_golay_smoothd.py +73 -0
- tests/transforms/test_scale_intensity.py +76 -0
- tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
- tests/transforms/test_scale_intensity_range.py +41 -0
- tests/transforms/test_scale_intensity_ranged.py +40 -0
- tests/transforms/test_scale_intensityd.py +57 -0
- tests/transforms/test_select_itemsd.py +41 -0
- tests/transforms/test_shift_intensity.py +31 -0
- tests/transforms/test_shift_intensityd.py +44 -0
- tests/transforms/test_signal_continuouswavelet.py +44 -0
- tests/transforms/test_signal_fillempty.py +52 -0
- tests/transforms/test_signal_fillemptyd.py +60 -0
- tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
- tests/transforms/test_signal_rand_add_sine.py +52 -0
- tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
- tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
- tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
- tests/transforms/test_signal_rand_drop.py +50 -0
- tests/transforms/test_signal_rand_scale.py +52 -0
- tests/transforms/test_signal_rand_shift.py +55 -0
- tests/transforms/test_signal_remove_frequency.py +71 -0
- tests/transforms/test_smooth_field.py +177 -0
- tests/transforms/test_sobel_gradient.py +189 -0
- tests/transforms/test_sobel_gradientd.py +212 -0
- tests/transforms/test_spacing.py +381 -0
- tests/transforms/test_spacingd.py +178 -0
- tests/transforms/test_spatial_crop.py +82 -0
- tests/transforms/test_spatial_cropd.py +74 -0
- tests/transforms/test_spatial_pad.py +57 -0
- tests/transforms/test_spatial_padd.py +43 -0
- tests/transforms/test_spatial_resample.py +235 -0
- tests/transforms/test_squeezedim.py +62 -0
- tests/transforms/test_squeezedimd.py +98 -0
- tests/transforms/test_std_shift_intensity.py +76 -0
- tests/transforms/test_std_shift_intensityd.py +74 -0
- tests/transforms/test_threshold_intensity.py +38 -0
- tests/transforms/test_threshold_intensityd.py +58 -0
- tests/transforms/test_to_contiguous.py +47 -0
- tests/transforms/test_to_cupy.py +112 -0
- tests/transforms/test_to_cupyd.py +76 -0
- tests/transforms/test_to_device.py +42 -0
- tests/transforms/test_to_deviced.py +37 -0
- tests/transforms/test_to_numpy.py +85 -0
- tests/transforms/test_to_numpyd.py +68 -0
- tests/transforms/test_to_pil.py +52 -0
- tests/transforms/test_to_pild.py +50 -0
- tests/transforms/test_to_tensor.py +60 -0
- tests/transforms/test_to_tensord.py +71 -0
- tests/transforms/test_torchvision.py +66 -0
- tests/transforms/test_torchvisiond.py +63 -0
- tests/transforms/test_transform.py +62 -0
- tests/transforms/test_transpose.py +41 -0
- tests/transforms/test_transposed.py +52 -0
- tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
- tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
- tests/transforms/test_vote_ensemble.py +84 -0
- tests/transforms/test_vote_ensembled.py +107 -0
- tests/transforms/test_with_allow_missing_keys.py +76 -0
- tests/transforms/test_zoom.py +120 -0
- tests/transforms/test_zoomd.py +94 -0
- tests/transforms/transform/__init__.py +10 -0
- tests/transforms/transform/test_randomizable.py +52 -0
- tests/transforms/transform/test_randomizable_transform_type.py +37 -0
- tests/transforms/utility/__init__.py +10 -0
- tests/transforms/utility/test_apply_transform_to_points.py +81 -0
- tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
- tests/transforms/utility/test_identity.py +29 -0
- tests/transforms/utility/test_identityd.py +30 -0
- tests/transforms/utility/test_lambda.py +71 -0
- tests/transforms/utility/test_lambdad.py +83 -0
- tests/transforms/utility/test_rand_lambda.py +87 -0
- tests/transforms/utility/test_rand_lambdad.py +77 -0
- tests/transforms/utility/test_simulatedelay.py +36 -0
- tests/transforms/utility/test_simulatedelayd.py +36 -0
- tests/transforms/utility/test_splitdim.py +52 -0
- tests/transforms/utility/test_splitdimd.py +96 -0
- tests/transforms/utils/__init__.py +10 -0
- tests/transforms/utils/test_correct_crop_centers.py +36 -0
- tests/transforms/utils/test_get_unique_labels.py +45 -0
- tests/transforms/utils/test_print_transform_backends.py +29 -0
- tests/transforms/utils/test_soft_clip.py +125 -0
- tests/utils/__init__.py +10 -0
- tests/utils/enums/__init__.py +10 -0
- tests/utils/enums/test_hovernet_loss.py +190 -0
- tests/utils/enums/test_ordering.py +289 -0
- tests/utils/enums/test_wsireader.py +663 -0
- tests/utils/misc/__init__.py +10 -0
- tests/utils/misc/test_ensure_tuple.py +53 -0
- tests/utils/misc/test_monai_env_vars.py +44 -0
- tests/utils/misc/test_monai_utils_misc.py +103 -0
- tests/utils/misc/test_str2bool.py +34 -0
- tests/utils/misc/test_str2list.py +33 -0
- tests/utils/test_alias.py +44 -0
- tests/utils/test_component_store.py +73 -0
- tests/utils/test_deprecated.py +455 -0
- tests/utils/test_enum_bound_interp.py +75 -0
- tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
- tests/utils/test_get_package_version.py +34 -0
- tests/utils/test_handler_logfile.py +84 -0
- tests/utils/test_handler_metric_logger.py +62 -0
- tests/utils/test_list_to_dict.py +43 -0
- tests/utils/test_look_up_option.py +87 -0
- tests/utils/test_optional_import.py +80 -0
- tests/utils/test_pad_mode.py +39 -0
- tests/utils/test_profiling.py +208 -0
- tests/utils/test_rankfilter_dist.py +77 -0
- tests/utils/test_require_pkg.py +83 -0
- tests/utils/test_sample_slices.py +43 -0
- tests/utils/test_set_determinism.py +74 -0
- tests/utils/test_squeeze_unsqueeze.py +71 -0
- tests/utils/test_state_cacher.py +67 -0
- tests/utils/test_torchscript_utils.py +113 -0
- tests/utils/test_version.py +91 -0
- tests/utils/test_version_after.py +65 -0
- tests/utils/type_conversion/__init__.py +10 -0
- tests/utils/type_conversion/test_convert_data_type.py +152 -0
- tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
- tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
- tests/visualize/__init__.py +10 -0
- tests/visualize/test_img2tensorboard.py +46 -0
- tests/visualize/test_occlusion_sensitivity.py +128 -0
- tests/visualize/test_plot_2d_or_3d_image.py +74 -0
- tests/visualize/test_vis_cam.py +98 -0
- tests/visualize/test_vis_gradcam.py +211 -0
- tests/visualize/utils/__init__.py +10 -0
- tests/visualize/utils/test_blend_images.py +63 -0
- tests/visualize/utils/test_matshow3d.py +133 -0
- monai_weekly-1.5.dev2505.dist-info/RECORD +0 -427
- {monai_weekly-1.5.dev2505.dist-info → monai_weekly-1.5.dev2507.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2505.dist-info → monai_weekly-1.5.dev2507.dist-info}/WHEEL +0 -0
@@ -0,0 +1,62 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
|
18
|
+
from monai.utils import optional_import
|
19
|
+
from tests.test_utils import SkipIfNoModule
|
20
|
+
|
21
|
+
try:
|
22
|
+
_, has_ignite = optional_import("ignite")
|
23
|
+
from ignite.engine import Engine, Events
|
24
|
+
|
25
|
+
from monai.handlers import MetricLogger
|
26
|
+
except ImportError:
|
27
|
+
has_ignite = False
|
28
|
+
|
29
|
+
|
30
|
+
class TestHandlerMetricLogger(unittest.TestCase):
|
31
|
+
@SkipIfNoModule("ignite")
|
32
|
+
def test_metric_logging(self):
|
33
|
+
dummy_name = "dummy"
|
34
|
+
|
35
|
+
# set up engine
|
36
|
+
def _train_func(engine, batch):
|
37
|
+
return torch.tensor(0.0)
|
38
|
+
|
39
|
+
engine = Engine(_train_func)
|
40
|
+
|
41
|
+
# set up dummy metric
|
42
|
+
@engine.on(Events.EPOCH_COMPLETED)
|
43
|
+
def _update_metric(engine):
|
44
|
+
engine.state.metrics[dummy_name] = 1
|
45
|
+
|
46
|
+
# set up testing handler
|
47
|
+
handler = MetricLogger(loss_transform=lambda output: output.item())
|
48
|
+
handler.attach(engine)
|
49
|
+
|
50
|
+
engine.run(range(3), max_epochs=2)
|
51
|
+
|
52
|
+
expected_loss = [(1, 0.0), (2, 0.0), (3, 0.0), (4, 0.0), (5, 0.0), (6, 0.0)]
|
53
|
+
expected_metric = [(4, 1), (5, 1), (6, 1)]
|
54
|
+
|
55
|
+
self.assertSetEqual({dummy_name}, set(handler.metrics))
|
56
|
+
|
57
|
+
self.assertListEqual(expected_loss, handler.loss)
|
58
|
+
self.assertListEqual(expected_metric, handler.metrics[dummy_name])
|
59
|
+
|
60
|
+
|
61
|
+
if __name__ == "__main__":
|
62
|
+
unittest.main()
|
@@ -0,0 +1,43 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
from parameterized import parameterized
|
17
|
+
|
18
|
+
from monai.utils import list_to_dict
|
19
|
+
|
20
|
+
TEST_CASE_1 = [["a=1", "b=2", "c=3", "d=4"], {"a": 1, "b": 2, "c": 3, "d": 4}]
|
21
|
+
|
22
|
+
TEST_CASE_2 = [["a=a", "b=b", "c=c", "d=d"], {"a": "a", "b": "b", "c": "c", "d": "d"}]
|
23
|
+
|
24
|
+
TEST_CASE_3 = [["a=0.1", "b=0.2", "c=0.3", "d=0.4"], {"a": 0.1, "b": 0.2, "c": 0.3, "d": 0.4}]
|
25
|
+
|
26
|
+
TEST_CASE_4 = [["a=True", "b=TRUE", "c=false", "d=FALSE"], {"a": True, "b": True, "c": False, "d": False}]
|
27
|
+
|
28
|
+
TEST_CASE_5 = [
|
29
|
+
["a='1'", "b=2 ", " c = 3", "d='test'", "'e'=0", "f", "g=None"],
|
30
|
+
{"a": 1, "b": 2, "c": 3, "d": "test", "e": 0, "f": None, "g": None},
|
31
|
+
]
|
32
|
+
|
33
|
+
|
34
|
+
class TestListToDict(unittest.TestCase):
|
35
|
+
|
36
|
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
|
37
|
+
def test_value_shape(self, input, output):
|
38
|
+
result = list_to_dict(input)
|
39
|
+
self.assertDictEqual(result, output)
|
40
|
+
|
41
|
+
|
42
|
+
if __name__ == "__main__":
|
43
|
+
unittest.main()
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
from enum import Enum
|
16
|
+
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.utils import StrEnum, look_up_option
|
20
|
+
|
21
|
+
|
22
|
+
class _CaseEnum(Enum):
|
23
|
+
CONST = "constant"
|
24
|
+
EMPTY = "empty"
|
25
|
+
|
26
|
+
|
27
|
+
class _CaseEnum1(Enum):
|
28
|
+
CONST = "constant"
|
29
|
+
EMPTY = "empty"
|
30
|
+
|
31
|
+
|
32
|
+
class _CaseStrEnum(StrEnum):
|
33
|
+
MODE_A = "A"
|
34
|
+
MODE_B = "B"
|
35
|
+
|
36
|
+
|
37
|
+
TEST_CASES = (
|
38
|
+
("test", ("test", "test1"), "test"),
|
39
|
+
("test1", {"test1", "test"}, "test1"),
|
40
|
+
(2, {1: "test", 2: "valid"}, "valid"),
|
41
|
+
(_CaseEnum.EMPTY, _CaseEnum, _CaseEnum.EMPTY),
|
42
|
+
("empty", _CaseEnum, _CaseEnum.EMPTY),
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class TestLookUpOption(unittest.TestCase):
|
47
|
+
|
48
|
+
@parameterized.expand(TEST_CASES)
|
49
|
+
def test_look_up(self, input_str, supported, expected):
|
50
|
+
output = look_up_option(input_str, supported)
|
51
|
+
self.assertEqual(output, expected)
|
52
|
+
|
53
|
+
def test_default(self):
|
54
|
+
output = look_up_option("not here", {"a", "b"}, default=None)
|
55
|
+
self.assertEqual(output, None)
|
56
|
+
|
57
|
+
def test_str_enum(self):
|
58
|
+
output = look_up_option("C", {"A", "B"}, default=None)
|
59
|
+
self.assertIsNone(output)
|
60
|
+
self.assertEqual(list(_CaseStrEnum), ["A", "B"])
|
61
|
+
self.assertEqual(_CaseStrEnum.MODE_A, "A")
|
62
|
+
self.assertEqual(str(_CaseStrEnum.MODE_A), "A")
|
63
|
+
self.assertEqual(look_up_option("A", _CaseStrEnum), "A")
|
64
|
+
|
65
|
+
def test_no_found(self):
|
66
|
+
with self.assertRaisesRegex(ValueError, "Unsupported"):
|
67
|
+
look_up_option("not here", {"a", "b"})
|
68
|
+
with self.assertRaisesRegex(ValueError, "Unsupported"):
|
69
|
+
look_up_option("not here", ["a", "b"])
|
70
|
+
with self.assertRaisesRegex(ValueError, "Unsupported"):
|
71
|
+
look_up_option("not here", {"a": 1, "b": 2})
|
72
|
+
with self.assertRaisesRegex(ValueError, "did you mean"):
|
73
|
+
look_up_option(3, {1: "a", 2: "b", "c": 3})
|
74
|
+
with self.assertRaisesRegex(ValueError, "did.*empty"):
|
75
|
+
look_up_option("empy", _CaseEnum)
|
76
|
+
with self.assertRaisesRegex(ValueError, "Unsupported"):
|
77
|
+
look_up_option(_CaseEnum1.EMPTY, _CaseEnum)
|
78
|
+
with self.assertRaisesRegex(ValueError, "Unsupported"):
|
79
|
+
look_up_option(None, _CaseEnum)
|
80
|
+
with self.assertRaisesRegex(ValueError, "No"):
|
81
|
+
look_up_option(None, None)
|
82
|
+
with self.assertRaisesRegex(ValueError, "No"):
|
83
|
+
look_up_option("test", None)
|
84
|
+
|
85
|
+
|
86
|
+
if __name__ == "__main__":
|
87
|
+
unittest.main()
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
from parameterized import parameterized
|
17
|
+
|
18
|
+
from monai.utils import OptionalImportError, exact_version, optional_import
|
19
|
+
|
20
|
+
|
21
|
+
class TestOptionalImport(unittest.TestCase):
|
22
|
+
|
23
|
+
@parameterized.expand(["not_a_module", "torch.randint"])
|
24
|
+
def test_default(self, import_module):
|
25
|
+
my_module, flag = optional_import(import_module)
|
26
|
+
self.assertFalse(flag)
|
27
|
+
with self.assertRaises(OptionalImportError):
|
28
|
+
my_module.test
|
29
|
+
|
30
|
+
def test_import_valid(self):
|
31
|
+
my_module, flag = optional_import("torch")
|
32
|
+
self.assertTrue(flag)
|
33
|
+
print(my_module.randint(1, 2, (1, 2)))
|
34
|
+
|
35
|
+
def test_import_wrong_number(self):
|
36
|
+
my_module, flag = optional_import("torch", "42")
|
37
|
+
with self.assertRaisesRegex(OptionalImportError, "version"):
|
38
|
+
my_module.nn
|
39
|
+
self.assertFalse(flag)
|
40
|
+
with self.assertRaisesRegex(OptionalImportError, "version"):
|
41
|
+
my_module.randint(1, 2, (1, 2))
|
42
|
+
with self.assertRaisesRegex(ValueError, "invalid literal"):
|
43
|
+
my_module, flag = optional_import("torch", "test") # version should be number.number
|
44
|
+
my_module.nn
|
45
|
+
self.assertTrue(flag)
|
46
|
+
print(my_module.randint(1, 2, (1, 2)))
|
47
|
+
|
48
|
+
@parameterized.expand(["0", "0.0.0.1", "1.1.0"])
|
49
|
+
def test_import_good_number(self, version_number):
|
50
|
+
my_module, flag = optional_import("torch", version_number)
|
51
|
+
my_module.nn
|
52
|
+
self.assertTrue(flag)
|
53
|
+
print(my_module.randint(1, 2, (1, 2)))
|
54
|
+
|
55
|
+
def test_import_exact(self):
|
56
|
+
my_module, flag = optional_import("torch", "0", exact_version)
|
57
|
+
with self.assertRaisesRegex(OptionalImportError, "exact_version"):
|
58
|
+
my_module.nn
|
59
|
+
self.assertFalse(flag)
|
60
|
+
with self.assertRaisesRegex(OptionalImportError, "exact_version"):
|
61
|
+
my_module.randint(1, 2, (1, 2))
|
62
|
+
|
63
|
+
def test_import_method(self):
|
64
|
+
nn, flag = optional_import("torch", "1.1", name="nn")
|
65
|
+
self.assertTrue(flag)
|
66
|
+
print(nn.functional)
|
67
|
+
|
68
|
+
def test_additional(self):
|
69
|
+
test_args = {"a": "test", "b": "test"}
|
70
|
+
|
71
|
+
def versioning(module, ver, a):
|
72
|
+
self.assertEqual(a, test_args)
|
73
|
+
return True
|
74
|
+
|
75
|
+
nn, flag = optional_import("torch", "1.1", version_checker=versioning, name="nn", version_args=test_args)
|
76
|
+
self.assertTrue(flag)
|
77
|
+
|
78
|
+
|
79
|
+
if __name__ == "__main__":
|
80
|
+
unittest.main()
|
@@ -0,0 +1,39 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
|
19
|
+
from monai.transforms import CastToType, Pad
|
20
|
+
from monai.utils import NumpyPadMode, PytorchPadMode
|
21
|
+
from tests.test_utils import SkipIfBeforePyTorchVersion
|
22
|
+
|
23
|
+
|
24
|
+
@SkipIfBeforePyTorchVersion((1, 10, 1))
|
25
|
+
class TestPadMode(unittest.TestCase):
|
26
|
+
def test_pad(self):
|
27
|
+
expected_shapes = {3: (1, 15, 10), 4: (1, 10, 6, 7)}
|
28
|
+
for t in (float, int, np.uint8, np.int16, np.float32, bool):
|
29
|
+
for d in ("cuda:0", "cpu") if torch.cuda.is_available() else ("cpu",):
|
30
|
+
for s in ((1, 10, 10), (1, 5, 6, 7)):
|
31
|
+
for m in list(PytorchPadMode) + list(NumpyPadMode):
|
32
|
+
a = torch.rand(s)
|
33
|
+
to_pad = [(0, 0), (2, 3)] if len(s) == 3 else [(0, 0), (2, 3), (0, 0), (0, 0)]
|
34
|
+
out = Pad(to_pad=to_pad, mode=m)(CastToType(dtype=t)(a).to(d))
|
35
|
+
self.assertEqual(out.shape, expected_shapes[len(s)])
|
36
|
+
|
37
|
+
|
38
|
+
if __name__ == "__main__":
|
39
|
+
unittest.main()
|
@@ -0,0 +1,208 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import datetime
|
15
|
+
import os
|
16
|
+
import unittest
|
17
|
+
from io import StringIO
|
18
|
+
|
19
|
+
import torch
|
20
|
+
|
21
|
+
import monai.transforms as mt
|
22
|
+
from monai.data import Dataset, ThreadDataLoader
|
23
|
+
from monai.utils import first, optional_import
|
24
|
+
from monai.utils.enums import CommonKeys
|
25
|
+
from monai.utils.profiling import ProfileHandler, ProfileResult, WorkflowProfiler
|
26
|
+
from tests.test_utils import SkipIfNoModule
|
27
|
+
|
28
|
+
pd, _ = optional_import("pandas")
|
29
|
+
|
30
|
+
|
31
|
+
class TestWorkflowProfiler(unittest.TestCase):
|
32
|
+
def setUp(self):
|
33
|
+
super().setUp()
|
34
|
+
|
35
|
+
self.scale = mt.ScaleIntensity()
|
36
|
+
self.scale_call_name = "ScaleIntensity.__call__"
|
37
|
+
self.compose_call_name = "Compose.__call__"
|
38
|
+
self.test_comp = mt.Compose([mt.ScaleIntensity(), mt.RandAxisFlip(0.5)])
|
39
|
+
self.test_image = torch.rand(1, 16, 16, 16)
|
40
|
+
self.pid = os.getpid()
|
41
|
+
|
42
|
+
def test_empty(self):
|
43
|
+
"""Test that the profiler correctly produces an empty result when nothing happens in a context."""
|
44
|
+
wp = WorkflowProfiler()
|
45
|
+
|
46
|
+
with wp:
|
47
|
+
pass
|
48
|
+
|
49
|
+
self.assertEqual(wp.get_results(), {})
|
50
|
+
|
51
|
+
def test_profile_transforms(self):
|
52
|
+
"""Test basic reporting when invoking a single transform directly."""
|
53
|
+
with WorkflowProfiler() as wp:
|
54
|
+
self.scale(self.test_image)
|
55
|
+
|
56
|
+
results = wp.get_results()
|
57
|
+
self.assertSequenceEqual(list(results), [self.scale_call_name])
|
58
|
+
|
59
|
+
prs = results[self.scale_call_name]
|
60
|
+
|
61
|
+
self.assertEqual(len(prs), 1)
|
62
|
+
|
63
|
+
pr = prs[0]
|
64
|
+
|
65
|
+
self.assertIsInstance(pr, ProfileResult)
|
66
|
+
self.assertEqual(pr.name, self.scale_call_name)
|
67
|
+
self.assertEqual(pr.pid, self.pid)
|
68
|
+
self.assertGreater(pr.time, 0)
|
69
|
+
|
70
|
+
dt = datetime.datetime.fromisoformat(pr.timestamp)
|
71
|
+
|
72
|
+
self.assertIsInstance(dt, datetime.datetime)
|
73
|
+
|
74
|
+
def test_profile_multithread(self):
|
75
|
+
"""Test resulst are gathered from multiple threads using ThreadDataLoader."""
|
76
|
+
ds = Dataset([self.test_image] * 4, self.scale)
|
77
|
+
dl = ThreadDataLoader(ds, batch_size=4, num_workers=4, use_thread_workers=True)
|
78
|
+
|
79
|
+
with WorkflowProfiler() as wp:
|
80
|
+
batch = first(dl)
|
81
|
+
|
82
|
+
self.assertSequenceEqual(batch.shape, (4, 1, 16, 16, 16))
|
83
|
+
|
84
|
+
results = wp.get_results()
|
85
|
+
self.assertSequenceEqual(list(results), [self.scale_call_name, self.compose_call_name])
|
86
|
+
|
87
|
+
prs = results[self.scale_call_name]
|
88
|
+
|
89
|
+
self.assertEqual(len(prs), 4)
|
90
|
+
|
91
|
+
def test_profile_context(self):
|
92
|
+
"""Test results from profiling contexts with the same name accumulate correctly."""
|
93
|
+
with WorkflowProfiler() as wp:
|
94
|
+
with wp.profile_ctx("context"):
|
95
|
+
self.scale(self.test_image)
|
96
|
+
|
97
|
+
with wp.profile_ctx("context"):
|
98
|
+
self.scale(self.test_image)
|
99
|
+
|
100
|
+
results = wp.get_results()
|
101
|
+
|
102
|
+
self.assertSequenceEqual(set(results), {"ScaleIntensity.__call__", "context"})
|
103
|
+
|
104
|
+
prs = results["context"]
|
105
|
+
|
106
|
+
self.assertEqual(len(prs), 2)
|
107
|
+
|
108
|
+
def test_profile_callable(self):
|
109
|
+
"""Test profiling functions with default or set names."""
|
110
|
+
|
111
|
+
def funca():
|
112
|
+
pass
|
113
|
+
|
114
|
+
with WorkflowProfiler() as wp:
|
115
|
+
funca = wp.profile_callable()(funca)
|
116
|
+
|
117
|
+
funca()
|
118
|
+
|
119
|
+
@wp.profile_callable("funcb")
|
120
|
+
def _func():
|
121
|
+
pass
|
122
|
+
|
123
|
+
_func()
|
124
|
+
_func()
|
125
|
+
|
126
|
+
results = wp.get_results()
|
127
|
+
self.assertSequenceEqual(set(results), {"funca", "funcb"})
|
128
|
+
|
129
|
+
self.assertEqual(len(results["funca"]), 1)
|
130
|
+
self.assertEqual(len(results["funcb"]), 2)
|
131
|
+
|
132
|
+
def test_profile_iteration(self):
|
133
|
+
"""Test iterables are profiled correctly, producing the right output and number of results."""
|
134
|
+
with WorkflowProfiler() as wp:
|
135
|
+
range_vals = []
|
136
|
+
|
137
|
+
for i in wp.profile_iter("range5", range(5)):
|
138
|
+
range_vals.append(i)
|
139
|
+
|
140
|
+
self.assertSequenceEqual(range_vals, list(range(5)))
|
141
|
+
|
142
|
+
results = wp.get_results()
|
143
|
+
self.assertSequenceEqual(set(results), {"range5"})
|
144
|
+
|
145
|
+
self.assertEqual(len(results["range5"]), 5)
|
146
|
+
|
147
|
+
def test_times_summary(self):
|
148
|
+
"""Test generating the summary report dictionary."""
|
149
|
+
with WorkflowProfiler() as wp:
|
150
|
+
self.scale(self.test_image)
|
151
|
+
|
152
|
+
tsum = wp.get_times_summary()
|
153
|
+
|
154
|
+
self.assertSequenceEqual(list(tsum), [self.scale_call_name])
|
155
|
+
|
156
|
+
times = tsum[self.scale_call_name]
|
157
|
+
|
158
|
+
self.assertEqual(len(times), 6)
|
159
|
+
self.assertEqual(times[0], 1)
|
160
|
+
|
161
|
+
@SkipIfNoModule("pandas")
|
162
|
+
def test_times_summary_pd(self):
|
163
|
+
"""Test generating the Pandas result works if Pandas is present."""
|
164
|
+
with WorkflowProfiler() as wp:
|
165
|
+
self.scale(self.test_image)
|
166
|
+
|
167
|
+
df = wp.get_times_summary_pd()
|
168
|
+
|
169
|
+
self.assertIsInstance(df, pd.DataFrame)
|
170
|
+
|
171
|
+
def test_csv_dump(self):
|
172
|
+
"""Test dumping the results to csv file in a local StringIO object."""
|
173
|
+
with WorkflowProfiler() as wp:
|
174
|
+
self.scale(self.test_image)
|
175
|
+
|
176
|
+
sio = StringIO()
|
177
|
+
wp.dump_csv(sio)
|
178
|
+
self.assertGreater(sio.tell(), 0)
|
179
|
+
|
180
|
+
@SkipIfNoModule("ignite")
|
181
|
+
def test_handler(self):
|
182
|
+
"""Test profiling Engine objects works if Ignite is present."""
|
183
|
+
from ignite.engine import Events
|
184
|
+
|
185
|
+
from monai.engines import SupervisedTrainer
|
186
|
+
|
187
|
+
net = torch.nn.Conv2d(1, 1, 3, padding=1)
|
188
|
+
im = torch.rand(1, 1, 16, 16)
|
189
|
+
|
190
|
+
with WorkflowProfiler(None) as wp:
|
191
|
+
trainer = SupervisedTrainer(
|
192
|
+
device=torch.device("cpu"),
|
193
|
+
max_epochs=2,
|
194
|
+
train_data_loader=[{CommonKeys.IMAGE: im, CommonKeys.LABEL: im}] * 3,
|
195
|
+
epoch_length=3,
|
196
|
+
network=net,
|
197
|
+
optimizer=torch.optim.Adam(net.parameters()),
|
198
|
+
loss_function=torch.nn.L1Loss(),
|
199
|
+
)
|
200
|
+
|
201
|
+
_ = ProfileHandler("Epoch", wp, Events.EPOCH_STARTED, Events.EPOCH_COMPLETED).attach(trainer)
|
202
|
+
|
203
|
+
trainer.run()
|
204
|
+
|
205
|
+
results = wp.get_results()
|
206
|
+
|
207
|
+
self.assertSequenceEqual(set(results), {"Epoch"})
|
208
|
+
self.assertEqual(len(results["Epoch"]), 2)
|
@@ -0,0 +1,77 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import logging
|
15
|
+
import os
|
16
|
+
import tempfile
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
import torch.distributed as dist
|
20
|
+
|
21
|
+
from monai.utils import RankFilter
|
22
|
+
from tests.test_utils import DistCall, DistTestCase
|
23
|
+
|
24
|
+
|
25
|
+
class DistributedRankFilterTest(DistTestCase):
|
26
|
+
def setUp(self):
|
27
|
+
self.log_dir = tempfile.TemporaryDirectory()
|
28
|
+
|
29
|
+
@DistCall(nnodes=1, nproc_per_node=2)
|
30
|
+
def test_rankfilter(self):
|
31
|
+
logger = logging.getLogger(__name__)
|
32
|
+
log_filename = os.path.join(self.log_dir.name, "records.log")
|
33
|
+
h1 = logging.FileHandler(filename=log_filename)
|
34
|
+
h1.setLevel(logging.WARNING)
|
35
|
+
|
36
|
+
logger.addHandler(h1)
|
37
|
+
|
38
|
+
logger.addFilter(RankFilter())
|
39
|
+
logger.warning("test_warnings")
|
40
|
+
|
41
|
+
dist.barrier()
|
42
|
+
if dist.get_rank() == 0:
|
43
|
+
with open(log_filename) as file:
|
44
|
+
lines = [line.rstrip() for line in file]
|
45
|
+
log_message = " ".join(lines)
|
46
|
+
self.assertEqual(log_message.count("test_warnings"), 1)
|
47
|
+
|
48
|
+
def tearDown(self) -> None:
|
49
|
+
self.log_dir.cleanup()
|
50
|
+
|
51
|
+
|
52
|
+
class SingleRankFilterTest(unittest.TestCase):
|
53
|
+
def tearDown(self) -> None:
|
54
|
+
self.log_dir.cleanup()
|
55
|
+
|
56
|
+
def setUp(self):
|
57
|
+
self.log_dir = tempfile.TemporaryDirectory()
|
58
|
+
|
59
|
+
def test_rankfilter_single_proc(self):
|
60
|
+
logger = logging.getLogger(__name__)
|
61
|
+
log_filename = os.path.join(self.log_dir.name, "records_sp.log")
|
62
|
+
h1 = logging.FileHandler(filename=log_filename)
|
63
|
+
h1.setLevel(logging.WARNING)
|
64
|
+
logger.addHandler(h1)
|
65
|
+
logger.addFilter(RankFilter())
|
66
|
+
logger.warning("test_warnings")
|
67
|
+
|
68
|
+
with open(log_filename) as file:
|
69
|
+
lines = [line.rstrip() for line in file]
|
70
|
+
logger.removeHandler(h1)
|
71
|
+
h1.close()
|
72
|
+
log_message = " ".join(lines)
|
73
|
+
self.assertEqual(log_message.count("test_warnings"), 1)
|
74
|
+
|
75
|
+
|
76
|
+
if __name__ == "__main__":
|
77
|
+
unittest.main()
|
@@ -0,0 +1,83 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
from monai.utils import OptionalImportError, min_version, require_pkg
|
17
|
+
|
18
|
+
|
19
|
+
class TestRequirePkg(unittest.TestCase):
|
20
|
+
|
21
|
+
def test_class(self):
|
22
|
+
|
23
|
+
@require_pkg(pkg_name="torch", version="1.4", version_checker=min_version)
|
24
|
+
class TestClass:
|
25
|
+
pass
|
26
|
+
|
27
|
+
TestClass()
|
28
|
+
|
29
|
+
def test_function(self):
|
30
|
+
|
31
|
+
@require_pkg(pkg_name="torch", version="1.4", version_checker=min_version)
|
32
|
+
def test_func(x):
|
33
|
+
return x
|
34
|
+
|
35
|
+
test_func(x=None)
|
36
|
+
|
37
|
+
def test_warning(self):
|
38
|
+
|
39
|
+
@require_pkg(pkg_name="test123", raise_error=False)
|
40
|
+
def test_func(x):
|
41
|
+
return x
|
42
|
+
|
43
|
+
test_func(x=None)
|
44
|
+
|
45
|
+
def test_class_exception(self):
|
46
|
+
with self.assertRaises(OptionalImportError):
|
47
|
+
|
48
|
+
@require_pkg(pkg_name="test123")
|
49
|
+
class TestClass:
|
50
|
+
pass
|
51
|
+
|
52
|
+
TestClass()
|
53
|
+
|
54
|
+
def test_class_version_exception(self):
|
55
|
+
with self.assertRaises(OptionalImportError):
|
56
|
+
|
57
|
+
@require_pkg(pkg_name="torch", version="10000", version_checker=min_version)
|
58
|
+
class TestClass:
|
59
|
+
pass
|
60
|
+
|
61
|
+
TestClass()
|
62
|
+
|
63
|
+
def test_func_exception(self):
|
64
|
+
with self.assertRaises(OptionalImportError):
|
65
|
+
|
66
|
+
@require_pkg(pkg_name="test123")
|
67
|
+
def test_func(x):
|
68
|
+
return x
|
69
|
+
|
70
|
+
test_func(x=None)
|
71
|
+
|
72
|
+
def test_func_versions_exception(self):
|
73
|
+
with self.assertRaises(OptionalImportError):
|
74
|
+
|
75
|
+
@require_pkg(pkg_name="torch", version="10000", version_checker=min_version)
|
76
|
+
def test_func(x):
|
77
|
+
return x
|
78
|
+
|
79
|
+
test_func(x=None)
|
80
|
+
|
81
|
+
|
82
|
+
if __name__ == "__main__":
|
83
|
+
unittest.main()
|