monai-weekly 1.5.dev2505__py3-none-any.whl → 1.5.dev2507__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (779) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/transforms.py +1 -4
  4. monai/data/meta_tensor.py +5 -0
  5. monai/data/utils.py +6 -13
  6. monai/inferers/utils.py +1 -2
  7. monai/losses/dice.py +2 -14
  8. monai/losses/ds_loss.py +1 -3
  9. monai/networks/layers/simplelayers.py +2 -14
  10. monai/networks/utils.py +4 -16
  11. monai/transforms/compose.py +28 -11
  12. monai/transforms/croppad/array.py +1 -6
  13. monai/transforms/io/array.py +0 -1
  14. monai/transforms/transform.py +15 -6
  15. monai/transforms/utils.py +1 -2
  16. monai/utils/jupyter_utils.py +1 -1
  17. monai/utils/tf32.py +0 -10
  18. monai/visualize/class_activation_maps.py +5 -8
  19. monai/visualize/img2tensorboard.py +2 -2
  20. {monai_weekly-1.5.dev2505.dist-info → monai_weekly-1.5.dev2507.dist-info}/METADATA +2 -2
  21. monai_weekly-1.5.dev2507.dist-info/RECORD +1181 -0
  22. {monai_weekly-1.5.dev2505.dist-info → monai_weekly-1.5.dev2507.dist-info}/top_level.txt +1 -0
  23. tests/apps/__init__.py +10 -0
  24. tests/apps/deepedit/__init__.py +10 -0
  25. tests/apps/deepedit/test_deepedit_transforms.py +314 -0
  26. tests/apps/deepgrow/__init__.py +10 -0
  27. tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
  28. tests/apps/deepgrow/transforms/__init__.py +10 -0
  29. tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
  30. tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
  31. tests/apps/detection/__init__.py +10 -0
  32. tests/apps/detection/metrics/__init__.py +10 -0
  33. tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
  34. tests/apps/detection/networks/__init__.py +10 -0
  35. tests/apps/detection/networks/test_retinanet.py +210 -0
  36. tests/apps/detection/networks/test_retinanet_detector.py +203 -0
  37. tests/apps/detection/test_box_transform.py +370 -0
  38. tests/apps/detection/utils/__init__.py +10 -0
  39. tests/apps/detection/utils/test_anchor_box.py +88 -0
  40. tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
  41. tests/apps/detection/utils/test_box_coder.py +43 -0
  42. tests/apps/detection/utils/test_detector_boxselector.py +67 -0
  43. tests/apps/detection/utils/test_detector_utils.py +96 -0
  44. tests/apps/detection/utils/test_hardnegsampler.py +54 -0
  45. tests/apps/nuclick/__init__.py +10 -0
  46. tests/apps/nuclick/test_nuclick_transforms.py +259 -0
  47. tests/apps/pathology/__init__.py +10 -0
  48. tests/apps/pathology/handlers/__init__.py +10 -0
  49. tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
  50. tests/apps/pathology/test_lesion_froc.py +333 -0
  51. tests/apps/pathology/test_pathology_prob_nms.py +55 -0
  52. tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
  53. tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
  54. tests/apps/pathology/transforms/__init__.py +10 -0
  55. tests/apps/pathology/transforms/post/__init__.py +10 -0
  56. tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
  57. tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
  58. tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
  59. tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
  60. tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
  61. tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
  62. tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
  63. tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
  64. tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
  65. tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
  66. tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
  67. tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
  68. tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
  69. tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
  70. tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
  71. tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
  72. tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
  73. tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
  74. tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
  75. tests/apps/pathology/transforms/post/test_watershed.py +60 -0
  76. tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
  77. tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
  78. tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
  79. tests/apps/reconstruction/__init__.py +10 -0
  80. tests/apps/reconstruction/nets/__init__.py +10 -0
  81. tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
  82. tests/apps/reconstruction/test_complex_utils.py +77 -0
  83. tests/apps/reconstruction/test_fastmri_reader.py +82 -0
  84. tests/apps/reconstruction/test_mri_utils.py +37 -0
  85. tests/apps/reconstruction/transforms/__init__.py +10 -0
  86. tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
  87. tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
  88. tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
  89. tests/apps/test_auto3dseg_bundlegen.py +156 -0
  90. tests/apps/test_check_hash.py +53 -0
  91. tests/apps/test_cross_validation.py +74 -0
  92. tests/apps/test_decathlondataset.py +93 -0
  93. tests/apps/test_download_and_extract.py +70 -0
  94. tests/apps/test_download_url_yandex.py +45 -0
  95. tests/apps/test_mednistdataset.py +72 -0
  96. tests/apps/test_mmar_download.py +154 -0
  97. tests/apps/test_tciadataset.py +123 -0
  98. tests/apps/vista3d/__init__.py +10 -0
  99. tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
  100. tests/apps/vista3d/test_vista3d_sampler.py +100 -0
  101. tests/apps/vista3d/test_vista3d_transforms.py +94 -0
  102. tests/bundle/__init__.py +10 -0
  103. tests/bundle/test_bundle_ckpt_export.py +107 -0
  104. tests/bundle/test_bundle_download.py +435 -0
  105. tests/bundle/test_bundle_get_data.py +94 -0
  106. tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
  107. tests/bundle/test_bundle_trt_export.py +147 -0
  108. tests/bundle/test_bundle_utils.py +149 -0
  109. tests/bundle/test_bundle_verify_metadata.py +66 -0
  110. tests/bundle/test_bundle_verify_net.py +76 -0
  111. tests/bundle/test_bundle_workflow.py +272 -0
  112. tests/bundle/test_component_locator.py +38 -0
  113. tests/bundle/test_config_item.py +138 -0
  114. tests/bundle/test_config_parser.py +392 -0
  115. tests/bundle/test_reference_resolver.py +114 -0
  116. tests/config/__init__.py +10 -0
  117. tests/config/test_cv2_dist.py +53 -0
  118. tests/engines/__init__.py +10 -0
  119. tests/engines/test_ensemble_evaluator.py +94 -0
  120. tests/engines/test_prepare_batch_default.py +76 -0
  121. tests/engines/test_prepare_batch_default_dist.py +76 -0
  122. tests/engines/test_prepare_batch_diffusion.py +104 -0
  123. tests/engines/test_prepare_batch_extra_input.py +80 -0
  124. tests/fl/__init__.py +10 -0
  125. tests/fl/monai_algo/__init__.py +10 -0
  126. tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
  127. tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
  128. tests/fl/test_fl_monai_algo_stats.py +81 -0
  129. tests/fl/utils/__init__.py +10 -0
  130. tests/fl/utils/test_fl_exchange_object.py +63 -0
  131. tests/handlers/__init__.py +10 -0
  132. tests/handlers/test_handler_checkpoint_loader.py +182 -0
  133. tests/handlers/test_handler_checkpoint_saver.py +233 -0
  134. tests/handlers/test_handler_classification_saver.py +64 -0
  135. tests/handlers/test_handler_classification_saver_dist.py +77 -0
  136. tests/handlers/test_handler_clearml_image.py +65 -0
  137. tests/handlers/test_handler_clearml_stats.py +65 -0
  138. tests/handlers/test_handler_confusion_matrix.py +104 -0
  139. tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
  140. tests/handlers/test_handler_decollate_batch.py +66 -0
  141. tests/handlers/test_handler_early_stop.py +68 -0
  142. tests/handlers/test_handler_garbage_collector.py +73 -0
  143. tests/handlers/test_handler_hausdorff_distance.py +111 -0
  144. tests/handlers/test_handler_ignite_metric.py +191 -0
  145. tests/handlers/test_handler_lr_scheduler.py +94 -0
  146. tests/handlers/test_handler_mean_dice.py +98 -0
  147. tests/handlers/test_handler_mean_iou.py +76 -0
  148. tests/handlers/test_handler_metrics_reloaded.py +149 -0
  149. tests/handlers/test_handler_metrics_saver.py +89 -0
  150. tests/handlers/test_handler_metrics_saver_dist.py +120 -0
  151. tests/handlers/test_handler_mlflow.py +296 -0
  152. tests/handlers/test_handler_nvtx.py +93 -0
  153. tests/handlers/test_handler_panoptic_quality.py +89 -0
  154. tests/handlers/test_handler_parameter_scheduler.py +136 -0
  155. tests/handlers/test_handler_post_processing.py +74 -0
  156. tests/handlers/test_handler_prob_map_producer.py +111 -0
  157. tests/handlers/test_handler_regression_metrics.py +160 -0
  158. tests/handlers/test_handler_regression_metrics_dist.py +245 -0
  159. tests/handlers/test_handler_rocauc.py +48 -0
  160. tests/handlers/test_handler_rocauc_dist.py +54 -0
  161. tests/handlers/test_handler_stats.py +281 -0
  162. tests/handlers/test_handler_surface_distance.py +113 -0
  163. tests/handlers/test_handler_tb_image.py +61 -0
  164. tests/handlers/test_handler_tb_stats.py +166 -0
  165. tests/handlers/test_handler_validation.py +59 -0
  166. tests/handlers/test_trt_compile.py +145 -0
  167. tests/handlers/test_write_metrics_reports.py +68 -0
  168. tests/inferers/__init__.py +10 -0
  169. tests/inferers/test_avg_merger.py +179 -0
  170. tests/inferers/test_controlnet_inferers.py +1310 -0
  171. tests/inferers/test_diffusion_inferer.py +236 -0
  172. tests/inferers/test_latent_diffusion_inferer.py +824 -0
  173. tests/inferers/test_patch_inferer.py +309 -0
  174. tests/inferers/test_saliency_inferer.py +55 -0
  175. tests/inferers/test_slice_inferer.py +57 -0
  176. tests/inferers/test_sliding_window_inference.py +377 -0
  177. tests/inferers/test_sliding_window_splitter.py +284 -0
  178. tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
  179. tests/inferers/test_zarr_avg_merger.py +326 -0
  180. tests/integration/__init__.py +10 -0
  181. tests/integration/test_auto3dseg_ensemble.py +211 -0
  182. tests/integration/test_auto3dseg_hpo.py +189 -0
  183. tests/integration/test_deepedit_interaction.py +122 -0
  184. tests/integration/test_downsample_block.py +50 -0
  185. tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
  186. tests/integration/test_integration_autorunner.py +201 -0
  187. tests/integration/test_integration_bundle_run.py +240 -0
  188. tests/integration/test_integration_classification_2d.py +282 -0
  189. tests/integration/test_integration_determinism.py +95 -0
  190. tests/integration/test_integration_fast_train.py +231 -0
  191. tests/integration/test_integration_gpu_customization.py +159 -0
  192. tests/integration/test_integration_lazy_samples.py +219 -0
  193. tests/integration/test_integration_nnunetv2_runner.py +96 -0
  194. tests/integration/test_integration_segmentation_3d.py +304 -0
  195. tests/integration/test_integration_sliding_window.py +100 -0
  196. tests/integration/test_integration_stn.py +133 -0
  197. tests/integration/test_integration_unet_2d.py +67 -0
  198. tests/integration/test_integration_workers.py +61 -0
  199. tests/integration/test_integration_workflows.py +365 -0
  200. tests/integration/test_integration_workflows_adversarial.py +173 -0
  201. tests/integration/test_integration_workflows_gan.py +158 -0
  202. tests/integration/test_loader_semaphore.py +48 -0
  203. tests/integration/test_mapping_filed.py +122 -0
  204. tests/integration/test_meta_affine.py +183 -0
  205. tests/integration/test_metatensor_integration.py +114 -0
  206. tests/integration/test_module_list.py +76 -0
  207. tests/integration/test_one_of.py +283 -0
  208. tests/integration/test_pad_collation.py +124 -0
  209. tests/integration/test_reg_loss_integration.py +107 -0
  210. tests/integration/test_retinanet_predict_utils.py +154 -0
  211. tests/integration/test_seg_loss_integration.py +159 -0
  212. tests/integration/test_spatial_combine_transforms.py +185 -0
  213. tests/integration/test_testtimeaugmentation.py +186 -0
  214. tests/integration/test_vis_gradbased.py +69 -0
  215. tests/integration/test_vista3d_utils.py +159 -0
  216. tests/losses/__init__.py +10 -0
  217. tests/losses/deform/__init__.py +10 -0
  218. tests/losses/deform/test_bending_energy.py +88 -0
  219. tests/losses/deform/test_diffusion_loss.py +117 -0
  220. tests/losses/image_dissimilarity/__init__.py +10 -0
  221. tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
  222. tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
  223. tests/losses/test_adversarial_loss.py +94 -0
  224. tests/losses/test_barlow_twins_loss.py +109 -0
  225. tests/losses/test_cldice_loss.py +51 -0
  226. tests/losses/test_contrastive_loss.py +86 -0
  227. tests/losses/test_dice_ce_loss.py +123 -0
  228. tests/losses/test_dice_focal_loss.py +124 -0
  229. tests/losses/test_dice_loss.py +227 -0
  230. tests/losses/test_ds_loss.py +189 -0
  231. tests/losses/test_focal_loss.py +379 -0
  232. tests/losses/test_generalized_dice_focal_loss.py +85 -0
  233. tests/losses/test_generalized_dice_loss.py +221 -0
  234. tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
  235. tests/losses/test_giou_loss.py +62 -0
  236. tests/losses/test_hausdorff_loss.py +264 -0
  237. tests/losses/test_masked_dice_loss.py +152 -0
  238. tests/losses/test_masked_loss.py +87 -0
  239. tests/losses/test_multi_scale.py +86 -0
  240. tests/losses/test_nacl_loss.py +167 -0
  241. tests/losses/test_perceptual_loss.py +122 -0
  242. tests/losses/test_spectral_loss.py +86 -0
  243. tests/losses/test_ssim_loss.py +59 -0
  244. tests/losses/test_sure_loss.py +72 -0
  245. tests/losses/test_tversky_loss.py +198 -0
  246. tests/losses/test_unified_focal_loss.py +66 -0
  247. tests/metrics/__init__.py +10 -0
  248. tests/metrics/test_compute_confusion_matrix.py +294 -0
  249. tests/metrics/test_compute_f_beta.py +80 -0
  250. tests/metrics/test_compute_fid_metric.py +40 -0
  251. tests/metrics/test_compute_froc.py +143 -0
  252. tests/metrics/test_compute_generalized_dice.py +240 -0
  253. tests/metrics/test_compute_meandice.py +306 -0
  254. tests/metrics/test_compute_meaniou.py +223 -0
  255. tests/metrics/test_compute_mmd_metric.py +56 -0
  256. tests/metrics/test_compute_multiscalessim_metric.py +83 -0
  257. tests/metrics/test_compute_panoptic_quality.py +113 -0
  258. tests/metrics/test_compute_regression_metrics.py +196 -0
  259. tests/metrics/test_compute_roc_auc.py +155 -0
  260. tests/metrics/test_compute_variance.py +147 -0
  261. tests/metrics/test_cumulative.py +63 -0
  262. tests/metrics/test_cumulative_average.py +74 -0
  263. tests/metrics/test_cumulative_average_dist.py +48 -0
  264. tests/metrics/test_hausdorff_distance.py +209 -0
  265. tests/metrics/test_label_quality_score.py +134 -0
  266. tests/metrics/test_loss_metric.py +57 -0
  267. tests/metrics/test_metrics_reloaded.py +96 -0
  268. tests/metrics/test_ssim_metric.py +78 -0
  269. tests/metrics/test_surface_dice.py +416 -0
  270. tests/metrics/test_surface_distance.py +186 -0
  271. tests/networks/__init__.py +10 -0
  272. tests/networks/blocks/__init__.py +10 -0
  273. tests/networks/blocks/dints_block/__init__.py +10 -0
  274. tests/networks/blocks/dints_block/test_acn_block.py +41 -0
  275. tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
  276. tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
  277. tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
  278. tests/networks/blocks/test_adn.py +86 -0
  279. tests/networks/blocks/test_convolutions.py +156 -0
  280. tests/networks/blocks/test_crf_cpu.py +513 -0
  281. tests/networks/blocks/test_crf_cuda.py +528 -0
  282. tests/networks/blocks/test_crossattention.py +185 -0
  283. tests/networks/blocks/test_denseblock.py +105 -0
  284. tests/networks/blocks/test_dynunet_block.py +116 -0
  285. tests/networks/blocks/test_fpn_block.py +88 -0
  286. tests/networks/blocks/test_localnet_block.py +121 -0
  287. tests/networks/blocks/test_mlp.py +78 -0
  288. tests/networks/blocks/test_patchembedding.py +212 -0
  289. tests/networks/blocks/test_regunet_block.py +103 -0
  290. tests/networks/blocks/test_se_block.py +85 -0
  291. tests/networks/blocks/test_se_blocks.py +78 -0
  292. tests/networks/blocks/test_segresnet_block.py +57 -0
  293. tests/networks/blocks/test_selfattention.py +232 -0
  294. tests/networks/blocks/test_simple_aspp.py +87 -0
  295. tests/networks/blocks/test_spatialattention.py +55 -0
  296. tests/networks/blocks/test_subpixel_upsample.py +87 -0
  297. tests/networks/blocks/test_text_encoding.py +49 -0
  298. tests/networks/blocks/test_transformerblock.py +90 -0
  299. tests/networks/blocks/test_unetr_block.py +158 -0
  300. tests/networks/blocks/test_upsample_block.py +134 -0
  301. tests/networks/blocks/warp/__init__.py +10 -0
  302. tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
  303. tests/networks/blocks/warp/test_warp.py +250 -0
  304. tests/networks/layers/__init__.py +10 -0
  305. tests/networks/layers/filtering/__init__.py +10 -0
  306. tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
  307. tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
  308. tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
  309. tests/networks/layers/filtering/test_phl_cpu.py +259 -0
  310. tests/networks/layers/filtering/test_phl_cuda.py +167 -0
  311. tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
  312. tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
  313. tests/networks/layers/test_affine_transform.py +385 -0
  314. tests/networks/layers/test_apply_filter.py +89 -0
  315. tests/networks/layers/test_channel_pad.py +51 -0
  316. tests/networks/layers/test_conjugate_gradient.py +56 -0
  317. tests/networks/layers/test_drop_path.py +46 -0
  318. tests/networks/layers/test_gaussian.py +317 -0
  319. tests/networks/layers/test_gaussian_filter.py +206 -0
  320. tests/networks/layers/test_get_layers.py +65 -0
  321. tests/networks/layers/test_gmm.py +314 -0
  322. tests/networks/layers/test_grid_pull.py +93 -0
  323. tests/networks/layers/test_hilbert_transform.py +131 -0
  324. tests/networks/layers/test_lltm.py +62 -0
  325. tests/networks/layers/test_median_filter.py +52 -0
  326. tests/networks/layers/test_polyval.py +55 -0
  327. tests/networks/layers/test_preset_filters.py +136 -0
  328. tests/networks/layers/test_savitzky_golay_filter.py +141 -0
  329. tests/networks/layers/test_separable_filter.py +87 -0
  330. tests/networks/layers/test_skip_connection.py +48 -0
  331. tests/networks/layers/test_vector_quantizer.py +89 -0
  332. tests/networks/layers/test_weight_init.py +50 -0
  333. tests/networks/nets/__init__.py +10 -0
  334. tests/networks/nets/dints/__init__.py +10 -0
  335. tests/networks/nets/dints/test_dints_cell.py +110 -0
  336. tests/networks/nets/dints/test_dints_mixop.py +84 -0
  337. tests/networks/nets/regunet/__init__.py +10 -0
  338. tests/networks/nets/regunet/test_localnet.py +86 -0
  339. tests/networks/nets/regunet/test_regunet.py +88 -0
  340. tests/networks/nets/test_ahnet.py +224 -0
  341. tests/networks/nets/test_attentionunet.py +88 -0
  342. tests/networks/nets/test_autoencoder.py +95 -0
  343. tests/networks/nets/test_autoencoderkl.py +337 -0
  344. tests/networks/nets/test_basic_unet.py +102 -0
  345. tests/networks/nets/test_basic_unetplusplus.py +109 -0
  346. tests/networks/nets/test_bundle_init_bundle.py +55 -0
  347. tests/networks/nets/test_cell_sam_wrapper.py +58 -0
  348. tests/networks/nets/test_controlnet.py +215 -0
  349. tests/networks/nets/test_daf3d.py +62 -0
  350. tests/networks/nets/test_densenet.py +121 -0
  351. tests/networks/nets/test_diffusion_model_unet.py +585 -0
  352. tests/networks/nets/test_dints_network.py +168 -0
  353. tests/networks/nets/test_discriminator.py +59 -0
  354. tests/networks/nets/test_dynunet.py +181 -0
  355. tests/networks/nets/test_efficientnet.py +400 -0
  356. tests/networks/nets/test_flexible_unet.py +341 -0
  357. tests/networks/nets/test_fullyconnectednet.py +69 -0
  358. tests/networks/nets/test_generator.py +59 -0
  359. tests/networks/nets/test_globalnet.py +103 -0
  360. tests/networks/nets/test_highresnet.py +67 -0
  361. tests/networks/nets/test_hovernet.py +218 -0
  362. tests/networks/nets/test_mednext.py +122 -0
  363. tests/networks/nets/test_milmodel.py +92 -0
  364. tests/networks/nets/test_net_adapter.py +68 -0
  365. tests/networks/nets/test_network_consistency.py +86 -0
  366. tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
  367. tests/networks/nets/test_quicknat.py +57 -0
  368. tests/networks/nets/test_resnet.py +340 -0
  369. tests/networks/nets/test_segresnet.py +120 -0
  370. tests/networks/nets/test_segresnet_ds.py +156 -0
  371. tests/networks/nets/test_senet.py +151 -0
  372. tests/networks/nets/test_spade_autoencoderkl.py +295 -0
  373. tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
  374. tests/networks/nets/test_spade_vaegan.py +140 -0
  375. tests/networks/nets/test_swin_unetr.py +139 -0
  376. tests/networks/nets/test_torchvision_fc_model.py +201 -0
  377. tests/networks/nets/test_transchex.py +84 -0
  378. tests/networks/nets/test_transformer.py +108 -0
  379. tests/networks/nets/test_unet.py +208 -0
  380. tests/networks/nets/test_unetr.py +137 -0
  381. tests/networks/nets/test_varautoencoder.py +127 -0
  382. tests/networks/nets/test_vista3d.py +84 -0
  383. tests/networks/nets/test_vit.py +139 -0
  384. tests/networks/nets/test_vitautoenc.py +112 -0
  385. tests/networks/nets/test_vnet.py +81 -0
  386. tests/networks/nets/test_voxelmorph.py +280 -0
  387. tests/networks/nets/test_vqvae.py +274 -0
  388. tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
  389. tests/networks/schedulers/__init__.py +10 -0
  390. tests/networks/schedulers/test_scheduler_ddim.py +83 -0
  391. tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
  392. tests/networks/schedulers/test_scheduler_pndm.py +108 -0
  393. tests/networks/test_bundle_onnx_export.py +71 -0
  394. tests/networks/test_convert_to_onnx.py +106 -0
  395. tests/networks/test_convert_to_torchscript.py +46 -0
  396. tests/networks/test_convert_to_trt.py +79 -0
  397. tests/networks/test_save_state.py +73 -0
  398. tests/networks/test_to_onehot.py +63 -0
  399. tests/networks/test_varnet.py +63 -0
  400. tests/networks/utils/__init__.py +10 -0
  401. tests/networks/utils/test_copy_model_state.py +187 -0
  402. tests/networks/utils/test_eval_mode.py +34 -0
  403. tests/networks/utils/test_freeze_layers.py +61 -0
  404. tests/networks/utils/test_replace_module.py +98 -0
  405. tests/networks/utils/test_train_mode.py +34 -0
  406. tests/optimizers/__init__.py +10 -0
  407. tests/optimizers/test_generate_param_groups.py +105 -0
  408. tests/optimizers/test_lr_finder.py +108 -0
  409. tests/optimizers/test_lr_scheduler.py +71 -0
  410. tests/optimizers/test_optim_novograd.py +100 -0
  411. tests/profile_subclass/__init__.py +10 -0
  412. tests/profile_subclass/cprofile_profiling.py +29 -0
  413. tests/profile_subclass/min_classes.py +30 -0
  414. tests/profile_subclass/profiling.py +73 -0
  415. tests/profile_subclass/pyspy_profiling.py +41 -0
  416. tests/transforms/__init__.py +10 -0
  417. tests/transforms/compose/__init__.py +10 -0
  418. tests/transforms/compose/test_compose.py +758 -0
  419. tests/transforms/compose/test_some_of.py +258 -0
  420. tests/transforms/croppad/__init__.py +10 -0
  421. tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
  422. tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
  423. tests/transforms/functional/__init__.py +10 -0
  424. tests/transforms/functional/test_apply.py +75 -0
  425. tests/transforms/functional/test_resample.py +50 -0
  426. tests/transforms/intensity/__init__.py +10 -0
  427. tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
  428. tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
  429. tests/transforms/intensity/test_foreground_mask.py +98 -0
  430. tests/transforms/intensity/test_foreground_maskd.py +106 -0
  431. tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
  432. tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
  433. tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
  434. tests/transforms/inverse/__init__.py +10 -0
  435. tests/transforms/inverse/test_inverse_array.py +76 -0
  436. tests/transforms/inverse/test_traceable_transform.py +59 -0
  437. tests/transforms/post/__init__.py +10 -0
  438. tests/transforms/post/test_label_filterd.py +78 -0
  439. tests/transforms/post/test_probnms.py +72 -0
  440. tests/transforms/post/test_probnmsd.py +79 -0
  441. tests/transforms/post/test_remove_small_objects.py +102 -0
  442. tests/transforms/spatial/__init__.py +10 -0
  443. tests/transforms/spatial/test_convert_box_points.py +119 -0
  444. tests/transforms/spatial/test_grid_patch.py +134 -0
  445. tests/transforms/spatial/test_grid_patchd.py +102 -0
  446. tests/transforms/spatial/test_rand_grid_patch.py +150 -0
  447. tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
  448. tests/transforms/spatial/test_spatial_resampled.py +124 -0
  449. tests/transforms/test_activations.py +120 -0
  450. tests/transforms/test_activationsd.py +64 -0
  451. tests/transforms/test_adaptors.py +160 -0
  452. tests/transforms/test_add_coordinate_channels.py +53 -0
  453. tests/transforms/test_add_coordinate_channelsd.py +67 -0
  454. tests/transforms/test_add_extreme_points_channel.py +80 -0
  455. tests/transforms/test_add_extreme_points_channeld.py +77 -0
  456. tests/transforms/test_adjust_contrast.py +70 -0
  457. tests/transforms/test_adjust_contrastd.py +64 -0
  458. tests/transforms/test_affine.py +245 -0
  459. tests/transforms/test_affine_grid.py +152 -0
  460. tests/transforms/test_affined.py +190 -0
  461. tests/transforms/test_as_channel_last.py +38 -0
  462. tests/transforms/test_as_channel_lastd.py +44 -0
  463. tests/transforms/test_as_discrete.py +81 -0
  464. tests/transforms/test_as_discreted.py +82 -0
  465. tests/transforms/test_border_pad.py +49 -0
  466. tests/transforms/test_border_padd.py +45 -0
  467. tests/transforms/test_bounding_rect.py +54 -0
  468. tests/transforms/test_bounding_rectd.py +53 -0
  469. tests/transforms/test_cast_to_type.py +63 -0
  470. tests/transforms/test_cast_to_typed.py +74 -0
  471. tests/transforms/test_center_scale_crop.py +55 -0
  472. tests/transforms/test_center_scale_cropd.py +56 -0
  473. tests/transforms/test_center_spatial_crop.py +56 -0
  474. tests/transforms/test_center_spatial_cropd.py +63 -0
  475. tests/transforms/test_classes_to_indices.py +93 -0
  476. tests/transforms/test_classes_to_indicesd.py +110 -0
  477. tests/transforms/test_clip_intensity_percentiles.py +196 -0
  478. tests/transforms/test_clip_intensity_percentilesd.py +193 -0
  479. tests/transforms/test_compose_get_number_conversions.py +127 -0
  480. tests/transforms/test_concat_itemsd.py +82 -0
  481. tests/transforms/test_convert_to_multi_channel.py +59 -0
  482. tests/transforms/test_convert_to_multi_channeld.py +37 -0
  483. tests/transforms/test_copy_itemsd.py +86 -0
  484. tests/transforms/test_create_grid_and_affine.py +274 -0
  485. tests/transforms/test_crop_foreground.py +164 -0
  486. tests/transforms/test_crop_foregroundd.py +205 -0
  487. tests/transforms/test_cucim_dict_transform.py +142 -0
  488. tests/transforms/test_cucim_transform.py +141 -0
  489. tests/transforms/test_data_stats.py +221 -0
  490. tests/transforms/test_data_statsd.py +249 -0
  491. tests/transforms/test_delete_itemsd.py +58 -0
  492. tests/transforms/test_detect_envelope.py +159 -0
  493. tests/transforms/test_distance_transform_edt.py +202 -0
  494. tests/transforms/test_divisible_pad.py +49 -0
  495. tests/transforms/test_divisible_padd.py +42 -0
  496. tests/transforms/test_ensure_channel_first.py +113 -0
  497. tests/transforms/test_ensure_channel_firstd.py +85 -0
  498. tests/transforms/test_ensure_type.py +94 -0
  499. tests/transforms/test_ensure_typed.py +110 -0
  500. tests/transforms/test_fg_bg_to_indices.py +83 -0
  501. tests/transforms/test_fg_bg_to_indicesd.py +78 -0
  502. tests/transforms/test_fill_holes.py +207 -0
  503. tests/transforms/test_fill_holesd.py +209 -0
  504. tests/transforms/test_flatten_sub_keysd.py +64 -0
  505. tests/transforms/test_flip.py +83 -0
  506. tests/transforms/test_flipd.py +90 -0
  507. tests/transforms/test_fourier.py +70 -0
  508. tests/transforms/test_gaussian_sharpen.py +92 -0
  509. tests/transforms/test_gaussian_sharpend.py +92 -0
  510. tests/transforms/test_gaussian_smooth.py +96 -0
  511. tests/transforms/test_gaussian_smoothd.py +96 -0
  512. tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
  513. tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
  514. tests/transforms/test_generate_spatial_bounding_box.py +114 -0
  515. tests/transforms/test_get_extreme_points.py +57 -0
  516. tests/transforms/test_gibbs_noise.py +75 -0
  517. tests/transforms/test_gibbs_noised.py +88 -0
  518. tests/transforms/test_grid_distortion.py +113 -0
  519. tests/transforms/test_grid_distortiond.py +87 -0
  520. tests/transforms/test_grid_split.py +88 -0
  521. tests/transforms/test_grid_splitd.py +96 -0
  522. tests/transforms/test_histogram_normalize.py +59 -0
  523. tests/transforms/test_histogram_normalized.py +59 -0
  524. tests/transforms/test_image_filter.py +259 -0
  525. tests/transforms/test_intensity_stats.py +73 -0
  526. tests/transforms/test_intensity_statsd.py +90 -0
  527. tests/transforms/test_inverse.py +521 -0
  528. tests/transforms/test_inverse_collation.py +147 -0
  529. tests/transforms/test_invert.py +105 -0
  530. tests/transforms/test_invertd.py +142 -0
  531. tests/transforms/test_k_space_spike_noise.py +81 -0
  532. tests/transforms/test_k_space_spike_noised.py +98 -0
  533. tests/transforms/test_keep_largest_connected_component.py +419 -0
  534. tests/transforms/test_keep_largest_connected_componentd.py +348 -0
  535. tests/transforms/test_label_filter.py +78 -0
  536. tests/transforms/test_label_to_contour.py +179 -0
  537. tests/transforms/test_label_to_contourd.py +182 -0
  538. tests/transforms/test_label_to_mask.py +69 -0
  539. tests/transforms/test_label_to_maskd.py +70 -0
  540. tests/transforms/test_load_image.py +502 -0
  541. tests/transforms/test_load_imaged.py +198 -0
  542. tests/transforms/test_load_spacing_orientation.py +149 -0
  543. tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
  544. tests/transforms/test_map_binary_to_indices.py +75 -0
  545. tests/transforms/test_map_classes_to_indices.py +135 -0
  546. tests/transforms/test_map_label_value.py +89 -0
  547. tests/transforms/test_map_label_valued.py +85 -0
  548. tests/transforms/test_map_transform.py +45 -0
  549. tests/transforms/test_mask_intensity.py +74 -0
  550. tests/transforms/test_mask_intensityd.py +68 -0
  551. tests/transforms/test_mean_ensemble.py +77 -0
  552. tests/transforms/test_mean_ensembled.py +91 -0
  553. tests/transforms/test_median_smooth.py +41 -0
  554. tests/transforms/test_median_smoothd.py +65 -0
  555. tests/transforms/test_morphological_ops.py +101 -0
  556. tests/transforms/test_nifti_endianness.py +107 -0
  557. tests/transforms/test_normalize_intensity.py +143 -0
  558. tests/transforms/test_normalize_intensityd.py +81 -0
  559. tests/transforms/test_nvtx_decorator.py +289 -0
  560. tests/transforms/test_nvtx_transform.py +143 -0
  561. tests/transforms/test_orientation.py +247 -0
  562. tests/transforms/test_orientationd.py +112 -0
  563. tests/transforms/test_rand_adjust_contrast.py +45 -0
  564. tests/transforms/test_rand_adjust_contrastd.py +44 -0
  565. tests/transforms/test_rand_affine.py +201 -0
  566. tests/transforms/test_rand_affine_grid.py +212 -0
  567. tests/transforms/test_rand_affined.py +281 -0
  568. tests/transforms/test_rand_axis_flip.py +50 -0
  569. tests/transforms/test_rand_axis_flipd.py +50 -0
  570. tests/transforms/test_rand_bias_field.py +69 -0
  571. tests/transforms/test_rand_bias_fieldd.py +65 -0
  572. tests/transforms/test_rand_coarse_dropout.py +110 -0
  573. tests/transforms/test_rand_coarse_dropoutd.py +107 -0
  574. tests/transforms/test_rand_coarse_shuffle.py +65 -0
  575. tests/transforms/test_rand_coarse_shuffled.py +59 -0
  576. tests/transforms/test_rand_crop_by_label_classes.py +170 -0
  577. tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
  578. tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
  579. tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
  580. tests/transforms/test_rand_cucim_dict_transform.py +162 -0
  581. tests/transforms/test_rand_cucim_transform.py +162 -0
  582. tests/transforms/test_rand_deform_grid.py +138 -0
  583. tests/transforms/test_rand_elastic_2d.py +127 -0
  584. tests/transforms/test_rand_elastic_3d.py +104 -0
  585. tests/transforms/test_rand_elasticd_2d.py +177 -0
  586. tests/transforms/test_rand_elasticd_3d.py +156 -0
  587. tests/transforms/test_rand_flip.py +60 -0
  588. tests/transforms/test_rand_flipd.py +55 -0
  589. tests/transforms/test_rand_gaussian_noise.py +48 -0
  590. tests/transforms/test_rand_gaussian_noised.py +54 -0
  591. tests/transforms/test_rand_gaussian_sharpen.py +140 -0
  592. tests/transforms/test_rand_gaussian_sharpend.py +143 -0
  593. tests/transforms/test_rand_gaussian_smooth.py +98 -0
  594. tests/transforms/test_rand_gaussian_smoothd.py +98 -0
  595. tests/transforms/test_rand_gibbs_noise.py +103 -0
  596. tests/transforms/test_rand_gibbs_noised.py +117 -0
  597. tests/transforms/test_rand_grid_distortion.py +99 -0
  598. tests/transforms/test_rand_grid_distortiond.py +90 -0
  599. tests/transforms/test_rand_histogram_shift.py +92 -0
  600. tests/transforms/test_rand_k_space_spike_noise.py +92 -0
  601. tests/transforms/test_rand_k_space_spike_noised.py +76 -0
  602. tests/transforms/test_rand_rician_noise.py +52 -0
  603. tests/transforms/test_rand_rician_noised.py +52 -0
  604. tests/transforms/test_rand_rotate.py +166 -0
  605. tests/transforms/test_rand_rotate90.py +100 -0
  606. tests/transforms/test_rand_rotate90d.py +112 -0
  607. tests/transforms/test_rand_rotated.py +187 -0
  608. tests/transforms/test_rand_scale_crop.py +78 -0
  609. tests/transforms/test_rand_scale_cropd.py +98 -0
  610. tests/transforms/test_rand_scale_intensity.py +54 -0
  611. tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
  612. tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
  613. tests/transforms/test_rand_scale_intensityd.py +53 -0
  614. tests/transforms/test_rand_shift_intensity.py +52 -0
  615. tests/transforms/test_rand_shift_intensityd.py +67 -0
  616. tests/transforms/test_rand_simulate_low_resolution.py +83 -0
  617. tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
  618. tests/transforms/test_rand_spatial_crop.py +107 -0
  619. tests/transforms/test_rand_spatial_crop_samples.py +128 -0
  620. tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
  621. tests/transforms/test_rand_spatial_cropd.py +112 -0
  622. tests/transforms/test_rand_std_shift_intensity.py +43 -0
  623. tests/transforms/test_rand_std_shift_intensityd.py +38 -0
  624. tests/transforms/test_rand_zoom.py +105 -0
  625. tests/transforms/test_rand_zoomd.py +108 -0
  626. tests/transforms/test_randidentity.py +49 -0
  627. tests/transforms/test_random_order.py +144 -0
  628. tests/transforms/test_randtorchvisiond.py +65 -0
  629. tests/transforms/test_regularization.py +139 -0
  630. tests/transforms/test_remove_repeated_channel.py +34 -0
  631. tests/transforms/test_remove_repeated_channeld.py +44 -0
  632. tests/transforms/test_repeat_channel.py +34 -0
  633. tests/transforms/test_repeat_channeld.py +41 -0
  634. tests/transforms/test_resample_backends.py +65 -0
  635. tests/transforms/test_resample_to_match.py +110 -0
  636. tests/transforms/test_resample_to_matchd.py +93 -0
  637. tests/transforms/test_resampler.py +165 -0
  638. tests/transforms/test_resize.py +140 -0
  639. tests/transforms/test_resize_with_pad_or_crop.py +91 -0
  640. tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
  641. tests/transforms/test_resized.py +163 -0
  642. tests/transforms/test_rotate.py +160 -0
  643. tests/transforms/test_rotate90.py +212 -0
  644. tests/transforms/test_rotate90d.py +106 -0
  645. tests/transforms/test_rotated.py +179 -0
  646. tests/transforms/test_save_classificationd.py +109 -0
  647. tests/transforms/test_save_image.py +80 -0
  648. tests/transforms/test_save_imaged.py +130 -0
  649. tests/transforms/test_savitzky_golay_smooth.py +73 -0
  650. tests/transforms/test_savitzky_golay_smoothd.py +73 -0
  651. tests/transforms/test_scale_intensity.py +76 -0
  652. tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
  653. tests/transforms/test_scale_intensity_range.py +41 -0
  654. tests/transforms/test_scale_intensity_ranged.py +40 -0
  655. tests/transforms/test_scale_intensityd.py +57 -0
  656. tests/transforms/test_select_itemsd.py +41 -0
  657. tests/transforms/test_shift_intensity.py +31 -0
  658. tests/transforms/test_shift_intensityd.py +44 -0
  659. tests/transforms/test_signal_continuouswavelet.py +44 -0
  660. tests/transforms/test_signal_fillempty.py +52 -0
  661. tests/transforms/test_signal_fillemptyd.py +60 -0
  662. tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
  663. tests/transforms/test_signal_rand_add_sine.py +52 -0
  664. tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
  665. tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
  666. tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
  667. tests/transforms/test_signal_rand_drop.py +50 -0
  668. tests/transforms/test_signal_rand_scale.py +52 -0
  669. tests/transforms/test_signal_rand_shift.py +55 -0
  670. tests/transforms/test_signal_remove_frequency.py +71 -0
  671. tests/transforms/test_smooth_field.py +177 -0
  672. tests/transforms/test_sobel_gradient.py +189 -0
  673. tests/transforms/test_sobel_gradientd.py +212 -0
  674. tests/transforms/test_spacing.py +381 -0
  675. tests/transforms/test_spacingd.py +178 -0
  676. tests/transforms/test_spatial_crop.py +82 -0
  677. tests/transforms/test_spatial_cropd.py +74 -0
  678. tests/transforms/test_spatial_pad.py +57 -0
  679. tests/transforms/test_spatial_padd.py +43 -0
  680. tests/transforms/test_spatial_resample.py +235 -0
  681. tests/transforms/test_squeezedim.py +62 -0
  682. tests/transforms/test_squeezedimd.py +98 -0
  683. tests/transforms/test_std_shift_intensity.py +76 -0
  684. tests/transforms/test_std_shift_intensityd.py +74 -0
  685. tests/transforms/test_threshold_intensity.py +38 -0
  686. tests/transforms/test_threshold_intensityd.py +58 -0
  687. tests/transforms/test_to_contiguous.py +47 -0
  688. tests/transforms/test_to_cupy.py +112 -0
  689. tests/transforms/test_to_cupyd.py +76 -0
  690. tests/transforms/test_to_device.py +42 -0
  691. tests/transforms/test_to_deviced.py +37 -0
  692. tests/transforms/test_to_numpy.py +85 -0
  693. tests/transforms/test_to_numpyd.py +68 -0
  694. tests/transforms/test_to_pil.py +52 -0
  695. tests/transforms/test_to_pild.py +50 -0
  696. tests/transforms/test_to_tensor.py +60 -0
  697. tests/transforms/test_to_tensord.py +71 -0
  698. tests/transforms/test_torchvision.py +66 -0
  699. tests/transforms/test_torchvisiond.py +63 -0
  700. tests/transforms/test_transform.py +62 -0
  701. tests/transforms/test_transpose.py +41 -0
  702. tests/transforms/test_transposed.py +52 -0
  703. tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
  704. tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
  705. tests/transforms/test_vote_ensemble.py +84 -0
  706. tests/transforms/test_vote_ensembled.py +107 -0
  707. tests/transforms/test_with_allow_missing_keys.py +76 -0
  708. tests/transforms/test_zoom.py +120 -0
  709. tests/transforms/test_zoomd.py +94 -0
  710. tests/transforms/transform/__init__.py +10 -0
  711. tests/transforms/transform/test_randomizable.py +52 -0
  712. tests/transforms/transform/test_randomizable_transform_type.py +37 -0
  713. tests/transforms/utility/__init__.py +10 -0
  714. tests/transforms/utility/test_apply_transform_to_points.py +81 -0
  715. tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
  716. tests/transforms/utility/test_identity.py +29 -0
  717. tests/transforms/utility/test_identityd.py +30 -0
  718. tests/transforms/utility/test_lambda.py +71 -0
  719. tests/transforms/utility/test_lambdad.py +83 -0
  720. tests/transforms/utility/test_rand_lambda.py +87 -0
  721. tests/transforms/utility/test_rand_lambdad.py +77 -0
  722. tests/transforms/utility/test_simulatedelay.py +36 -0
  723. tests/transforms/utility/test_simulatedelayd.py +36 -0
  724. tests/transforms/utility/test_splitdim.py +52 -0
  725. tests/transforms/utility/test_splitdimd.py +96 -0
  726. tests/transforms/utils/__init__.py +10 -0
  727. tests/transforms/utils/test_correct_crop_centers.py +36 -0
  728. tests/transforms/utils/test_get_unique_labels.py +45 -0
  729. tests/transforms/utils/test_print_transform_backends.py +29 -0
  730. tests/transforms/utils/test_soft_clip.py +125 -0
  731. tests/utils/__init__.py +10 -0
  732. tests/utils/enums/__init__.py +10 -0
  733. tests/utils/enums/test_hovernet_loss.py +190 -0
  734. tests/utils/enums/test_ordering.py +289 -0
  735. tests/utils/enums/test_wsireader.py +663 -0
  736. tests/utils/misc/__init__.py +10 -0
  737. tests/utils/misc/test_ensure_tuple.py +53 -0
  738. tests/utils/misc/test_monai_env_vars.py +44 -0
  739. tests/utils/misc/test_monai_utils_misc.py +103 -0
  740. tests/utils/misc/test_str2bool.py +34 -0
  741. tests/utils/misc/test_str2list.py +33 -0
  742. tests/utils/test_alias.py +44 -0
  743. tests/utils/test_component_store.py +73 -0
  744. tests/utils/test_deprecated.py +455 -0
  745. tests/utils/test_enum_bound_interp.py +75 -0
  746. tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
  747. tests/utils/test_get_package_version.py +34 -0
  748. tests/utils/test_handler_logfile.py +84 -0
  749. tests/utils/test_handler_metric_logger.py +62 -0
  750. tests/utils/test_list_to_dict.py +43 -0
  751. tests/utils/test_look_up_option.py +87 -0
  752. tests/utils/test_optional_import.py +80 -0
  753. tests/utils/test_pad_mode.py +39 -0
  754. tests/utils/test_profiling.py +208 -0
  755. tests/utils/test_rankfilter_dist.py +77 -0
  756. tests/utils/test_require_pkg.py +83 -0
  757. tests/utils/test_sample_slices.py +43 -0
  758. tests/utils/test_set_determinism.py +74 -0
  759. tests/utils/test_squeeze_unsqueeze.py +71 -0
  760. tests/utils/test_state_cacher.py +67 -0
  761. tests/utils/test_torchscript_utils.py +113 -0
  762. tests/utils/test_version.py +91 -0
  763. tests/utils/test_version_after.py +65 -0
  764. tests/utils/type_conversion/__init__.py +10 -0
  765. tests/utils/type_conversion/test_convert_data_type.py +152 -0
  766. tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
  767. tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
  768. tests/visualize/__init__.py +10 -0
  769. tests/visualize/test_img2tensorboard.py +46 -0
  770. tests/visualize/test_occlusion_sensitivity.py +128 -0
  771. tests/visualize/test_plot_2d_or_3d_image.py +74 -0
  772. tests/visualize/test_vis_cam.py +98 -0
  773. tests/visualize/test_vis_gradcam.py +211 -0
  774. tests/visualize/utils/__init__.py +10 -0
  775. tests/visualize/utils/test_blend_images.py +63 -0
  776. tests/visualize/utils/test_matshow3d.py +133 -0
  777. monai_weekly-1.5.dev2505.dist-info/RECORD +0 -427
  778. {monai_weekly-1.5.dev2505.dist-info → monai_weekly-1.5.dev2507.dist-info}/LICENSE +0 -0
  779. {monai_weekly-1.5.dev2505.dist-info → monai_weekly-1.5.dev2507.dist-info}/WHEEL +0 -0
