monai-weekly 1.5.dev2506__py3-none-any.whl → 1.5.dev2507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/transforms.py +1 -4
- monai/data/utils.py +6 -13
- monai/inferers/utils.py +1 -2
- monai/losses/dice.py +2 -14
- monai/losses/ds_loss.py +1 -3
- monai/networks/layers/simplelayers.py +2 -14
- monai/networks/utils.py +4 -16
- monai/transforms/compose.py +28 -11
- monai/transforms/croppad/array.py +1 -6
- monai/transforms/io/array.py +0 -1
- monai/transforms/transform.py +15 -6
- monai/transforms/utils.py +1 -2
- monai/utils/tf32.py +0 -10
- monai/visualize/class_activation_maps.py +5 -8
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/METADATA +2 -2
- monai_weekly-1.5.dev2507.dist-info/RECORD +1181 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/top_level.txt +1 -0
- tests/apps/__init__.py +10 -0
- tests/apps/deepedit/__init__.py +10 -0
- tests/apps/deepedit/test_deepedit_transforms.py +314 -0
- tests/apps/deepgrow/__init__.py +10 -0
- tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
- tests/apps/deepgrow/transforms/__init__.py +10 -0
- tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
- tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
- tests/apps/detection/__init__.py +10 -0
- tests/apps/detection/metrics/__init__.py +10 -0
- tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
- tests/apps/detection/networks/__init__.py +10 -0
- tests/apps/detection/networks/test_retinanet.py +210 -0
- tests/apps/detection/networks/test_retinanet_detector.py +203 -0
- tests/apps/detection/test_box_transform.py +370 -0
- tests/apps/detection/utils/__init__.py +10 -0
- tests/apps/detection/utils/test_anchor_box.py +88 -0
- tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
- tests/apps/detection/utils/test_box_coder.py +43 -0
- tests/apps/detection/utils/test_detector_boxselector.py +67 -0
- tests/apps/detection/utils/test_detector_utils.py +96 -0
- tests/apps/detection/utils/test_hardnegsampler.py +54 -0
- tests/apps/nuclick/__init__.py +10 -0
- tests/apps/nuclick/test_nuclick_transforms.py +259 -0
- tests/apps/pathology/__init__.py +10 -0
- tests/apps/pathology/handlers/__init__.py +10 -0
- tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
- tests/apps/pathology/test_lesion_froc.py +333 -0
- tests/apps/pathology/test_pathology_prob_nms.py +55 -0
- tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
- tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
- tests/apps/pathology/transforms/__init__.py +10 -0
- tests/apps/pathology/transforms/post/__init__.py +10 -0
- tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
- tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
- tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
- tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
- tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
- tests/apps/pathology/transforms/post/test_watershed.py +60 -0
- tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
- tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
- tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
- tests/apps/reconstruction/__init__.py +10 -0
- tests/apps/reconstruction/nets/__init__.py +10 -0
- tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
- tests/apps/reconstruction/test_complex_utils.py +77 -0
- tests/apps/reconstruction/test_fastmri_reader.py +82 -0
- tests/apps/reconstruction/test_mri_utils.py +37 -0
- tests/apps/reconstruction/transforms/__init__.py +10 -0
- tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
- tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
- tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
- tests/apps/test_auto3dseg_bundlegen.py +156 -0
- tests/apps/test_check_hash.py +53 -0
- tests/apps/test_cross_validation.py +74 -0
- tests/apps/test_decathlondataset.py +93 -0
- tests/apps/test_download_and_extract.py +70 -0
- tests/apps/test_download_url_yandex.py +45 -0
- tests/apps/test_mednistdataset.py +72 -0
- tests/apps/test_mmar_download.py +154 -0
- tests/apps/test_tciadataset.py +123 -0
- tests/apps/vista3d/__init__.py +10 -0
- tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
- tests/apps/vista3d/test_vista3d_sampler.py +100 -0
- tests/apps/vista3d/test_vista3d_transforms.py +94 -0
- tests/bundle/__init__.py +10 -0
- tests/bundle/test_bundle_ckpt_export.py +107 -0
- tests/bundle/test_bundle_download.py +435 -0
- tests/bundle/test_bundle_get_data.py +94 -0
- tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
- tests/bundle/test_bundle_trt_export.py +147 -0
- tests/bundle/test_bundle_utils.py +149 -0
- tests/bundle/test_bundle_verify_metadata.py +66 -0
- tests/bundle/test_bundle_verify_net.py +76 -0
- tests/bundle/test_bundle_workflow.py +272 -0
- tests/bundle/test_component_locator.py +38 -0
- tests/bundle/test_config_item.py +138 -0
- tests/bundle/test_config_parser.py +392 -0
- tests/bundle/test_reference_resolver.py +114 -0
- tests/config/__init__.py +10 -0
- tests/config/test_cv2_dist.py +53 -0
- tests/engines/__init__.py +10 -0
- tests/engines/test_ensemble_evaluator.py +94 -0
- tests/engines/test_prepare_batch_default.py +76 -0
- tests/engines/test_prepare_batch_default_dist.py +76 -0
- tests/engines/test_prepare_batch_diffusion.py +104 -0
- tests/engines/test_prepare_batch_extra_input.py +80 -0
- tests/fl/__init__.py +10 -0
- tests/fl/monai_algo/__init__.py +10 -0
- tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
- tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
- tests/fl/test_fl_monai_algo_stats.py +81 -0
- tests/fl/utils/__init__.py +10 -0
- tests/fl/utils/test_fl_exchange_object.py +63 -0
- tests/handlers/__init__.py +10 -0
- tests/handlers/test_handler_checkpoint_loader.py +182 -0
- tests/handlers/test_handler_checkpoint_saver.py +233 -0
- tests/handlers/test_handler_classification_saver.py +64 -0
- tests/handlers/test_handler_classification_saver_dist.py +77 -0
- tests/handlers/test_handler_clearml_image.py +65 -0
- tests/handlers/test_handler_clearml_stats.py +65 -0
- tests/handlers/test_handler_confusion_matrix.py +104 -0
- tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
- tests/handlers/test_handler_decollate_batch.py +66 -0
- tests/handlers/test_handler_early_stop.py +68 -0
- tests/handlers/test_handler_garbage_collector.py +73 -0
- tests/handlers/test_handler_hausdorff_distance.py +111 -0
- tests/handlers/test_handler_ignite_metric.py +191 -0
- tests/handlers/test_handler_lr_scheduler.py +94 -0
- tests/handlers/test_handler_mean_dice.py +98 -0
- tests/handlers/test_handler_mean_iou.py +76 -0
- tests/handlers/test_handler_metrics_reloaded.py +149 -0
- tests/handlers/test_handler_metrics_saver.py +89 -0
- tests/handlers/test_handler_metrics_saver_dist.py +120 -0
- tests/handlers/test_handler_mlflow.py +296 -0
- tests/handlers/test_handler_nvtx.py +93 -0
- tests/handlers/test_handler_panoptic_quality.py +89 -0
- tests/handlers/test_handler_parameter_scheduler.py +136 -0
- tests/handlers/test_handler_post_processing.py +74 -0
- tests/handlers/test_handler_prob_map_producer.py +111 -0
- tests/handlers/test_handler_regression_metrics.py +160 -0
- tests/handlers/test_handler_regression_metrics_dist.py +245 -0
- tests/handlers/test_handler_rocauc.py +48 -0
- tests/handlers/test_handler_rocauc_dist.py +54 -0
- tests/handlers/test_handler_stats.py +281 -0
- tests/handlers/test_handler_surface_distance.py +113 -0
- tests/handlers/test_handler_tb_image.py +61 -0
- tests/handlers/test_handler_tb_stats.py +166 -0
- tests/handlers/test_handler_validation.py +59 -0
- tests/handlers/test_trt_compile.py +145 -0
- tests/handlers/test_write_metrics_reports.py +68 -0
- tests/inferers/__init__.py +10 -0
- tests/inferers/test_avg_merger.py +179 -0
- tests/inferers/test_controlnet_inferers.py +1310 -0
- tests/inferers/test_diffusion_inferer.py +236 -0
- tests/inferers/test_latent_diffusion_inferer.py +824 -0
- tests/inferers/test_patch_inferer.py +309 -0
- tests/inferers/test_saliency_inferer.py +55 -0
- tests/inferers/test_slice_inferer.py +57 -0
- tests/inferers/test_sliding_window_inference.py +377 -0
- tests/inferers/test_sliding_window_splitter.py +284 -0
- tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
- tests/inferers/test_zarr_avg_merger.py +326 -0
- tests/integration/__init__.py +10 -0
- tests/integration/test_auto3dseg_ensemble.py +211 -0
- tests/integration/test_auto3dseg_hpo.py +189 -0
- tests/integration/test_deepedit_interaction.py +122 -0
- tests/integration/test_downsample_block.py +50 -0
- tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
- tests/integration/test_integration_autorunner.py +201 -0
- tests/integration/test_integration_bundle_run.py +240 -0
- tests/integration/test_integration_classification_2d.py +282 -0
- tests/integration/test_integration_determinism.py +95 -0
- tests/integration/test_integration_fast_train.py +231 -0
- tests/integration/test_integration_gpu_customization.py +159 -0
- tests/integration/test_integration_lazy_samples.py +219 -0
- tests/integration/test_integration_nnunetv2_runner.py +96 -0
- tests/integration/test_integration_segmentation_3d.py +304 -0
- tests/integration/test_integration_sliding_window.py +100 -0
- tests/integration/test_integration_stn.py +133 -0
- tests/integration/test_integration_unet_2d.py +67 -0
- tests/integration/test_integration_workers.py +61 -0
- tests/integration/test_integration_workflows.py +365 -0
- tests/integration/test_integration_workflows_adversarial.py +173 -0
- tests/integration/test_integration_workflows_gan.py +158 -0
- tests/integration/test_loader_semaphore.py +48 -0
- tests/integration/test_mapping_filed.py +122 -0
- tests/integration/test_meta_affine.py +183 -0
- tests/integration/test_metatensor_integration.py +114 -0
- tests/integration/test_module_list.py +76 -0
- tests/integration/test_one_of.py +283 -0
- tests/integration/test_pad_collation.py +124 -0
- tests/integration/test_reg_loss_integration.py +107 -0
- tests/integration/test_retinanet_predict_utils.py +154 -0
- tests/integration/test_seg_loss_integration.py +159 -0
- tests/integration/test_spatial_combine_transforms.py +185 -0
- tests/integration/test_testtimeaugmentation.py +186 -0
- tests/integration/test_vis_gradbased.py +69 -0
- tests/integration/test_vista3d_utils.py +159 -0
- tests/losses/__init__.py +10 -0
- tests/losses/deform/__init__.py +10 -0
- tests/losses/deform/test_bending_energy.py +88 -0
- tests/losses/deform/test_diffusion_loss.py +117 -0
- tests/losses/image_dissimilarity/__init__.py +10 -0
- tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
- tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
- tests/losses/test_adversarial_loss.py +94 -0
- tests/losses/test_barlow_twins_loss.py +109 -0
- tests/losses/test_cldice_loss.py +51 -0
- tests/losses/test_contrastive_loss.py +86 -0
- tests/losses/test_dice_ce_loss.py +123 -0
- tests/losses/test_dice_focal_loss.py +124 -0
- tests/losses/test_dice_loss.py +227 -0
- tests/losses/test_ds_loss.py +189 -0
- tests/losses/test_focal_loss.py +379 -0
- tests/losses/test_generalized_dice_focal_loss.py +85 -0
- tests/losses/test_generalized_dice_loss.py +221 -0
- tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
- tests/losses/test_giou_loss.py +62 -0
- tests/losses/test_hausdorff_loss.py +264 -0
- tests/losses/test_masked_dice_loss.py +152 -0
- tests/losses/test_masked_loss.py +87 -0
- tests/losses/test_multi_scale.py +86 -0
- tests/losses/test_nacl_loss.py +167 -0
- tests/losses/test_perceptual_loss.py +122 -0
- tests/losses/test_spectral_loss.py +86 -0
- tests/losses/test_ssim_loss.py +59 -0
- tests/losses/test_sure_loss.py +72 -0
- tests/losses/test_tversky_loss.py +198 -0
- tests/losses/test_unified_focal_loss.py +66 -0
- tests/metrics/__init__.py +10 -0
- tests/metrics/test_compute_confusion_matrix.py +294 -0
- tests/metrics/test_compute_f_beta.py +80 -0
- tests/metrics/test_compute_fid_metric.py +40 -0
- tests/metrics/test_compute_froc.py +143 -0
- tests/metrics/test_compute_generalized_dice.py +240 -0
- tests/metrics/test_compute_meandice.py +306 -0
- tests/metrics/test_compute_meaniou.py +223 -0
- tests/metrics/test_compute_mmd_metric.py +56 -0
- tests/metrics/test_compute_multiscalessim_metric.py +83 -0
- tests/metrics/test_compute_panoptic_quality.py +113 -0
- tests/metrics/test_compute_regression_metrics.py +196 -0
- tests/metrics/test_compute_roc_auc.py +155 -0
- tests/metrics/test_compute_variance.py +147 -0
- tests/metrics/test_cumulative.py +63 -0
- tests/metrics/test_cumulative_average.py +74 -0
- tests/metrics/test_cumulative_average_dist.py +48 -0
- tests/metrics/test_hausdorff_distance.py +209 -0
- tests/metrics/test_label_quality_score.py +134 -0
- tests/metrics/test_loss_metric.py +57 -0
- tests/metrics/test_metrics_reloaded.py +96 -0
- tests/metrics/test_ssim_metric.py +78 -0
- tests/metrics/test_surface_dice.py +416 -0
- tests/metrics/test_surface_distance.py +186 -0
- tests/networks/__init__.py +10 -0
- tests/networks/blocks/__init__.py +10 -0
- tests/networks/blocks/dints_block/__init__.py +10 -0
- tests/networks/blocks/dints_block/test_acn_block.py +41 -0
- tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
- tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
- tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
- tests/networks/blocks/test_adn.py +86 -0
- tests/networks/blocks/test_convolutions.py +156 -0
- tests/networks/blocks/test_crf_cpu.py +513 -0
- tests/networks/blocks/test_crf_cuda.py +528 -0
- tests/networks/blocks/test_crossattention.py +185 -0
- tests/networks/blocks/test_denseblock.py +105 -0
- tests/networks/blocks/test_dynunet_block.py +116 -0
- tests/networks/blocks/test_fpn_block.py +88 -0
- tests/networks/blocks/test_localnet_block.py +121 -0
- tests/networks/blocks/test_mlp.py +78 -0
- tests/networks/blocks/test_patchembedding.py +212 -0
- tests/networks/blocks/test_regunet_block.py +103 -0
- tests/networks/blocks/test_se_block.py +85 -0
- tests/networks/blocks/test_se_blocks.py +78 -0
- tests/networks/blocks/test_segresnet_block.py +57 -0
- tests/networks/blocks/test_selfattention.py +232 -0
- tests/networks/blocks/test_simple_aspp.py +87 -0
- tests/networks/blocks/test_spatialattention.py +55 -0
- tests/networks/blocks/test_subpixel_upsample.py +87 -0
- tests/networks/blocks/test_text_encoding.py +49 -0
- tests/networks/blocks/test_transformerblock.py +90 -0
- tests/networks/blocks/test_unetr_block.py +158 -0
- tests/networks/blocks/test_upsample_block.py +134 -0
- tests/networks/blocks/warp/__init__.py +10 -0
- tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
- tests/networks/blocks/warp/test_warp.py +250 -0
- tests/networks/layers/__init__.py +10 -0
- tests/networks/layers/filtering/__init__.py +10 -0
- tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
- tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
- tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
- tests/networks/layers/filtering/test_phl_cpu.py +259 -0
- tests/networks/layers/filtering/test_phl_cuda.py +167 -0
- tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
- tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
- tests/networks/layers/test_affine_transform.py +385 -0
- tests/networks/layers/test_apply_filter.py +89 -0
- tests/networks/layers/test_channel_pad.py +51 -0
- tests/networks/layers/test_conjugate_gradient.py +56 -0
- tests/networks/layers/test_drop_path.py +46 -0
- tests/networks/layers/test_gaussian.py +317 -0
- tests/networks/layers/test_gaussian_filter.py +206 -0
- tests/networks/layers/test_get_layers.py +65 -0
- tests/networks/layers/test_gmm.py +314 -0
- tests/networks/layers/test_grid_pull.py +93 -0
- tests/networks/layers/test_hilbert_transform.py +131 -0
- tests/networks/layers/test_lltm.py +62 -0
- tests/networks/layers/test_median_filter.py +52 -0
- tests/networks/layers/test_polyval.py +55 -0
- tests/networks/layers/test_preset_filters.py +136 -0
- tests/networks/layers/test_savitzky_golay_filter.py +141 -0
- tests/networks/layers/test_separable_filter.py +87 -0
- tests/networks/layers/test_skip_connection.py +48 -0
- tests/networks/layers/test_vector_quantizer.py +89 -0
- tests/networks/layers/test_weight_init.py +50 -0
- tests/networks/nets/__init__.py +10 -0
- tests/networks/nets/dints/__init__.py +10 -0
- tests/networks/nets/dints/test_dints_cell.py +110 -0
- tests/networks/nets/dints/test_dints_mixop.py +84 -0
- tests/networks/nets/regunet/__init__.py +10 -0
- tests/networks/nets/regunet/test_localnet.py +86 -0
- tests/networks/nets/regunet/test_regunet.py +88 -0
- tests/networks/nets/test_ahnet.py +224 -0
- tests/networks/nets/test_attentionunet.py +88 -0
- tests/networks/nets/test_autoencoder.py +95 -0
- tests/networks/nets/test_autoencoderkl.py +337 -0
- tests/networks/nets/test_basic_unet.py +102 -0
- tests/networks/nets/test_basic_unetplusplus.py +109 -0
- tests/networks/nets/test_bundle_init_bundle.py +55 -0
- tests/networks/nets/test_cell_sam_wrapper.py +58 -0
- tests/networks/nets/test_controlnet.py +215 -0
- tests/networks/nets/test_daf3d.py +62 -0
- tests/networks/nets/test_densenet.py +121 -0
- tests/networks/nets/test_diffusion_model_unet.py +585 -0
- tests/networks/nets/test_dints_network.py +168 -0
- tests/networks/nets/test_discriminator.py +59 -0
- tests/networks/nets/test_dynunet.py +181 -0
- tests/networks/nets/test_efficientnet.py +400 -0
- tests/networks/nets/test_flexible_unet.py +341 -0
- tests/networks/nets/test_fullyconnectednet.py +69 -0
- tests/networks/nets/test_generator.py +59 -0
- tests/networks/nets/test_globalnet.py +103 -0
- tests/networks/nets/test_highresnet.py +67 -0
- tests/networks/nets/test_hovernet.py +218 -0
- tests/networks/nets/test_mednext.py +122 -0
- tests/networks/nets/test_milmodel.py +92 -0
- tests/networks/nets/test_net_adapter.py +68 -0
- tests/networks/nets/test_network_consistency.py +86 -0
- tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
- tests/networks/nets/test_quicknat.py +57 -0
- tests/networks/nets/test_resnet.py +340 -0
- tests/networks/nets/test_segresnet.py +120 -0
- tests/networks/nets/test_segresnet_ds.py +156 -0
- tests/networks/nets/test_senet.py +151 -0
- tests/networks/nets/test_spade_autoencoderkl.py +295 -0
- tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
- tests/networks/nets/test_spade_vaegan.py +140 -0
- tests/networks/nets/test_swin_unetr.py +139 -0
- tests/networks/nets/test_torchvision_fc_model.py +201 -0
- tests/networks/nets/test_transchex.py +84 -0
- tests/networks/nets/test_transformer.py +108 -0
- tests/networks/nets/test_unet.py +208 -0
- tests/networks/nets/test_unetr.py +137 -0
- tests/networks/nets/test_varautoencoder.py +127 -0
- tests/networks/nets/test_vista3d.py +84 -0
- tests/networks/nets/test_vit.py +139 -0
- tests/networks/nets/test_vitautoenc.py +112 -0
- tests/networks/nets/test_vnet.py +81 -0
- tests/networks/nets/test_voxelmorph.py +280 -0
- tests/networks/nets/test_vqvae.py +274 -0
- tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
- tests/networks/schedulers/__init__.py +10 -0
- tests/networks/schedulers/test_scheduler_ddim.py +83 -0
- tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
- tests/networks/schedulers/test_scheduler_pndm.py +108 -0
- tests/networks/test_bundle_onnx_export.py +71 -0
- tests/networks/test_convert_to_onnx.py +106 -0
- tests/networks/test_convert_to_torchscript.py +46 -0
- tests/networks/test_convert_to_trt.py +79 -0
- tests/networks/test_save_state.py +73 -0
- tests/networks/test_to_onehot.py +63 -0
- tests/networks/test_varnet.py +63 -0
- tests/networks/utils/__init__.py +10 -0
- tests/networks/utils/test_copy_model_state.py +187 -0
- tests/networks/utils/test_eval_mode.py +34 -0
- tests/networks/utils/test_freeze_layers.py +61 -0
- tests/networks/utils/test_replace_module.py +98 -0
- tests/networks/utils/test_train_mode.py +34 -0
- tests/optimizers/__init__.py +10 -0
- tests/optimizers/test_generate_param_groups.py +105 -0
- tests/optimizers/test_lr_finder.py +108 -0
- tests/optimizers/test_lr_scheduler.py +71 -0
- tests/optimizers/test_optim_novograd.py +100 -0
- tests/profile_subclass/__init__.py +10 -0
- tests/profile_subclass/cprofile_profiling.py +29 -0
- tests/profile_subclass/min_classes.py +30 -0
- tests/profile_subclass/profiling.py +73 -0
- tests/profile_subclass/pyspy_profiling.py +41 -0
- tests/transforms/__init__.py +10 -0
- tests/transforms/compose/__init__.py +10 -0
- tests/transforms/compose/test_compose.py +758 -0
- tests/transforms/compose/test_some_of.py +258 -0
- tests/transforms/croppad/__init__.py +10 -0
- tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
- tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
- tests/transforms/functional/__init__.py +10 -0
- tests/transforms/functional/test_apply.py +75 -0
- tests/transforms/functional/test_resample.py +50 -0
- tests/transforms/intensity/__init__.py +10 -0
- tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
- tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
- tests/transforms/intensity/test_foreground_mask.py +98 -0
- tests/transforms/intensity/test_foreground_maskd.py +106 -0
- tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
- tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
- tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
- tests/transforms/inverse/__init__.py +10 -0
- tests/transforms/inverse/test_inverse_array.py +76 -0
- tests/transforms/inverse/test_traceable_transform.py +59 -0
- tests/transforms/post/__init__.py +10 -0
- tests/transforms/post/test_label_filterd.py +78 -0
- tests/transforms/post/test_probnms.py +72 -0
- tests/transforms/post/test_probnmsd.py +79 -0
- tests/transforms/post/test_remove_small_objects.py +102 -0
- tests/transforms/spatial/__init__.py +10 -0
- tests/transforms/spatial/test_convert_box_points.py +119 -0
- tests/transforms/spatial/test_grid_patch.py +134 -0
- tests/transforms/spatial/test_grid_patchd.py +102 -0
- tests/transforms/spatial/test_rand_grid_patch.py +150 -0
- tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
- tests/transforms/spatial/test_spatial_resampled.py +124 -0
- tests/transforms/test_activations.py +120 -0
- tests/transforms/test_activationsd.py +64 -0
- tests/transforms/test_adaptors.py +160 -0
- tests/transforms/test_add_coordinate_channels.py +53 -0
- tests/transforms/test_add_coordinate_channelsd.py +67 -0
- tests/transforms/test_add_extreme_points_channel.py +80 -0
- tests/transforms/test_add_extreme_points_channeld.py +77 -0
- tests/transforms/test_adjust_contrast.py +70 -0
- tests/transforms/test_adjust_contrastd.py +64 -0
- tests/transforms/test_affine.py +245 -0
- tests/transforms/test_affine_grid.py +152 -0
- tests/transforms/test_affined.py +190 -0
- tests/transforms/test_as_channel_last.py +38 -0
- tests/transforms/test_as_channel_lastd.py +44 -0
- tests/transforms/test_as_discrete.py +81 -0
- tests/transforms/test_as_discreted.py +82 -0
- tests/transforms/test_border_pad.py +49 -0
- tests/transforms/test_border_padd.py +45 -0
- tests/transforms/test_bounding_rect.py +54 -0
- tests/transforms/test_bounding_rectd.py +53 -0
- tests/transforms/test_cast_to_type.py +63 -0
- tests/transforms/test_cast_to_typed.py +74 -0
- tests/transforms/test_center_scale_crop.py +55 -0
- tests/transforms/test_center_scale_cropd.py +56 -0
- tests/transforms/test_center_spatial_crop.py +56 -0
- tests/transforms/test_center_spatial_cropd.py +63 -0
- tests/transforms/test_classes_to_indices.py +93 -0
- tests/transforms/test_classes_to_indicesd.py +110 -0
- tests/transforms/test_clip_intensity_percentiles.py +196 -0
- tests/transforms/test_clip_intensity_percentilesd.py +193 -0
- tests/transforms/test_compose_get_number_conversions.py +127 -0
- tests/transforms/test_concat_itemsd.py +82 -0
- tests/transforms/test_convert_to_multi_channel.py +59 -0
- tests/transforms/test_convert_to_multi_channeld.py +37 -0
- tests/transforms/test_copy_itemsd.py +86 -0
- tests/transforms/test_create_grid_and_affine.py +274 -0
- tests/transforms/test_crop_foreground.py +164 -0
- tests/transforms/test_crop_foregroundd.py +205 -0
- tests/transforms/test_cucim_dict_transform.py +142 -0
- tests/transforms/test_cucim_transform.py +141 -0
- tests/transforms/test_data_stats.py +221 -0
- tests/transforms/test_data_statsd.py +249 -0
- tests/transforms/test_delete_itemsd.py +58 -0
- tests/transforms/test_detect_envelope.py +159 -0
- tests/transforms/test_distance_transform_edt.py +202 -0
- tests/transforms/test_divisible_pad.py +49 -0
- tests/transforms/test_divisible_padd.py +42 -0
- tests/transforms/test_ensure_channel_first.py +113 -0
- tests/transforms/test_ensure_channel_firstd.py +85 -0
- tests/transforms/test_ensure_type.py +94 -0
- tests/transforms/test_ensure_typed.py +110 -0
- tests/transforms/test_fg_bg_to_indices.py +83 -0
- tests/transforms/test_fg_bg_to_indicesd.py +78 -0
- tests/transforms/test_fill_holes.py +207 -0
- tests/transforms/test_fill_holesd.py +209 -0
- tests/transforms/test_flatten_sub_keysd.py +64 -0
- tests/transforms/test_flip.py +83 -0
- tests/transforms/test_flipd.py +90 -0
- tests/transforms/test_fourier.py +70 -0
- tests/transforms/test_gaussian_sharpen.py +92 -0
- tests/transforms/test_gaussian_sharpend.py +92 -0
- tests/transforms/test_gaussian_smooth.py +96 -0
- tests/transforms/test_gaussian_smoothd.py +96 -0
- tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
- tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
- tests/transforms/test_generate_spatial_bounding_box.py +114 -0
- tests/transforms/test_get_extreme_points.py +57 -0
- tests/transforms/test_gibbs_noise.py +75 -0
- tests/transforms/test_gibbs_noised.py +88 -0
- tests/transforms/test_grid_distortion.py +113 -0
- tests/transforms/test_grid_distortiond.py +87 -0
- tests/transforms/test_grid_split.py +88 -0
- tests/transforms/test_grid_splitd.py +96 -0
- tests/transforms/test_histogram_normalize.py +59 -0
- tests/transforms/test_histogram_normalized.py +59 -0
- tests/transforms/test_image_filter.py +259 -0
- tests/transforms/test_intensity_stats.py +73 -0
- tests/transforms/test_intensity_statsd.py +90 -0
- tests/transforms/test_inverse.py +521 -0
- tests/transforms/test_inverse_collation.py +147 -0
- tests/transforms/test_invert.py +105 -0
- tests/transforms/test_invertd.py +142 -0
- tests/transforms/test_k_space_spike_noise.py +81 -0
- tests/transforms/test_k_space_spike_noised.py +98 -0
- tests/transforms/test_keep_largest_connected_component.py +419 -0
- tests/transforms/test_keep_largest_connected_componentd.py +348 -0
- tests/transforms/test_label_filter.py +78 -0
- tests/transforms/test_label_to_contour.py +179 -0
- tests/transforms/test_label_to_contourd.py +182 -0
- tests/transforms/test_label_to_mask.py +69 -0
- tests/transforms/test_label_to_maskd.py +70 -0
- tests/transforms/test_load_image.py +502 -0
- tests/transforms/test_load_imaged.py +198 -0
- tests/transforms/test_load_spacing_orientation.py +149 -0
- tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
- tests/transforms/test_map_binary_to_indices.py +75 -0
- tests/transforms/test_map_classes_to_indices.py +135 -0
- tests/transforms/test_map_label_value.py +89 -0
- tests/transforms/test_map_label_valued.py +85 -0
- tests/transforms/test_map_transform.py +45 -0
- tests/transforms/test_mask_intensity.py +74 -0
- tests/transforms/test_mask_intensityd.py +68 -0
- tests/transforms/test_mean_ensemble.py +77 -0
- tests/transforms/test_mean_ensembled.py +91 -0
- tests/transforms/test_median_smooth.py +41 -0
- tests/transforms/test_median_smoothd.py +65 -0
- tests/transforms/test_morphological_ops.py +101 -0
- tests/transforms/test_nifti_endianness.py +107 -0
- tests/transforms/test_normalize_intensity.py +143 -0
- tests/transforms/test_normalize_intensityd.py +81 -0
- tests/transforms/test_nvtx_decorator.py +289 -0
- tests/transforms/test_nvtx_transform.py +143 -0
- tests/transforms/test_orientation.py +247 -0
- tests/transforms/test_orientationd.py +112 -0
- tests/transforms/test_rand_adjust_contrast.py +45 -0
- tests/transforms/test_rand_adjust_contrastd.py +44 -0
- tests/transforms/test_rand_affine.py +201 -0
- tests/transforms/test_rand_affine_grid.py +212 -0
- tests/transforms/test_rand_affined.py +281 -0
- tests/transforms/test_rand_axis_flip.py +50 -0
- tests/transforms/test_rand_axis_flipd.py +50 -0
- tests/transforms/test_rand_bias_field.py +69 -0
- tests/transforms/test_rand_bias_fieldd.py +65 -0
- tests/transforms/test_rand_coarse_dropout.py +110 -0
- tests/transforms/test_rand_coarse_dropoutd.py +107 -0
- tests/transforms/test_rand_coarse_shuffle.py +65 -0
- tests/transforms/test_rand_coarse_shuffled.py +59 -0
- tests/transforms/test_rand_crop_by_label_classes.py +170 -0
- tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
- tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
- tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
- tests/transforms/test_rand_cucim_dict_transform.py +162 -0
- tests/transforms/test_rand_cucim_transform.py +162 -0
- tests/transforms/test_rand_deform_grid.py +138 -0
- tests/transforms/test_rand_elastic_2d.py +127 -0
- tests/transforms/test_rand_elastic_3d.py +104 -0
- tests/transforms/test_rand_elasticd_2d.py +177 -0
- tests/transforms/test_rand_elasticd_3d.py +156 -0
- tests/transforms/test_rand_flip.py +60 -0
- tests/transforms/test_rand_flipd.py +55 -0
- tests/transforms/test_rand_gaussian_noise.py +48 -0
- tests/transforms/test_rand_gaussian_noised.py +54 -0
- tests/transforms/test_rand_gaussian_sharpen.py +140 -0
- tests/transforms/test_rand_gaussian_sharpend.py +143 -0
- tests/transforms/test_rand_gaussian_smooth.py +98 -0
- tests/transforms/test_rand_gaussian_smoothd.py +98 -0
- tests/transforms/test_rand_gibbs_noise.py +103 -0
- tests/transforms/test_rand_gibbs_noised.py +117 -0
- tests/transforms/test_rand_grid_distortion.py +99 -0
- tests/transforms/test_rand_grid_distortiond.py +90 -0
- tests/transforms/test_rand_histogram_shift.py +92 -0
- tests/transforms/test_rand_k_space_spike_noise.py +92 -0
- tests/transforms/test_rand_k_space_spike_noised.py +76 -0
- tests/transforms/test_rand_rician_noise.py +52 -0
- tests/transforms/test_rand_rician_noised.py +52 -0
- tests/transforms/test_rand_rotate.py +166 -0
- tests/transforms/test_rand_rotate90.py +100 -0
- tests/transforms/test_rand_rotate90d.py +112 -0
- tests/transforms/test_rand_rotated.py +187 -0
- tests/transforms/test_rand_scale_crop.py +78 -0
- tests/transforms/test_rand_scale_cropd.py +98 -0
- tests/transforms/test_rand_scale_intensity.py +54 -0
- tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
- tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
- tests/transforms/test_rand_scale_intensityd.py +53 -0
- tests/transforms/test_rand_shift_intensity.py +52 -0
- tests/transforms/test_rand_shift_intensityd.py +67 -0
- tests/transforms/test_rand_simulate_low_resolution.py +83 -0
- tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
- tests/transforms/test_rand_spatial_crop.py +107 -0
- tests/transforms/test_rand_spatial_crop_samples.py +128 -0
- tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
- tests/transforms/test_rand_spatial_cropd.py +112 -0
- tests/transforms/test_rand_std_shift_intensity.py +43 -0
- tests/transforms/test_rand_std_shift_intensityd.py +38 -0
- tests/transforms/test_rand_zoom.py +105 -0
- tests/transforms/test_rand_zoomd.py +108 -0
- tests/transforms/test_randidentity.py +49 -0
- tests/transforms/test_random_order.py +144 -0
- tests/transforms/test_randtorchvisiond.py +65 -0
- tests/transforms/test_regularization.py +139 -0
- tests/transforms/test_remove_repeated_channel.py +34 -0
- tests/transforms/test_remove_repeated_channeld.py +44 -0
- tests/transforms/test_repeat_channel.py +34 -0
- tests/transforms/test_repeat_channeld.py +41 -0
- tests/transforms/test_resample_backends.py +65 -0
- tests/transforms/test_resample_to_match.py +110 -0
- tests/transforms/test_resample_to_matchd.py +93 -0
- tests/transforms/test_resampler.py +165 -0
- tests/transforms/test_resize.py +140 -0
- tests/transforms/test_resize_with_pad_or_crop.py +91 -0
- tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
- tests/transforms/test_resized.py +163 -0
- tests/transforms/test_rotate.py +160 -0
- tests/transforms/test_rotate90.py +212 -0
- tests/transforms/test_rotate90d.py +106 -0
- tests/transforms/test_rotated.py +179 -0
- tests/transforms/test_save_classificationd.py +109 -0
- tests/transforms/test_save_image.py +80 -0
- tests/transforms/test_save_imaged.py +130 -0
- tests/transforms/test_savitzky_golay_smooth.py +73 -0
- tests/transforms/test_savitzky_golay_smoothd.py +73 -0
- tests/transforms/test_scale_intensity.py +76 -0
- tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
- tests/transforms/test_scale_intensity_range.py +41 -0
- tests/transforms/test_scale_intensity_ranged.py +40 -0
- tests/transforms/test_scale_intensityd.py +57 -0
- tests/transforms/test_select_itemsd.py +41 -0
- tests/transforms/test_shift_intensity.py +31 -0
- tests/transforms/test_shift_intensityd.py +44 -0
- tests/transforms/test_signal_continuouswavelet.py +44 -0
- tests/transforms/test_signal_fillempty.py +52 -0
- tests/transforms/test_signal_fillemptyd.py +60 -0
- tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
- tests/transforms/test_signal_rand_add_sine.py +52 -0
- tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
- tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
- tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
- tests/transforms/test_signal_rand_drop.py +50 -0
- tests/transforms/test_signal_rand_scale.py +52 -0
- tests/transforms/test_signal_rand_shift.py +55 -0
- tests/transforms/test_signal_remove_frequency.py +71 -0
- tests/transforms/test_smooth_field.py +177 -0
- tests/transforms/test_sobel_gradient.py +189 -0
- tests/transforms/test_sobel_gradientd.py +212 -0
- tests/transforms/test_spacing.py +381 -0
- tests/transforms/test_spacingd.py +178 -0
- tests/transforms/test_spatial_crop.py +82 -0
- tests/transforms/test_spatial_cropd.py +74 -0
- tests/transforms/test_spatial_pad.py +57 -0
- tests/transforms/test_spatial_padd.py +43 -0
- tests/transforms/test_spatial_resample.py +235 -0
- tests/transforms/test_squeezedim.py +62 -0
- tests/transforms/test_squeezedimd.py +98 -0
- tests/transforms/test_std_shift_intensity.py +76 -0
- tests/transforms/test_std_shift_intensityd.py +74 -0
- tests/transforms/test_threshold_intensity.py +38 -0
- tests/transforms/test_threshold_intensityd.py +58 -0
- tests/transforms/test_to_contiguous.py +47 -0
- tests/transforms/test_to_cupy.py +112 -0
- tests/transforms/test_to_cupyd.py +76 -0
- tests/transforms/test_to_device.py +42 -0
- tests/transforms/test_to_deviced.py +37 -0
- tests/transforms/test_to_numpy.py +85 -0
- tests/transforms/test_to_numpyd.py +68 -0
- tests/transforms/test_to_pil.py +52 -0
- tests/transforms/test_to_pild.py +50 -0
- tests/transforms/test_to_tensor.py +60 -0
- tests/transforms/test_to_tensord.py +71 -0
- tests/transforms/test_torchvision.py +66 -0
- tests/transforms/test_torchvisiond.py +63 -0
- tests/transforms/test_transform.py +62 -0
- tests/transforms/test_transpose.py +41 -0
- tests/transforms/test_transposed.py +52 -0
- tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
- tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
- tests/transforms/test_vote_ensemble.py +84 -0
- tests/transforms/test_vote_ensembled.py +107 -0
- tests/transforms/test_with_allow_missing_keys.py +76 -0
- tests/transforms/test_zoom.py +120 -0
- tests/transforms/test_zoomd.py +94 -0
- tests/transforms/transform/__init__.py +10 -0
- tests/transforms/transform/test_randomizable.py +52 -0
- tests/transforms/transform/test_randomizable_transform_type.py +37 -0
- tests/transforms/utility/__init__.py +10 -0
- tests/transforms/utility/test_apply_transform_to_points.py +81 -0
- tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
- tests/transforms/utility/test_identity.py +29 -0
- tests/transforms/utility/test_identityd.py +30 -0
- tests/transforms/utility/test_lambda.py +71 -0
- tests/transforms/utility/test_lambdad.py +83 -0
- tests/transforms/utility/test_rand_lambda.py +87 -0
- tests/transforms/utility/test_rand_lambdad.py +77 -0
- tests/transforms/utility/test_simulatedelay.py +36 -0
- tests/transforms/utility/test_simulatedelayd.py +36 -0
- tests/transforms/utility/test_splitdim.py +52 -0
- tests/transforms/utility/test_splitdimd.py +96 -0
- tests/transforms/utils/__init__.py +10 -0
- tests/transforms/utils/test_correct_crop_centers.py +36 -0
- tests/transforms/utils/test_get_unique_labels.py +45 -0
- tests/transforms/utils/test_print_transform_backends.py +29 -0
- tests/transforms/utils/test_soft_clip.py +125 -0
- tests/utils/__init__.py +10 -0
- tests/utils/enums/__init__.py +10 -0
- tests/utils/enums/test_hovernet_loss.py +190 -0
- tests/utils/enums/test_ordering.py +289 -0
- tests/utils/enums/test_wsireader.py +663 -0
- tests/utils/misc/__init__.py +10 -0
- tests/utils/misc/test_ensure_tuple.py +53 -0
- tests/utils/misc/test_monai_env_vars.py +44 -0
- tests/utils/misc/test_monai_utils_misc.py +103 -0
- tests/utils/misc/test_str2bool.py +34 -0
- tests/utils/misc/test_str2list.py +33 -0
- tests/utils/test_alias.py +44 -0
- tests/utils/test_component_store.py +73 -0
- tests/utils/test_deprecated.py +455 -0
- tests/utils/test_enum_bound_interp.py +75 -0
- tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
- tests/utils/test_get_package_version.py +34 -0
- tests/utils/test_handler_logfile.py +84 -0
- tests/utils/test_handler_metric_logger.py +62 -0
- tests/utils/test_list_to_dict.py +43 -0
- tests/utils/test_look_up_option.py +87 -0
- tests/utils/test_optional_import.py +80 -0
- tests/utils/test_pad_mode.py +39 -0
- tests/utils/test_profiling.py +208 -0
- tests/utils/test_rankfilter_dist.py +77 -0
- tests/utils/test_require_pkg.py +83 -0
- tests/utils/test_sample_slices.py +43 -0
- tests/utils/test_set_determinism.py +74 -0
- tests/utils/test_squeeze_unsqueeze.py +71 -0
- tests/utils/test_state_cacher.py +67 -0
- tests/utils/test_torchscript_utils.py +113 -0
- tests/utils/test_version.py +91 -0
- tests/utils/test_version_after.py +65 -0
- tests/utils/type_conversion/__init__.py +10 -0
- tests/utils/type_conversion/test_convert_data_type.py +152 -0
- tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
- tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
- tests/visualize/__init__.py +10 -0
- tests/visualize/test_img2tensorboard.py +46 -0
- tests/visualize/test_occlusion_sensitivity.py +128 -0
- tests/visualize/test_plot_2d_or_3d_image.py +74 -0
- tests/visualize/test_vis_cam.py +98 -0
- tests/visualize/test_vis_gradcam.py +211 -0
- tests/visualize/utils/__init__.py +10 -0
- tests/visualize/utils/test_blend_images.py +63 -0
- tests/visualize/utils/test_matshow3d.py +133 -0
- monai_weekly-1.5.dev2506.dist-info/RECORD +0 -427
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/WHEEL +0 -0
@@ -0,0 +1,104 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.networks.schedulers import DDPMScheduler
|
20
|
+
from tests.test_utils import assert_allclose
|
21
|
+
|
22
|
+
TEST_2D_CASE = []
|
23
|
+
for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
|
24
|
+
for variance_type in ["fixed_small", "fixed_large"]:
|
25
|
+
TEST_2D_CASE.append(
|
26
|
+
[{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16), (2, 6, 16, 16)]
|
27
|
+
)
|
28
|
+
|
29
|
+
TEST_3D_CASE = []
|
30
|
+
for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
|
31
|
+
for variance_type in ["fixed_small", "fixed_large"]:
|
32
|
+
TEST_3D_CASE.append(
|
33
|
+
[{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]
|
34
|
+
)
|
35
|
+
|
36
|
+
TEST_CASES = TEST_2D_CASE + TEST_3D_CASE
|
37
|
+
|
38
|
+
TEST_FULl_LOOP = [
|
39
|
+
[{"schedule": "linear_beta"}, (1, 1, 2, 2), torch.Tensor([[[[-1.0153, -0.3218], [0.8454, -0.7870]]]])]
|
40
|
+
]
|
41
|
+
|
42
|
+
|
43
|
+
class TestDDPMScheduler(unittest.TestCase):
|
44
|
+
@parameterized.expand(TEST_CASES)
|
45
|
+
def test_add_noise(self, input_param, input_shape, expected_shape):
|
46
|
+
scheduler = DDPMScheduler(**input_param)
|
47
|
+
original_sample = torch.zeros(input_shape)
|
48
|
+
noise = torch.randn_like(original_sample)
|
49
|
+
timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long()
|
50
|
+
|
51
|
+
noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps)
|
52
|
+
self.assertEqual(noisy.shape, expected_shape)
|
53
|
+
|
54
|
+
@parameterized.expand(TEST_CASES)
|
55
|
+
def test_step_shape(self, input_param, input_shape, expected_shape):
|
56
|
+
scheduler = DDPMScheduler(**input_param)
|
57
|
+
model_output = torch.randn(input_shape)
|
58
|
+
sample = torch.randn(input_shape)
|
59
|
+
output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)
|
60
|
+
self.assertEqual(output_step[0].shape, expected_shape)
|
61
|
+
self.assertEqual(output_step[1].shape, expected_shape)
|
62
|
+
|
63
|
+
@parameterized.expand(TEST_FULl_LOOP)
|
64
|
+
def test_full_timestep_loop(self, input_param, input_shape, expected_output):
|
65
|
+
scheduler = DDPMScheduler(**input_param)
|
66
|
+
scheduler.set_timesteps(50)
|
67
|
+
torch.manual_seed(42)
|
68
|
+
model_output = torch.randn(input_shape)
|
69
|
+
sample = torch.randn(input_shape)
|
70
|
+
for t in range(50):
|
71
|
+
sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample)
|
72
|
+
assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3)
|
73
|
+
|
74
|
+
@parameterized.expand(TEST_CASES)
|
75
|
+
def test_get_velocity_shape(self, input_param, input_shape, expected_shape):
|
76
|
+
scheduler = DDPMScheduler(**input_param)
|
77
|
+
sample = torch.randn(input_shape)
|
78
|
+
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],)).long()
|
79
|
+
velocity = scheduler.get_velocity(sample=sample, noise=sample, timesteps=timesteps)
|
80
|
+
self.assertEqual(velocity.shape, expected_shape)
|
81
|
+
|
82
|
+
def test_step_learned(self):
|
83
|
+
for variance_type in ["learned", "learned_range"]:
|
84
|
+
scheduler = DDPMScheduler(variance_type=variance_type)
|
85
|
+
model_output = torch.randn(2, 6, 16, 16)
|
86
|
+
sample = torch.randn(2, 3, 16, 16)
|
87
|
+
output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)
|
88
|
+
self.assertEqual(output_step[0].shape, sample.shape)
|
89
|
+
self.assertEqual(output_step[1].shape, sample.shape)
|
90
|
+
|
91
|
+
def test_set_timesteps(self):
|
92
|
+
scheduler = DDPMScheduler(num_train_timesteps=1000)
|
93
|
+
scheduler.set_timesteps(num_inference_steps=100)
|
94
|
+
self.assertEqual(scheduler.num_inference_steps, 100)
|
95
|
+
self.assertEqual(len(scheduler.timesteps), 100)
|
96
|
+
|
97
|
+
def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):
|
98
|
+
scheduler = DDPMScheduler(num_train_timesteps=1000)
|
99
|
+
with self.assertRaises(ValueError):
|
100
|
+
scheduler.set_timesteps(num_inference_steps=2000)
|
101
|
+
|
102
|
+
|
103
|
+
if __name__ == "__main__":
|
104
|
+
unittest.main()
|
@@ -0,0 +1,108 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.networks.schedulers import PNDMScheduler
|
20
|
+
from tests.test_utils import assert_allclose
|
21
|
+
|
22
|
+
TEST_2D_CASE = []
|
23
|
+
for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
|
24
|
+
TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)])
|
25
|
+
|
26
|
+
TEST_3D_CASE = []
|
27
|
+
for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
|
28
|
+
TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)])
|
29
|
+
|
30
|
+
TEST_CASES = TEST_2D_CASE + TEST_3D_CASE
|
31
|
+
|
32
|
+
TEST_FULl_LOOP = [
|
33
|
+
[
|
34
|
+
{"schedule": "linear_beta"},
|
35
|
+
(1, 1, 2, 2),
|
36
|
+
torch.Tensor([[[[-2123055.2500, -459014.2812], [2863438.0000, -1263401.7500]]]]),
|
37
|
+
]
|
38
|
+
]
|
39
|
+
|
40
|
+
|
41
|
+
class TestDDPMScheduler(unittest.TestCase):
|
42
|
+
@parameterized.expand(TEST_CASES)
|
43
|
+
def test_add_noise(self, input_param, input_shape, expected_shape):
|
44
|
+
scheduler = PNDMScheduler(**input_param)
|
45
|
+
original_sample = torch.zeros(input_shape)
|
46
|
+
noise = torch.randn_like(original_sample)
|
47
|
+
timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long()
|
48
|
+
noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps)
|
49
|
+
self.assertEqual(noisy.shape, expected_shape)
|
50
|
+
|
51
|
+
@parameterized.expand(TEST_CASES)
|
52
|
+
def test_step_shape(self, input_param, input_shape, expected_shape):
|
53
|
+
scheduler = PNDMScheduler(**input_param)
|
54
|
+
scheduler.set_timesteps(600)
|
55
|
+
model_output = torch.randn(input_shape)
|
56
|
+
sample = torch.randn(input_shape)
|
57
|
+
output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)
|
58
|
+
self.assertEqual(output_step[0].shape, expected_shape)
|
59
|
+
self.assertEqual(output_step[1], None)
|
60
|
+
|
61
|
+
@parameterized.expand(TEST_FULl_LOOP)
|
62
|
+
def test_full_timestep_loop(self, input_param, input_shape, expected_output):
|
63
|
+
scheduler = PNDMScheduler(**input_param)
|
64
|
+
scheduler.set_timesteps(50)
|
65
|
+
torch.manual_seed(42)
|
66
|
+
model_output = torch.randn(input_shape)
|
67
|
+
sample = torch.randn(input_shape)
|
68
|
+
for t in range(50):
|
69
|
+
sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample)
|
70
|
+
assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3)
|
71
|
+
|
72
|
+
@parameterized.expand(TEST_FULl_LOOP)
|
73
|
+
def test_timestep_two_loops(self, input_param, input_shape, expected_output):
|
74
|
+
scheduler = PNDMScheduler(**input_param)
|
75
|
+
scheduler.set_timesteps(50)
|
76
|
+
torch.manual_seed(42)
|
77
|
+
model_output = torch.randn(input_shape)
|
78
|
+
sample = torch.randn(input_shape)
|
79
|
+
for t in range(50):
|
80
|
+
sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample)
|
81
|
+
torch.manual_seed(42)
|
82
|
+
model_output2 = torch.randn(input_shape)
|
83
|
+
sample2 = torch.randn(input_shape)
|
84
|
+
scheduler.set_timesteps(50)
|
85
|
+
for t in range(50):
|
86
|
+
sample2, _ = scheduler.step(model_output=model_output2, timestep=t, sample=sample2)
|
87
|
+
assert_allclose(sample, sample2, rtol=1e-3, atol=1e-3)
|
88
|
+
|
89
|
+
def test_set_timesteps(self):
|
90
|
+
scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=True)
|
91
|
+
scheduler.set_timesteps(num_inference_steps=100)
|
92
|
+
self.assertEqual(scheduler.num_inference_steps, 100)
|
93
|
+
self.assertEqual(len(scheduler.timesteps), 100)
|
94
|
+
|
95
|
+
def test_set_timesteps_prk(self):
|
96
|
+
scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=False)
|
97
|
+
scheduler.set_timesteps(num_inference_steps=100)
|
98
|
+
self.assertEqual(scheduler.num_inference_steps, 109)
|
99
|
+
self.assertEqual(len(scheduler.timesteps), 109)
|
100
|
+
|
101
|
+
def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):
|
102
|
+
scheduler = PNDMScheduler(num_train_timesteps=1000)
|
103
|
+
with self.assertRaises(ValueError):
|
104
|
+
scheduler.set_timesteps(num_inference_steps=2000)
|
105
|
+
|
106
|
+
|
107
|
+
if __name__ == "__main__":
|
108
|
+
unittest.main()
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import os
|
15
|
+
import tempfile
|
16
|
+
import unittest
|
17
|
+
from pathlib import Path
|
18
|
+
|
19
|
+
from parameterized import parameterized
|
20
|
+
|
21
|
+
from monai.bundle import ConfigParser
|
22
|
+
from monai.networks import save_state
|
23
|
+
from tests.test_utils import SkipIfBeforePyTorchVersion, SkipIfNoModule, command_line_tests, skip_if_windows
|
24
|
+
|
25
|
+
TEST_CASE_1 = ["True"]
|
26
|
+
TEST_CASE_2 = ["False"]
|
27
|
+
|
28
|
+
|
29
|
+
@skip_if_windows
|
30
|
+
@SkipIfNoModule("onnx")
|
31
|
+
@SkipIfBeforePyTorchVersion((1, 10))
|
32
|
+
class TestONNXExport(unittest.TestCase):
|
33
|
+
def setUp(self):
|
34
|
+
self.device = os.environ.get("CUDA_VISIBLE_DEVICES")
|
35
|
+
if not self.device:
|
36
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # default
|
37
|
+
|
38
|
+
def tearDown(self):
|
39
|
+
if self.device is not None:
|
40
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = self.device
|
41
|
+
else:
|
42
|
+
del os.environ["CUDA_VISIBLE_DEVICES"] # previously unset
|
43
|
+
|
44
|
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
|
45
|
+
def test_onnx_export(self, use_trace):
|
46
|
+
tests_path = Path(__file__).parents[1]
|
47
|
+
meta_file = os.path.join(tests_path, "testing_data", "metadata.json")
|
48
|
+
config_file = os.path.join(tests_path, "testing_data", "inference.json")
|
49
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
50
|
+
def_args = {"meta_file": "will be replaced by `meta_file` arg"}
|
51
|
+
def_args_file = os.path.join(tempdir, "def_args.yaml")
|
52
|
+
|
53
|
+
ckpt_file = os.path.join(tempdir, "model.pt")
|
54
|
+
onnx_file = os.path.join(tempdir, "model.onnx")
|
55
|
+
|
56
|
+
parser = ConfigParser()
|
57
|
+
parser.export_config_file(config=def_args, filepath=def_args_file)
|
58
|
+
parser.read_config(config_file)
|
59
|
+
net = parser.get_parsed_content("network_def")
|
60
|
+
save_state(src=net, path=ckpt_file)
|
61
|
+
|
62
|
+
cmd = ["python", "-m", "monai.bundle", "onnx_export", "network_def", "--filepath", onnx_file]
|
63
|
+
cmd += ["--meta_file", meta_file, "--config_file", f"['{config_file}','{def_args_file}']"]
|
64
|
+
cmd += ["--ckpt_file", ckpt_file, "--args_file", def_args_file, "--input_shape", "[1, 1, 96, 96, 96]"]
|
65
|
+
cmd += ["--use_trace", use_trace]
|
66
|
+
command_line_tests(cmd)
|
67
|
+
self.assertTrue(os.path.exists(onnx_file))
|
68
|
+
|
69
|
+
|
70
|
+
if __name__ == "__main__":
|
71
|
+
unittest.main()
|
@@ -0,0 +1,106 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import itertools
|
15
|
+
import platform
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from parameterized import parameterized
|
20
|
+
|
21
|
+
from monai.networks import convert_to_onnx
|
22
|
+
from monai.networks.nets import SegResNet, UNet
|
23
|
+
from tests.test_utils import SkipIfBeforePyTorchVersion, SkipIfNoModule, optional_import, skip_if_quick
|
24
|
+
|
25
|
+
if torch.cuda.is_available():
|
26
|
+
TORCH_DEVICE_OPTIONS = ["cpu", "cuda"]
|
27
|
+
else:
|
28
|
+
TORCH_DEVICE_OPTIONS = ["cpu"]
|
29
|
+
TESTS = list(itertools.product(TORCH_DEVICE_OPTIONS, [True, False], [True, False]))
|
30
|
+
TESTS_ORT = list(itertools.product(TORCH_DEVICE_OPTIONS, [True]))
|
31
|
+
|
32
|
+
ON_AARCH64 = platform.machine() == "aarch64"
|
33
|
+
if ON_AARCH64:
|
34
|
+
rtol, atol = 1e-1, 1e-2
|
35
|
+
else:
|
36
|
+
rtol, atol = 1e-3, 1e-4
|
37
|
+
|
38
|
+
onnx, _ = optional_import("onnx")
|
39
|
+
|
40
|
+
|
41
|
+
@SkipIfNoModule("onnx")
|
42
|
+
@SkipIfBeforePyTorchVersion((1, 9))
|
43
|
+
@skip_if_quick
|
44
|
+
class TestConvertToOnnx(unittest.TestCase):
|
45
|
+
@parameterized.expand(TESTS)
|
46
|
+
def test_unet(self, device, use_trace, use_ort):
|
47
|
+
if use_ort:
|
48
|
+
_, has_onnxruntime = optional_import("onnxruntime")
|
49
|
+
if not has_onnxruntime:
|
50
|
+
self.skipTest("onnxruntime is not installed probably due to python version >= 3.11.")
|
51
|
+
model = UNet(
|
52
|
+
spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0
|
53
|
+
)
|
54
|
+
if use_trace:
|
55
|
+
onnx_model = convert_to_onnx(
|
56
|
+
model=model,
|
57
|
+
inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)],
|
58
|
+
input_names=["x"],
|
59
|
+
output_names=["y"],
|
60
|
+
verify=True,
|
61
|
+
device=device,
|
62
|
+
use_ort=use_ort,
|
63
|
+
use_trace=use_trace,
|
64
|
+
rtol=rtol,
|
65
|
+
atol=atol,
|
66
|
+
)
|
67
|
+
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))
|
68
|
+
|
69
|
+
@parameterized.expand(TESTS_ORT)
|
70
|
+
@SkipIfBeforePyTorchVersion((1, 12))
|
71
|
+
def test_seg_res_net(self, device, use_ort):
|
72
|
+
if use_ort:
|
73
|
+
_, has_onnxruntime = optional_import("onnxruntime")
|
74
|
+
if not has_onnxruntime:
|
75
|
+
self.skipTest("onnxruntime is not installed probably due to python version >= 3.11.")
|
76
|
+
model = SegResNet(
|
77
|
+
spatial_dims=3,
|
78
|
+
init_filters=32,
|
79
|
+
in_channels=1,
|
80
|
+
out_channels=105,
|
81
|
+
dropout_prob=0.2,
|
82
|
+
act=("RELU", {"inplace": True}),
|
83
|
+
norm=("GROUP", {"num_groups": 8}),
|
84
|
+
norm_name="",
|
85
|
+
num_groups=8,
|
86
|
+
use_conv_final=True,
|
87
|
+
blocks_down=[1, 2, 2, 4],
|
88
|
+
blocks_up=[1, 1, 1],
|
89
|
+
)
|
90
|
+
onnx_model = convert_to_onnx(
|
91
|
+
model=model,
|
92
|
+
inputs=[torch.randn((1, 1, 24, 24, 24), requires_grad=False)],
|
93
|
+
input_names=["x"],
|
94
|
+
output_names=["y"],
|
95
|
+
verify=True,
|
96
|
+
device=device,
|
97
|
+
use_ort=use_ort,
|
98
|
+
use_trace=True,
|
99
|
+
rtol=rtol,
|
100
|
+
atol=atol,
|
101
|
+
)
|
102
|
+
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))
|
103
|
+
|
104
|
+
|
105
|
+
if __name__ == "__main__":
|
106
|
+
unittest.main()
|
@@ -0,0 +1,46 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import os
|
15
|
+
import tempfile
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from monai.networks import convert_to_torchscript
|
21
|
+
from monai.networks.nets import UNet
|
22
|
+
|
23
|
+
|
24
|
+
class TestConvertToTorchScript(unittest.TestCase):
|
25
|
+
|
26
|
+
def test_value(self):
|
27
|
+
model = UNet(
|
28
|
+
spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0
|
29
|
+
)
|
30
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
31
|
+
torchscript_model = convert_to_torchscript(
|
32
|
+
model=model,
|
33
|
+
filename_or_obj=os.path.join(tempdir, "model.ts"),
|
34
|
+
extra_files={"foo.txt": b"bar"},
|
35
|
+
verify=True,
|
36
|
+
inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)],
|
37
|
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
38
|
+
rtol=1e-3,
|
39
|
+
atol=1e-4,
|
40
|
+
optimize=None,
|
41
|
+
)
|
42
|
+
self.assertTrue(isinstance(torchscript_model, torch.nn.Module))
|
43
|
+
|
44
|
+
|
45
|
+
if __name__ == "__main__":
|
46
|
+
unittest.main()
|
@@ -0,0 +1,79 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import tempfile
|
15
|
+
import unittest
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.networks import convert_to_trt
|
21
|
+
from monai.networks.nets import UNet
|
22
|
+
from monai.utils import optional_import
|
23
|
+
from tests.test_utils import SkipIfBeforeComputeCapabilityVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows
|
24
|
+
|
25
|
+
_, has_torchtrt = optional_import(
|
26
|
+
"torch_tensorrt",
|
27
|
+
version="1.4.0",
|
28
|
+
descriptor="Torch-TRT is not installed. Are you sure you have a Torch-TensorRT compilation?",
|
29
|
+
)
|
30
|
+
_, has_tensorrt = optional_import(
|
31
|
+
"tensorrt", descriptor="TensorRT is not installed. Are you sure you have a TensorRT compilation?"
|
32
|
+
)
|
33
|
+
|
34
|
+
TEST_CASE_1 = ["fp32"]
|
35
|
+
TEST_CASE_2 = ["fp16"]
|
36
|
+
|
37
|
+
|
38
|
+
@skip_if_windows
|
39
|
+
@skip_if_no_cuda
|
40
|
+
@skip_if_quick
|
41
|
+
@SkipIfBeforeComputeCapabilityVersion((7, 5))
|
42
|
+
class TestConvertToTRT(unittest.TestCase):
|
43
|
+
def setUp(self):
|
44
|
+
self.gpu_device = torch.cuda.current_device()
|
45
|
+
|
46
|
+
def tearDown(self):
|
47
|
+
current_device = torch.cuda.current_device()
|
48
|
+
if current_device != self.gpu_device:
|
49
|
+
torch.cuda.set_device(self.gpu_device)
|
50
|
+
|
51
|
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
|
52
|
+
@unittest.skipUnless(has_torchtrt and has_tensorrt, "Torch-TensorRT is required for convert!")
|
53
|
+
def test_value(self, precision):
|
54
|
+
model = UNet(
|
55
|
+
spatial_dims=3,
|
56
|
+
in_channels=1,
|
57
|
+
out_channels=2,
|
58
|
+
channels=(2, 2, 4, 8, 4),
|
59
|
+
strides=(2, 2, 2, 2),
|
60
|
+
num_res_units=2,
|
61
|
+
norm="batch",
|
62
|
+
)
|
63
|
+
with tempfile.TemporaryDirectory() as _:
|
64
|
+
torchscript_model = convert_to_trt(
|
65
|
+
model=model,
|
66
|
+
precision=precision,
|
67
|
+
input_shape=[1, 1, 96, 96, 96],
|
68
|
+
dynamic_batchsize=[1, 4, 8],
|
69
|
+
use_trace=False,
|
70
|
+
verify=True,
|
71
|
+
device=0,
|
72
|
+
rtol=1e-2,
|
73
|
+
atol=1e-2,
|
74
|
+
)
|
75
|
+
self.assertTrue(isinstance(torchscript_model, torch.nn.Module))
|
76
|
+
|
77
|
+
|
78
|
+
if __name__ == "__main__":
|
79
|
+
unittest.main()
|
@@ -0,0 +1,73 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import os
|
15
|
+
import tempfile
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.optim as optim
|
20
|
+
from parameterized import parameterized
|
21
|
+
|
22
|
+
from monai.networks import save_state
|
23
|
+
|
24
|
+
TEST_CASE_1 = [torch.nn.PReLU(), ["weight"]]
|
25
|
+
|
26
|
+
TEST_CASE_2 = [{"net": torch.nn.PReLU()}, ["net"]]
|
27
|
+
|
28
|
+
TEST_CASE_3 = [{"net": torch.nn.PReLU(), "opt": optim.SGD(torch.nn.PReLU().parameters(), lr=0.02)}, ["net", "opt"]]
|
29
|
+
|
30
|
+
TEST_CASE_4 = [torch.nn.DataParallel(torch.nn.PReLU()), ["weight"]]
|
31
|
+
|
32
|
+
TEST_CASE_5 = [{"net": torch.nn.DataParallel(torch.nn.PReLU())}, ["net"]]
|
33
|
+
|
34
|
+
TEST_CASE_6 = [torch.nn.PReLU(), ["weight"], True, True, None, {"pickle_protocol": 2}]
|
35
|
+
|
36
|
+
TEST_CASE_7 = [torch.nn.PReLU().state_dict(), ["weight"]]
|
37
|
+
|
38
|
+
TEST_CASE_8 = [torch.nn.PReLU(), ["weight"], False]
|
39
|
+
|
40
|
+
TEST_CASE_9 = [torch.nn.PReLU(), ["weight"], True, False]
|
41
|
+
|
42
|
+
TEST_CASE_10 = [torch.nn.PReLU(), ["weight"], True, True, torch.save]
|
43
|
+
|
44
|
+
|
45
|
+
class TestSaveState(unittest.TestCase):
|
46
|
+
|
47
|
+
@parameterized.expand(
|
48
|
+
[
|
49
|
+
TEST_CASE_1,
|
50
|
+
TEST_CASE_2,
|
51
|
+
TEST_CASE_3,
|
52
|
+
TEST_CASE_4,
|
53
|
+
TEST_CASE_5,
|
54
|
+
TEST_CASE_6,
|
55
|
+
TEST_CASE_7,
|
56
|
+
TEST_CASE_8,
|
57
|
+
TEST_CASE_9,
|
58
|
+
TEST_CASE_10,
|
59
|
+
]
|
60
|
+
)
|
61
|
+
def test_file(self, src, expected_keys, create_dir=True, atomic=True, func=None, kwargs=None):
|
62
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
63
|
+
path = os.path.join(tempdir, "test_ckpt.pt")
|
64
|
+
if kwargs is None:
|
65
|
+
kwargs = {}
|
66
|
+
save_state(src=src, path=path, create_dir=create_dir, atomic=atomic, func=func, **kwargs)
|
67
|
+
ckpt = dict(torch.load(path))
|
68
|
+
for k in ckpt.keys():
|
69
|
+
self.assertIn(k, expected_keys)
|
70
|
+
|
71
|
+
|
72
|
+
if __name__ == "__main__":
|
73
|
+
unittest.main()
|
@@ -0,0 +1,63 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.networks import one_hot
|
21
|
+
|
22
|
+
TEST_CASE_1 = [ # single channel 2D, batch 3, shape (2, 1, 2, 2)
|
23
|
+
{"labels": torch.tensor([[[[0, 1], [1, 2]]], [[[2, 1], [1, 0]]]]), "num_classes": 3},
|
24
|
+
(2, 3, 2, 2),
|
25
|
+
]
|
26
|
+
|
27
|
+
TEST_CASE_2 = [ # single channel 1D, batch 2, shape (2, 1, 4)
|
28
|
+
{"labels": torch.tensor([[[1, 2, 2, 0]], [[2, 1, 0, 1]]]), "num_classes": 3},
|
29
|
+
(2, 3, 4),
|
30
|
+
np.array([[[0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 1, 0]], [[0, 0, 1, 0], [0, 1, 0, 1], [1, 0, 0, 0]]]),
|
31
|
+
]
|
32
|
+
|
33
|
+
TEST_CASE_3 = [ # single channel 0D, batch 2, shape (2, 1)
|
34
|
+
{"labels": torch.tensor([[1.0], [2.0]]), "num_classes": 3},
|
35
|
+
(2, 3),
|
36
|
+
np.array([[0, 1, 0], [0, 0, 1]]),
|
37
|
+
]
|
38
|
+
|
39
|
+
TEST_CASE_4 = [ # no channel 0D, batch 3, shape (3)
|
40
|
+
{"labels": torch.tensor([1, 2, 0]), "num_classes": 3, "dtype": torch.long},
|
41
|
+
(3, 3),
|
42
|
+
np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]),
|
43
|
+
]
|
44
|
+
|
45
|
+
|
46
|
+
class TestToOneHot(unittest.TestCase):
|
47
|
+
|
48
|
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
|
49
|
+
def test_shape(self, input_data, expected_shape, expected_result=None):
|
50
|
+
result = one_hot(**input_data)
|
51
|
+
self.assertEqual(result.shape, expected_shape)
|
52
|
+
if expected_result is not None:
|
53
|
+
self.assertTrue(np.allclose(expected_result, result.numpy()))
|
54
|
+
|
55
|
+
if "dtype" in input_data:
|
56
|
+
self.assertEqual(result.dtype, input_data["dtype"])
|
57
|
+
else:
|
58
|
+
# by default, expecting float type
|
59
|
+
self.assertEqual(result.dtype, torch.float)
|
60
|
+
|
61
|
+
|
62
|
+
if __name__ == "__main__":
|
63
|
+
unittest.main()
|