@@ -0,0 +1,111 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import os
15
+ import unittest
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import torch
20
+ from ignite.engine import Engine
21
+ from parameterized import parameterized
22
+
23
+ from monai.data import DataLoader, Dataset, MetaTensor
24
+ from monai.engines import Evaluator
25
+ from monai.handlers import ProbMapProducer, ValidationHandler
26
+ from monai.utils.enums import ProbMapKeys
27
+
28
+ TEST_CASE_0 = ["temp_image_inference_output_1", 1]
29
+ TEST_CASE_1 = ["temp_image_inference_output_2", 9]
30
+ TEST_CASE_2 = ["temp_image_inference_output_3", 100]
31
+
32
+
33
+ class TestDataset(Dataset):
34
+ __test__ = False # indicate to pytest that this class is not intended for collection
35
+
36
+ def __init__(self, name, size):
37
+ super().__init__(
38
+ data=[
39
+ {
40
+ "image": name,
41
+ ProbMapKeys.COUNT.value: size,
42
+ ProbMapKeys.SIZE.value: np.array([size + 1, size + 1]),
43
+ ProbMapKeys.LOCATION.value: np.array([i, i + 1]),
44
+ }
45
+ for i in range(size)
46
+ ]
47
+ )
48
+ self.image_data = [
49
+ {
50
+ ProbMapKeys.NAME.value: name,
51
+ ProbMapKeys.COUNT.value: size,
52
+ ProbMapKeys.SIZE.value: np.array([size + 1, size + 1]),
53
+ }
54
+ ]
55
+
56
+ def __getitem__(self, index):
57
+ image = np.ones((3, 2, 2)) * index
58
+ metadata = {
59
+ ProbMapKeys.COUNT.value: self.data[index][ProbMapKeys.COUNT.value],
60
+ ProbMapKeys.NAME.value: self.data[index]["image"],
61
+ ProbMapKeys.SIZE.value: self.data[index][ProbMapKeys.SIZE.value],
62
+ ProbMapKeys.LOCATION.value: self.data[index][ProbMapKeys.LOCATION.value],
63
+ }
64
+
65
+ return {"image": MetaTensor(x=image, meta=metadata), "pred": index + 1}
66
+
67
+
68
+ class TestEvaluator(Evaluator):
69
+ __test__ = False # indicate to pytest that this class is not intended for collection
70
+
71
+ def _iteration(self, engine, batchdata):
72
+ return batchdata
73
+
74
+
75
+ class TestHandlerProbMapGenerator(unittest.TestCase):
76
+ @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2])
77
+ def test_prob_map_generator(self, name, size):
78
+ # set up dataset
79
+ dataset = TestDataset(name, size)
80
+ batch_size = 2
81
+ data_loader = DataLoader(dataset, batch_size=batch_size)
82
+
83
+ # set up engine
84
+ def inference(engine, batch):
85
+ pass
86
+
87
+ engine = Engine(inference)
88
+
89
+ tests_path = Path(__file__).parents[1].as_posix()
90
+ # add ProbMapGenerator() to evaluator
91
+ output_dir = os.path.join(tests_path, "testing_data")
92
+ prob_map_gen = ProbMapProducer(output_dir=output_dir)
93
+
94
+ evaluator = TestEvaluator(
95
+ torch.device("cpu:0"), data_loader, np.ceil(size / batch_size), val_handlers=[prob_map_gen]
96
+ )
97
+
98
+ # set up validation handler
99
+ validation = ValidationHandler(interval=1, validator=None)
100
+ validation.attach(engine)
101
+ validation.set_validator(validator=evaluator)
102
+
103
+ engine.run(data_loader)
104
+
105
+ prob_map = np.load(os.path.join(output_dir, name + ".npy"))
106
+ self.assertListEqual(np.vstack(prob_map.nonzero()).T.tolist(), [[i, i + 1] for i in range(size)])
107
+ self.assertListEqual(prob_map[prob_map.nonzero()].tolist(), [i + 1 for i in range(size)])
108
+
109
+
110
+ if __name__ == "__main__":
111
+ unittest.main()
@@ -0,0 +1,160 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+ from functools import partial
16
+
17
+ import numpy as np
18
+ import torch
19
+ from ignite.engine import Engine
20
+
21
+ from monai.handlers import MeanAbsoluteError, MeanSquaredError, PeakSignalToNoiseRatio, RootMeanSquaredError
22
+ from monai.utils import set_determinism
23
+
24
+
25
+ # define a numpy flatten function that only preserves batch dimension
26
+ def flatten(data):
27
+ return np.reshape(data, [data.shape[0], -1])
28
+
29
+
30
+ # define metrics computation truth functions to check our monai metrics against
31
+ def msemetric_np(y_pred, y):
32
+ return np.mean((flatten(y_pred) - flatten(y)) ** 2)
33
+
34
+
35
+ def maemetric_np(y_pred, y):
36
+ return np.mean(np.abs(flatten(y_pred) - flatten(y)))
37
+
38
+
39
+ def rmsemetric_np(y_pred, y):
40
+ return np.mean(np.sqrt(np.mean((flatten(y_pred) - flatten(y)) ** 2, axis=1)))
41
+
42
+
43
+ def psnrmetric_np(max_val, y_pred, y):
44
+ mse = np.mean((flatten(y_pred) - flatten(y)) ** 2, axis=1)
45
+ return np.mean(20 * np.log10(max_val) - 10 * np.log10(mse))
46
+
47
+
48
+ class TestHandlerRegressionMetrics(unittest.TestCase):
49
+
50
+ def test_compute(self):
51
+ set_determinism(seed=123)
52
+ device = "cuda" if torch.cuda.is_available() else "cpu"
53
+
54
+ # regression metrics to check + truth metric function in numpy
55
+ metrics = [
56
+ MeanSquaredError,
57
+ MeanAbsoluteError,
58
+ RootMeanSquaredError,
59
+ partial(PeakSignalToNoiseRatio, max_val=1.0),
60
+ ]
61
+ metrics_np = [msemetric_np, maemetric_np, rmsemetric_np, partial(psnrmetric_np, max_val=1.0)]
62
+
63
+ # define variations in batch/base_dims/spatial_dims
64
+ batch_dims = [1, 2, 4, 16]
65
+ base_dims = [16, 32, 64]
66
+ spatial_dims = [2, 3, 4]
67
+
68
+ # iterate over all variations and check shapes for different reduction functions
69
+ for mt_fn, mt_fn_np in zip(metrics, metrics_np):
70
+ for batch in batch_dims:
71
+ for spatial in spatial_dims:
72
+ for base in base_dims:
73
+ mt_fn_obj = mt_fn(**{"save_details": False})
74
+
75
+ # create random tensor
76
+ in_tensor_a1 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)
77
+ in_tensor_b1 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)
78
+ mt_fn_obj.update([in_tensor_a1, in_tensor_b1])
79
+ out_tensor_np1 = mt_fn_np(y_pred=in_tensor_a1.cpu().numpy(), y=in_tensor_b1.cpu().numpy())
80
+
81
+ in_tensor_a2 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)
82
+ in_tensor_b2 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)
83
+ mt_fn_obj.update([in_tensor_a2, in_tensor_b2])
84
+ out_tensor_np2 = mt_fn_np(y_pred=in_tensor_a2.cpu().numpy(), y=in_tensor_b2.cpu().numpy())
85
+
86
+ out_tensor = mt_fn_obj.compute()
87
+ out_tensor_np = (out_tensor_np1 + out_tensor_np2) / 2.0
88
+
89
+ np.testing.assert_allclose(out_tensor, out_tensor_np, atol=1e-4)
90
+
91
+ def test_compute_engine(self):
92
+ set_determinism(seed=123)
93
+ device = "cuda" if torch.cuda.is_available() else "cpu"
94
+
95
+ # regression metrics to check + truth metric function in numpy
96
+ metrics_names = ["MSE", "MAE", "RMSE", "PSNR"]
97
+ metrics = [
98
+ MeanSquaredError,
99
+ MeanAbsoluteError,
100
+ RootMeanSquaredError,
101
+ partial(PeakSignalToNoiseRatio, max_val=1.0),
102
+ ]
103
+ metrics_np = [msemetric_np, maemetric_np, rmsemetric_np, partial(psnrmetric_np, max_val=1.0)]
104
+
105
+ def _val_func(engine, batch):
106
+ pass
107
+
108
+ # define variations in batch/base_dims/spatial_dims
109
+ batch_dims = [1, 2, 4, 16]
110
+ base_dims = [16, 32, 64]
111
+ spatial_dims = [2, 3, 4]
112
+
113
+ # iterate over all variations and check shapes for different reduction functions
114
+ for mt_fn_name, mt_fn, mt_fn_np in zip(metrics_names, metrics, metrics_np):
115
+ for batch in batch_dims:
116
+ for spatial in spatial_dims:
117
+ for base in base_dims:
118
+ mt_fn_obj = mt_fn() # 'save_details' == True
119
+ engine = Engine(_val_func)
120
+ mt_fn_obj.attach(engine, mt_fn_name)
121
+
122
+ # create random tensor
123
+ in_tensor_a1 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)
124
+ in_tensor_b1 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)
125
+ mt_fn_obj.update([in_tensor_a1, in_tensor_b1])
126
+ out_tensor_np1 = mt_fn_np(y_pred=in_tensor_a1.cpu().numpy(), y=in_tensor_b1.cpu().numpy())
127
+
128
+ in_tensor_a2 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)
129
+ in_tensor_b2 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)
130
+ mt_fn_obj.update([in_tensor_a2, in_tensor_b2])
131
+ out_tensor_np2 = mt_fn_np(y_pred=in_tensor_a2.cpu().numpy(), y=in_tensor_b2.cpu().numpy())
132
+
133
+ out_tensor = mt_fn_obj.compute()
134
+ out_tensor_np = (out_tensor_np1 + out_tensor_np2) / 2.0
135
+
136
+ np.testing.assert_allclose(out_tensor, out_tensor_np, atol=1e-4)
137
+
138
+ def test_ill_shape(self):
139
+ set_determinism(seed=123)
140
+ device = "cuda" if torch.cuda.is_available() else "cpu"
141
+
142
+ # regression metrics to check + truth metric function in numpy
143
+ metrics = [
144
+ MeanSquaredError,
145
+ MeanAbsoluteError,
146
+ RootMeanSquaredError,
147
+ partial(PeakSignalToNoiseRatio, max_val=1.0),
148
+ ]
149
+ basedim = 10
150
+
151
+ # different shape for pred/target
152
+ with self.assertRaises((AssertionError, ValueError)):
153
+ in_tensor_a = torch.rand((basedim,)).to(device)
154
+ in_tensor_b = torch.rand((basedim, basedim)).to(device)
155
+ for mt_fn in metrics:
156
+ mt_fn().update([in_tensor_a, in_tensor_b])
157
+
158
+
159
+ if __name__ == "__main__":
160
+ unittest.main()
@@ -0,0 +1,245 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.distributed as dist
19
+ from ignite.engine import Engine
20
+
21
+ from monai.handlers import MeanAbsoluteError, MeanSquaredError, PeakSignalToNoiseRatio, RootMeanSquaredError
22
+ from monai.utils import set_determinism
23
+ from tests.test_utils import DistCall, DistTestCase
24
+
25
+
26
+ # define a numpy flatten function that only preserves batch dimension
27
+ def flatten(data):
28
+ return np.reshape(data, [data.shape[0], -1])
29
+
30
+
31
+ # define metrics computation truth functions to check our monai metrics against
32
+ def msemetric_np(y_pred, y):
33
+ return np.mean((flatten(y_pred) - flatten(y)) ** 2)
34
+
35
+
36
+ def maemetric_np(y_pred, y):
37
+ return np.mean(np.abs(flatten(y_pred) - flatten(y)))
38
+
39
+
40
+ def rmsemetric_np(y_pred, y):
41
+ return np.mean(np.sqrt(np.mean((flatten(y_pred) - flatten(y)) ** 2, axis=1)))
42
+
43
+
44
+ def psnrmetric_np(max_val, y_pred, y):
45
+ mse = np.mean((flatten(y_pred) - flatten(y)) ** 2, axis=1)
46
+ return np.mean(20 * np.log10(max_val) - 10 * np.log10(mse))
47
+
48
+
49
+ # define tensor size as (BATCH_SIZE, (BASE_DIM_SIZE,) * SPATIAL_DIM)
50
+ # One tensor with following shape takes 4*32*32*32*32/(8*1000) = 512 MB on a single GPU
51
+ # We have total of 2 tensors each on one GPU for following tests, so required GPU memory is 1024 MB on each GPU
52
+ # The required GPU memory can be lowered by changing BASE_DIM_SIZE to another value e.g. BASE_DIM_SIZE=16 will
53
+ # require 128 MB on each GPU
54
+ BATCH_SIZE = 4
55
+ BASE_DIM_SIZE = 32
56
+ SPATIAL_DIM = 3
57
+
58
+
59
+ class DistributedMeanSquaredError(DistTestCase):
60
+ @DistCall(nnodes=1, nproc_per_node=2)
61
+ def test_compute(self):
62
+ set_determinism(123)
63
+ self._compute()
64
+
65
+ def _compute(self):
66
+ device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu"
67
+ metric = MeanSquaredError()
68
+
69
+ def _val_func(engine, batch):
70
+ pass
71
+
72
+ engine = Engine(_val_func)
73
+ metric.attach(engine, "MSE")
74
+
75
+ # get testing data
76
+ batch = BATCH_SIZE
77
+ base = BASE_DIM_SIZE
78
+ spatial = SPATIAL_DIM
79
+ in_tensor_a1 = torch.rand((batch,) + (base,) * (spatial - 1))
80
+ in_tensor_b1 = torch.rand((batch,) + (base,) * (spatial - 1))
81
+
82
+ in_tensor_a2 = torch.rand((batch,) + (base,) * (spatial - 1))
83
+ in_tensor_b2 = torch.rand((batch,) + (base,) * (spatial - 1))
84
+
85
+ if dist.get_rank() == 0:
86
+ y_pred = in_tensor_a1.to(device)
87
+ y = in_tensor_b1.to(device)
88
+ metric.update([y_pred, y])
89
+
90
+ if dist.get_rank() == 1:
91
+ y_pred = in_tensor_a2.to(device)
92
+ y = in_tensor_b2.to(device)
93
+ metric.update([y_pred, y])
94
+
95
+ out_tensor = metric.compute()
96
+
97
+ # do numpy functions to get ground truth referece
98
+ out_tensor_np1 = msemetric_np(y_pred=in_tensor_a1.cpu().numpy(), y=in_tensor_b1.cpu().numpy())
99
+ out_tensor_np2 = msemetric_np(y_pred=in_tensor_a2.cpu().numpy(), y=in_tensor_b2.cpu().numpy())
100
+ out_tensor_np = (out_tensor_np1 + out_tensor_np2) / 2.0
101
+
102
+ np.testing.assert_allclose(out_tensor, out_tensor_np, rtol=1e-04, atol=1e-04)
103
+
104
+
105
+ class DistributedMeanAbsoluteError(DistTestCase):
106
+ @DistCall(nnodes=1, nproc_per_node=2)
107
+ def test_compute(self):
108
+ set_determinism(123)
109
+ self._compute()
110
+
111
+ def _compute(self):
112
+ device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu"
113
+ metric = MeanAbsoluteError()
114
+
115
+ def _val_func(engine, batch):
116
+ pass
117
+
118
+ engine = Engine(_val_func)
119
+ metric.attach(engine, "MAE")
120
+
121
+ # get testing data
122
+ batch = BATCH_SIZE
123
+ base = BASE_DIM_SIZE
124
+ spatial = SPATIAL_DIM
125
+ in_tensor_a1 = torch.rand((batch,) + (base,) * (spatial - 1))
126
+ in_tensor_b1 = torch.rand((batch,) + (base,) * (spatial - 1))
127
+
128
+ in_tensor_a2 = torch.rand((batch,) + (base,) * (spatial - 1))
129
+ in_tensor_b2 = torch.rand((batch,) + (base,) * (spatial - 1))
130
+
131
+ if dist.get_rank() == 0:
132
+ y_pred = in_tensor_a1.to(device)
133
+ y = in_tensor_b1.to(device)
134
+ metric.update([y_pred, y])
135
+
136
+ if dist.get_rank() == 1:
137
+ y_pred = in_tensor_a2.to(device)
138
+ y = in_tensor_b2.to(device)
139
+ metric.update([y_pred, y])
140
+
141
+ out_tensor = metric.compute()
142
+
143
+ # do numpy functions to get ground truth referece
144
+ out_tensor_np1 = maemetric_np(y_pred=in_tensor_a1.cpu().numpy(), y=in_tensor_b1.cpu().numpy())
145
+ out_tensor_np2 = maemetric_np(y_pred=in_tensor_a2.cpu().numpy(), y=in_tensor_b2.cpu().numpy())
146
+ out_tensor_np = (out_tensor_np1 + out_tensor_np2) / 2.0
147
+
148
+ np.testing.assert_allclose(out_tensor, out_tensor_np, rtol=1e-04, atol=1e-04)
149
+
150
+
151
+ class DistributedRootMeanSquaredError(DistTestCase):
152
+ @DistCall(nnodes=1, nproc_per_node=2)
153
+ def test_compute(self):
154
+ set_determinism(123)
155
+ self._compute()
156
+
157
+ def _compute(self):
158
+ device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu"
159
+ metric = RootMeanSquaredError()
160
+
161
+ def _val_func(engine, batch):
162
+ pass
163
+
164
+ engine = Engine(_val_func)
165
+ metric.attach(engine, "RMSE")
166
+
167
+ # get testing data
168
+ batch = BATCH_SIZE
169
+ base = BASE_DIM_SIZE
170
+ spatial = SPATIAL_DIM
171
+ in_tensor_a1 = torch.rand((batch,) + (base,) * (spatial - 1))
172
+ in_tensor_b1 = torch.rand((batch,) + (base,) * (spatial - 1))
173
+
174
+ in_tensor_a2 = torch.rand((batch,) + (base,) * (spatial - 1))
175
+ in_tensor_b2 = torch.rand((batch,) + (base,) * (spatial - 1))
176
+
177
+ if dist.get_rank() == 0:
178
+ y_pred = in_tensor_a1.to(device)
179
+ y = in_tensor_b1.to(device)
180
+ metric.update([y_pred, y])
181
+
182
+ if dist.get_rank() == 1:
183
+ y_pred = in_tensor_a2.to(device)
184
+ y = in_tensor_b2.to(device)
185
+ metric.update([y_pred, y])
186
+
187
+ out_tensor = metric.compute()
188
+
189
+ # do numpy functions to get ground truth referece
190
+ out_tensor_np1 = rmsemetric_np(y_pred=in_tensor_a1.cpu().numpy(), y=in_tensor_b1.cpu().numpy())
191
+ out_tensor_np2 = rmsemetric_np(y_pred=in_tensor_a2.cpu().numpy(), y=in_tensor_b2.cpu().numpy())
192
+ out_tensor_np = (out_tensor_np1 + out_tensor_np2) / 2.0
193
+
194
+ np.testing.assert_allclose(out_tensor, out_tensor_np, rtol=1e-04, atol=1e-04)
195
+
196
+
197
+ class DistributedPeakSignalToNoiseRatio(DistTestCase):
198
+ @DistCall(nnodes=1, nproc_per_node=2)
199
+ def test_compute(self):
200
+ set_determinism(123)
201
+ self._compute()
202
+
203
+ def _compute(self):
204
+ device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu"
205
+ max_val = 1.0
206
+ metric = PeakSignalToNoiseRatio(max_val=max_val)
207
+
208
+ def _val_func(engine, batch):
209
+ pass
210
+
211
+ engine = Engine(_val_func)
212
+ metric.attach(engine, "PSNR")
213
+
214
+ # get testing data
215
+ batch = BATCH_SIZE
216
+ base = BASE_DIM_SIZE
217
+ spatial = SPATIAL_DIM
218
+ in_tensor_a1 = torch.rand((batch,) + (base,) * (spatial - 1))
219
+ in_tensor_b1 = torch.rand((batch,) + (base,) * (spatial - 1))
220
+
221
+ in_tensor_a2 = torch.rand((batch,) + (base,) * (spatial - 1))
222
+ in_tensor_b2 = torch.rand((batch,) + (base,) * (spatial - 1))
223
+
224
+ if dist.get_rank() == 0:
225
+ y_pred = in_tensor_a1.to(device)
226
+ y = in_tensor_b1.to(device)
227
+ metric.update([y_pred, y])
228
+
229
+ if dist.get_rank() == 1:
230
+ y_pred = in_tensor_a2.to(device)
231
+ y = in_tensor_b2.to(device)
232
+ metric.update([y_pred, y])
233
+
234
+ out_tensor = metric.compute()
235
+
236
+ # do numpy functions to get ground truth referece
237
+ out_tensor_np1 = psnrmetric_np(max_val=max_val, y_pred=in_tensor_a1.cpu().numpy(), y=in_tensor_b1.cpu().numpy())
238
+ out_tensor_np2 = psnrmetric_np(max_val=max_val, y_pred=in_tensor_a2.cpu().numpy(), y=in_tensor_b2.cpu().numpy())
239
+ out_tensor_np = (out_tensor_np1 + out_tensor_np2) / 2.0
240
+
241
+ np.testing.assert_allclose(out_tensor, out_tensor_np, rtol=1e-04, atol=1e-04)
242
+
243
+
244
+ if __name__ == "__main__":
245
+ unittest.main()
@@ -0,0 +1,48 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import numpy as np
17
+ import torch
18
+
19
+ from monai.handlers import ROCAUC
20
+ from monai.transforms import Activations, AsDiscrete
21
+
22
+
23
+ class TestHandlerROCAUC(unittest.TestCase):
24
+
25
+ def test_compute(self):
26
+ auc_metric = ROCAUC()
27
+ act = Activations(softmax=True)
28
+ to_onehot = AsDiscrete(to_onehot=2)
29
+
30
+ y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])]
31
+ y = [torch.Tensor([0]), torch.Tensor([1])]
32
+ y_pred = [act(p) for p in y_pred]
33
+ y = [to_onehot(y_) for y_ in y]
34
+ auc_metric.update([y_pred, y])
35
+
36
+ y_pred = [torch.Tensor([0.2, 0.1]), torch.Tensor([0.1, 0.5])]
37
+ y = [torch.Tensor([0]), torch.Tensor([1])]
38
+ y_pred = [act(p) for p in y_pred]
39
+ y = [to_onehot(y_) for y_ in y]
40
+
41
+ auc_metric.update([y_pred, y])
42
+
43
+ auc = auc_metric.compute()
44
+ np.testing.assert_allclose(0.75, auc)
45
+
46
+
47
+ if __name__ == "__main__":
48
+ unittest.main()
@@ -0,0 +1,54 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.distributed as dist
19
+
20
+ from monai.handlers import ROCAUC
21
+ from monai.transforms import Activations, AsDiscrete
22
+ from tests.test_utils import DistCall, DistTestCase
23
+
24
+
25
+ class DistributedROCAUC(DistTestCase):
26
+ @DistCall(nnodes=1, nproc_per_node=2, node_rank=0)
27
+ def test_compute(self):
28
+ auc_metric = ROCAUC()
29
+ act = Activations(softmax=True)
30
+ to_onehot = AsDiscrete(to_onehot=2)
31
+
32
+ device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu"
33
+ if dist.get_rank() == 0:
34
+ y_pred = [torch.tensor([0.1, 0.9], device=device), torch.tensor([0.3, 1.4], device=device)]
35
+ y = [torch.tensor([0], device=device), torch.tensor([1], device=device)]
36
+
37
+ if dist.get_rank() == 1:
38
+ y_pred = [
39
+ torch.tensor([0.2, 0.1], device=device),
40
+ torch.tensor([0.1, 0.5], device=device),
41
+ torch.tensor([0.3, 0.4], device=device),
42
+ ]
43
+ y = [torch.tensor([0], device=device), torch.tensor([1], device=device), torch.tensor([1], device=device)]
44
+
45
+ y_pred = [act(p) for p in y_pred]
46
+ y = [to_onehot(y_) for y_ in y]
47
+ auc_metric.update([y_pred, y])
48
+
49
+ result = auc_metric.compute()
50
+ np.testing.assert_allclose(0.66667, result, rtol=1e-4)
51
+
52
+
53
+ if __name__ == "__main__":
54
+ unittest.main()