monai-weekly 1.5.dev2506__py3-none-any.whl → 1.5.dev2507__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (776) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/transforms.py +1 -4
  4. monai/data/utils.py +6 -13
  5. monai/inferers/utils.py +1 -2
  6. monai/losses/dice.py +2 -14
  7. monai/losses/ds_loss.py +1 -3
  8. monai/networks/layers/simplelayers.py +2 -14
  9. monai/networks/utils.py +4 -16
  10. monai/transforms/compose.py +28 -11
  11. monai/transforms/croppad/array.py +1 -6
  12. monai/transforms/io/array.py +0 -1
  13. monai/transforms/transform.py +15 -6
  14. monai/transforms/utils.py +1 -2
  15. monai/utils/tf32.py +0 -10
  16. monai/visualize/class_activation_maps.py +5 -8
  17. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/METADATA +2 -2
  18. monai_weekly-1.5.dev2507.dist-info/RECORD +1181 -0
  19. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/top_level.txt +1 -0
  20. tests/apps/__init__.py +10 -0
  21. tests/apps/deepedit/__init__.py +10 -0
  22. tests/apps/deepedit/test_deepedit_transforms.py +314 -0
  23. tests/apps/deepgrow/__init__.py +10 -0
  24. tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
  25. tests/apps/deepgrow/transforms/__init__.py +10 -0
  26. tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
  27. tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
  28. tests/apps/detection/__init__.py +10 -0
  29. tests/apps/detection/metrics/__init__.py +10 -0
  30. tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
  31. tests/apps/detection/networks/__init__.py +10 -0
  32. tests/apps/detection/networks/test_retinanet.py +210 -0
  33. tests/apps/detection/networks/test_retinanet_detector.py +203 -0
  34. tests/apps/detection/test_box_transform.py +370 -0
  35. tests/apps/detection/utils/__init__.py +10 -0
  36. tests/apps/detection/utils/test_anchor_box.py +88 -0
  37. tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
  38. tests/apps/detection/utils/test_box_coder.py +43 -0
  39. tests/apps/detection/utils/test_detector_boxselector.py +67 -0
  40. tests/apps/detection/utils/test_detector_utils.py +96 -0
  41. tests/apps/detection/utils/test_hardnegsampler.py +54 -0
  42. tests/apps/nuclick/__init__.py +10 -0
  43. tests/apps/nuclick/test_nuclick_transforms.py +259 -0
  44. tests/apps/pathology/__init__.py +10 -0
  45. tests/apps/pathology/handlers/__init__.py +10 -0
  46. tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
  47. tests/apps/pathology/test_lesion_froc.py +333 -0
  48. tests/apps/pathology/test_pathology_prob_nms.py +55 -0
  49. tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
  50. tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
  51. tests/apps/pathology/transforms/__init__.py +10 -0
  52. tests/apps/pathology/transforms/post/__init__.py +10 -0
  53. tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
  54. tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
  55. tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
  56. tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
  57. tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
  58. tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
  59. tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
  60. tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
  61. tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
  62. tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
  63. tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
  64. tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
  65. tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
  66. tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
  67. tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
  68. tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
  69. tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
  70. tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
  71. tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
  72. tests/apps/pathology/transforms/post/test_watershed.py +60 -0
  73. tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
  74. tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
  75. tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
  76. tests/apps/reconstruction/__init__.py +10 -0
  77. tests/apps/reconstruction/nets/__init__.py +10 -0
  78. tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
  79. tests/apps/reconstruction/test_complex_utils.py +77 -0
  80. tests/apps/reconstruction/test_fastmri_reader.py +82 -0
  81. tests/apps/reconstruction/test_mri_utils.py +37 -0
  82. tests/apps/reconstruction/transforms/__init__.py +10 -0
  83. tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
  84. tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
  85. tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
  86. tests/apps/test_auto3dseg_bundlegen.py +156 -0
  87. tests/apps/test_check_hash.py +53 -0
  88. tests/apps/test_cross_validation.py +74 -0
  89. tests/apps/test_decathlondataset.py +93 -0
  90. tests/apps/test_download_and_extract.py +70 -0
  91. tests/apps/test_download_url_yandex.py +45 -0
  92. tests/apps/test_mednistdataset.py +72 -0
  93. tests/apps/test_mmar_download.py +154 -0
  94. tests/apps/test_tciadataset.py +123 -0
  95. tests/apps/vista3d/__init__.py +10 -0
  96. tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
  97. tests/apps/vista3d/test_vista3d_sampler.py +100 -0
  98. tests/apps/vista3d/test_vista3d_transforms.py +94 -0
  99. tests/bundle/__init__.py +10 -0
  100. tests/bundle/test_bundle_ckpt_export.py +107 -0
  101. tests/bundle/test_bundle_download.py +435 -0
  102. tests/bundle/test_bundle_get_data.py +94 -0
  103. tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
  104. tests/bundle/test_bundle_trt_export.py +147 -0
  105. tests/bundle/test_bundle_utils.py +149 -0
  106. tests/bundle/test_bundle_verify_metadata.py +66 -0
  107. tests/bundle/test_bundle_verify_net.py +76 -0
  108. tests/bundle/test_bundle_workflow.py +272 -0
  109. tests/bundle/test_component_locator.py +38 -0
  110. tests/bundle/test_config_item.py +138 -0
  111. tests/bundle/test_config_parser.py +392 -0
  112. tests/bundle/test_reference_resolver.py +114 -0
  113. tests/config/__init__.py +10 -0
  114. tests/config/test_cv2_dist.py +53 -0
  115. tests/engines/__init__.py +10 -0
  116. tests/engines/test_ensemble_evaluator.py +94 -0
  117. tests/engines/test_prepare_batch_default.py +76 -0
  118. tests/engines/test_prepare_batch_default_dist.py +76 -0
  119. tests/engines/test_prepare_batch_diffusion.py +104 -0
  120. tests/engines/test_prepare_batch_extra_input.py +80 -0
  121. tests/fl/__init__.py +10 -0
  122. tests/fl/monai_algo/__init__.py +10 -0
  123. tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
  124. tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
  125. tests/fl/test_fl_monai_algo_stats.py +81 -0
  126. tests/fl/utils/__init__.py +10 -0
  127. tests/fl/utils/test_fl_exchange_object.py +63 -0
  128. tests/handlers/__init__.py +10 -0
  129. tests/handlers/test_handler_checkpoint_loader.py +182 -0
  130. tests/handlers/test_handler_checkpoint_saver.py +233 -0
  131. tests/handlers/test_handler_classification_saver.py +64 -0
  132. tests/handlers/test_handler_classification_saver_dist.py +77 -0
  133. tests/handlers/test_handler_clearml_image.py +65 -0
  134. tests/handlers/test_handler_clearml_stats.py +65 -0
  135. tests/handlers/test_handler_confusion_matrix.py +104 -0
  136. tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
  137. tests/handlers/test_handler_decollate_batch.py +66 -0
  138. tests/handlers/test_handler_early_stop.py +68 -0
  139. tests/handlers/test_handler_garbage_collector.py +73 -0
  140. tests/handlers/test_handler_hausdorff_distance.py +111 -0
  141. tests/handlers/test_handler_ignite_metric.py +191 -0
  142. tests/handlers/test_handler_lr_scheduler.py +94 -0
  143. tests/handlers/test_handler_mean_dice.py +98 -0
  144. tests/handlers/test_handler_mean_iou.py +76 -0
  145. tests/handlers/test_handler_metrics_reloaded.py +149 -0
  146. tests/handlers/test_handler_metrics_saver.py +89 -0
  147. tests/handlers/test_handler_metrics_saver_dist.py +120 -0
  148. tests/handlers/test_handler_mlflow.py +296 -0
  149. tests/handlers/test_handler_nvtx.py +93 -0
  150. tests/handlers/test_handler_panoptic_quality.py +89 -0
  151. tests/handlers/test_handler_parameter_scheduler.py +136 -0
  152. tests/handlers/test_handler_post_processing.py +74 -0
  153. tests/handlers/test_handler_prob_map_producer.py +111 -0
  154. tests/handlers/test_handler_regression_metrics.py +160 -0
  155. tests/handlers/test_handler_regression_metrics_dist.py +245 -0
  156. tests/handlers/test_handler_rocauc.py +48 -0
  157. tests/handlers/test_handler_rocauc_dist.py +54 -0
  158. tests/handlers/test_handler_stats.py +281 -0
  159. tests/handlers/test_handler_surface_distance.py +113 -0
  160. tests/handlers/test_handler_tb_image.py +61 -0
  161. tests/handlers/test_handler_tb_stats.py +166 -0
  162. tests/handlers/test_handler_validation.py +59 -0
  163. tests/handlers/test_trt_compile.py +145 -0
  164. tests/handlers/test_write_metrics_reports.py +68 -0
  165. tests/inferers/__init__.py +10 -0
  166. tests/inferers/test_avg_merger.py +179 -0
  167. tests/inferers/test_controlnet_inferers.py +1310 -0
  168. tests/inferers/test_diffusion_inferer.py +236 -0
  169. tests/inferers/test_latent_diffusion_inferer.py +824 -0
  170. tests/inferers/test_patch_inferer.py +309 -0
  171. tests/inferers/test_saliency_inferer.py +55 -0
  172. tests/inferers/test_slice_inferer.py +57 -0
  173. tests/inferers/test_sliding_window_inference.py +377 -0
  174. tests/inferers/test_sliding_window_splitter.py +284 -0
  175. tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
  176. tests/inferers/test_zarr_avg_merger.py +326 -0
  177. tests/integration/__init__.py +10 -0
  178. tests/integration/test_auto3dseg_ensemble.py +211 -0
  179. tests/integration/test_auto3dseg_hpo.py +189 -0
  180. tests/integration/test_deepedit_interaction.py +122 -0
  181. tests/integration/test_downsample_block.py +50 -0
  182. tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
  183. tests/integration/test_integration_autorunner.py +201 -0
  184. tests/integration/test_integration_bundle_run.py +240 -0
  185. tests/integration/test_integration_classification_2d.py +282 -0
  186. tests/integration/test_integration_determinism.py +95 -0
  187. tests/integration/test_integration_fast_train.py +231 -0
  188. tests/integration/test_integration_gpu_customization.py +159 -0
  189. tests/integration/test_integration_lazy_samples.py +219 -0
  190. tests/integration/test_integration_nnunetv2_runner.py +96 -0
  191. tests/integration/test_integration_segmentation_3d.py +304 -0
  192. tests/integration/test_integration_sliding_window.py +100 -0
  193. tests/integration/test_integration_stn.py +133 -0
  194. tests/integration/test_integration_unet_2d.py +67 -0
  195. tests/integration/test_integration_workers.py +61 -0
  196. tests/integration/test_integration_workflows.py +365 -0
  197. tests/integration/test_integration_workflows_adversarial.py +173 -0
  198. tests/integration/test_integration_workflows_gan.py +158 -0
  199. tests/integration/test_loader_semaphore.py +48 -0
  200. tests/integration/test_mapping_filed.py +122 -0
  201. tests/integration/test_meta_affine.py +183 -0
  202. tests/integration/test_metatensor_integration.py +114 -0
  203. tests/integration/test_module_list.py +76 -0
  204. tests/integration/test_one_of.py +283 -0
  205. tests/integration/test_pad_collation.py +124 -0
  206. tests/integration/test_reg_loss_integration.py +107 -0
  207. tests/integration/test_retinanet_predict_utils.py +154 -0
  208. tests/integration/test_seg_loss_integration.py +159 -0
  209. tests/integration/test_spatial_combine_transforms.py +185 -0
  210. tests/integration/test_testtimeaugmentation.py +186 -0
  211. tests/integration/test_vis_gradbased.py +69 -0
  212. tests/integration/test_vista3d_utils.py +159 -0
  213. tests/losses/__init__.py +10 -0
  214. tests/losses/deform/__init__.py +10 -0
  215. tests/losses/deform/test_bending_energy.py +88 -0
  216. tests/losses/deform/test_diffusion_loss.py +117 -0
  217. tests/losses/image_dissimilarity/__init__.py +10 -0
  218. tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
  219. tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
  220. tests/losses/test_adversarial_loss.py +94 -0
  221. tests/losses/test_barlow_twins_loss.py +109 -0
  222. tests/losses/test_cldice_loss.py +51 -0
  223. tests/losses/test_contrastive_loss.py +86 -0
  224. tests/losses/test_dice_ce_loss.py +123 -0
  225. tests/losses/test_dice_focal_loss.py +124 -0
  226. tests/losses/test_dice_loss.py +227 -0
  227. tests/losses/test_ds_loss.py +189 -0
  228. tests/losses/test_focal_loss.py +379 -0
  229. tests/losses/test_generalized_dice_focal_loss.py +85 -0
  230. tests/losses/test_generalized_dice_loss.py +221 -0
  231. tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
  232. tests/losses/test_giou_loss.py +62 -0
  233. tests/losses/test_hausdorff_loss.py +264 -0
  234. tests/losses/test_masked_dice_loss.py +152 -0
  235. tests/losses/test_masked_loss.py +87 -0
  236. tests/losses/test_multi_scale.py +86 -0
  237. tests/losses/test_nacl_loss.py +167 -0
  238. tests/losses/test_perceptual_loss.py +122 -0
  239. tests/losses/test_spectral_loss.py +86 -0
  240. tests/losses/test_ssim_loss.py +59 -0
  241. tests/losses/test_sure_loss.py +72 -0
  242. tests/losses/test_tversky_loss.py +198 -0
  243. tests/losses/test_unified_focal_loss.py +66 -0
  244. tests/metrics/__init__.py +10 -0
  245. tests/metrics/test_compute_confusion_matrix.py +294 -0
  246. tests/metrics/test_compute_f_beta.py +80 -0
  247. tests/metrics/test_compute_fid_metric.py +40 -0
  248. tests/metrics/test_compute_froc.py +143 -0
  249. tests/metrics/test_compute_generalized_dice.py +240 -0
  250. tests/metrics/test_compute_meandice.py +306 -0
  251. tests/metrics/test_compute_meaniou.py +223 -0
  252. tests/metrics/test_compute_mmd_metric.py +56 -0
  253. tests/metrics/test_compute_multiscalessim_metric.py +83 -0
  254. tests/metrics/test_compute_panoptic_quality.py +113 -0
  255. tests/metrics/test_compute_regression_metrics.py +196 -0
  256. tests/metrics/test_compute_roc_auc.py +155 -0
  257. tests/metrics/test_compute_variance.py +147 -0
  258. tests/metrics/test_cumulative.py +63 -0
  259. tests/metrics/test_cumulative_average.py +74 -0
  260. tests/metrics/test_cumulative_average_dist.py +48 -0
  261. tests/metrics/test_hausdorff_distance.py +209 -0
  262. tests/metrics/test_label_quality_score.py +134 -0
  263. tests/metrics/test_loss_metric.py +57 -0
  264. tests/metrics/test_metrics_reloaded.py +96 -0
  265. tests/metrics/test_ssim_metric.py +78 -0
  266. tests/metrics/test_surface_dice.py +416 -0
  267. tests/metrics/test_surface_distance.py +186 -0
  268. tests/networks/__init__.py +10 -0
  269. tests/networks/blocks/__init__.py +10 -0
  270. tests/networks/blocks/dints_block/__init__.py +10 -0
  271. tests/networks/blocks/dints_block/test_acn_block.py +41 -0
  272. tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
  273. tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
  274. tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
  275. tests/networks/blocks/test_adn.py +86 -0
  276. tests/networks/blocks/test_convolutions.py +156 -0
  277. tests/networks/blocks/test_crf_cpu.py +513 -0
  278. tests/networks/blocks/test_crf_cuda.py +528 -0
  279. tests/networks/blocks/test_crossattention.py +185 -0
  280. tests/networks/blocks/test_denseblock.py +105 -0
  281. tests/networks/blocks/test_dynunet_block.py +116 -0
  282. tests/networks/blocks/test_fpn_block.py +88 -0
  283. tests/networks/blocks/test_localnet_block.py +121 -0
  284. tests/networks/blocks/test_mlp.py +78 -0
  285. tests/networks/blocks/test_patchembedding.py +212 -0
  286. tests/networks/blocks/test_regunet_block.py +103 -0
  287. tests/networks/blocks/test_se_block.py +85 -0
  288. tests/networks/blocks/test_se_blocks.py +78 -0
  289. tests/networks/blocks/test_segresnet_block.py +57 -0
  290. tests/networks/blocks/test_selfattention.py +232 -0
  291. tests/networks/blocks/test_simple_aspp.py +87 -0
  292. tests/networks/blocks/test_spatialattention.py +55 -0
  293. tests/networks/blocks/test_subpixel_upsample.py +87 -0
  294. tests/networks/blocks/test_text_encoding.py +49 -0
  295. tests/networks/blocks/test_transformerblock.py +90 -0
  296. tests/networks/blocks/test_unetr_block.py +158 -0
  297. tests/networks/blocks/test_upsample_block.py +134 -0
  298. tests/networks/blocks/warp/__init__.py +10 -0
  299. tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
  300. tests/networks/blocks/warp/test_warp.py +250 -0
  301. tests/networks/layers/__init__.py +10 -0
  302. tests/networks/layers/filtering/__init__.py +10 -0
  303. tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
  304. tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
  305. tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
  306. tests/networks/layers/filtering/test_phl_cpu.py +259 -0
  307. tests/networks/layers/filtering/test_phl_cuda.py +167 -0
  308. tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
  309. tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
  310. tests/networks/layers/test_affine_transform.py +385 -0
  311. tests/networks/layers/test_apply_filter.py +89 -0
  312. tests/networks/layers/test_channel_pad.py +51 -0
  313. tests/networks/layers/test_conjugate_gradient.py +56 -0
  314. tests/networks/layers/test_drop_path.py +46 -0
  315. tests/networks/layers/test_gaussian.py +317 -0
  316. tests/networks/layers/test_gaussian_filter.py +206 -0
  317. tests/networks/layers/test_get_layers.py +65 -0
  318. tests/networks/layers/test_gmm.py +314 -0
  319. tests/networks/layers/test_grid_pull.py +93 -0
  320. tests/networks/layers/test_hilbert_transform.py +131 -0
  321. tests/networks/layers/test_lltm.py +62 -0
  322. tests/networks/layers/test_median_filter.py +52 -0
  323. tests/networks/layers/test_polyval.py +55 -0
  324. tests/networks/layers/test_preset_filters.py +136 -0
  325. tests/networks/layers/test_savitzky_golay_filter.py +141 -0
  326. tests/networks/layers/test_separable_filter.py +87 -0
  327. tests/networks/layers/test_skip_connection.py +48 -0
  328. tests/networks/layers/test_vector_quantizer.py +89 -0
  329. tests/networks/layers/test_weight_init.py +50 -0
  330. tests/networks/nets/__init__.py +10 -0
  331. tests/networks/nets/dints/__init__.py +10 -0
  332. tests/networks/nets/dints/test_dints_cell.py +110 -0
  333. tests/networks/nets/dints/test_dints_mixop.py +84 -0
  334. tests/networks/nets/regunet/__init__.py +10 -0
  335. tests/networks/nets/regunet/test_localnet.py +86 -0
  336. tests/networks/nets/regunet/test_regunet.py +88 -0
  337. tests/networks/nets/test_ahnet.py +224 -0
  338. tests/networks/nets/test_attentionunet.py +88 -0
  339. tests/networks/nets/test_autoencoder.py +95 -0
  340. tests/networks/nets/test_autoencoderkl.py +337 -0
  341. tests/networks/nets/test_basic_unet.py +102 -0
  342. tests/networks/nets/test_basic_unetplusplus.py +109 -0
  343. tests/networks/nets/test_bundle_init_bundle.py +55 -0
  344. tests/networks/nets/test_cell_sam_wrapper.py +58 -0
  345. tests/networks/nets/test_controlnet.py +215 -0
  346. tests/networks/nets/test_daf3d.py +62 -0
  347. tests/networks/nets/test_densenet.py +121 -0
  348. tests/networks/nets/test_diffusion_model_unet.py +585 -0
  349. tests/networks/nets/test_dints_network.py +168 -0
  350. tests/networks/nets/test_discriminator.py +59 -0
  351. tests/networks/nets/test_dynunet.py +181 -0
  352. tests/networks/nets/test_efficientnet.py +400 -0
  353. tests/networks/nets/test_flexible_unet.py +341 -0
  354. tests/networks/nets/test_fullyconnectednet.py +69 -0
  355. tests/networks/nets/test_generator.py +59 -0
  356. tests/networks/nets/test_globalnet.py +103 -0
  357. tests/networks/nets/test_highresnet.py +67 -0
  358. tests/networks/nets/test_hovernet.py +218 -0
  359. tests/networks/nets/test_mednext.py +122 -0
  360. tests/networks/nets/test_milmodel.py +92 -0
  361. tests/networks/nets/test_net_adapter.py +68 -0
  362. tests/networks/nets/test_network_consistency.py +86 -0
  363. tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
  364. tests/networks/nets/test_quicknat.py +57 -0
  365. tests/networks/nets/test_resnet.py +340 -0
  366. tests/networks/nets/test_segresnet.py +120 -0
  367. tests/networks/nets/test_segresnet_ds.py +156 -0
  368. tests/networks/nets/test_senet.py +151 -0
  369. tests/networks/nets/test_spade_autoencoderkl.py +295 -0
  370. tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
  371. tests/networks/nets/test_spade_vaegan.py +140 -0
  372. tests/networks/nets/test_swin_unetr.py +139 -0
  373. tests/networks/nets/test_torchvision_fc_model.py +201 -0
  374. tests/networks/nets/test_transchex.py +84 -0
  375. tests/networks/nets/test_transformer.py +108 -0
  376. tests/networks/nets/test_unet.py +208 -0
  377. tests/networks/nets/test_unetr.py +137 -0
  378. tests/networks/nets/test_varautoencoder.py +127 -0
  379. tests/networks/nets/test_vista3d.py +84 -0
  380. tests/networks/nets/test_vit.py +139 -0
  381. tests/networks/nets/test_vitautoenc.py +112 -0
  382. tests/networks/nets/test_vnet.py +81 -0
  383. tests/networks/nets/test_voxelmorph.py +280 -0
  384. tests/networks/nets/test_vqvae.py +274 -0
  385. tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
  386. tests/networks/schedulers/__init__.py +10 -0
  387. tests/networks/schedulers/test_scheduler_ddim.py +83 -0
  388. tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
  389. tests/networks/schedulers/test_scheduler_pndm.py +108 -0
  390. tests/networks/test_bundle_onnx_export.py +71 -0
  391. tests/networks/test_convert_to_onnx.py +106 -0
  392. tests/networks/test_convert_to_torchscript.py +46 -0
  393. tests/networks/test_convert_to_trt.py +79 -0
  394. tests/networks/test_save_state.py +73 -0
  395. tests/networks/test_to_onehot.py +63 -0
  396. tests/networks/test_varnet.py +63 -0
  397. tests/networks/utils/__init__.py +10 -0
  398. tests/networks/utils/test_copy_model_state.py +187 -0
  399. tests/networks/utils/test_eval_mode.py +34 -0
  400. tests/networks/utils/test_freeze_layers.py +61 -0
  401. tests/networks/utils/test_replace_module.py +98 -0
  402. tests/networks/utils/test_train_mode.py +34 -0
  403. tests/optimizers/__init__.py +10 -0
  404. tests/optimizers/test_generate_param_groups.py +105 -0
  405. tests/optimizers/test_lr_finder.py +108 -0
  406. tests/optimizers/test_lr_scheduler.py +71 -0
  407. tests/optimizers/test_optim_novograd.py +100 -0
  408. tests/profile_subclass/__init__.py +10 -0
  409. tests/profile_subclass/cprofile_profiling.py +29 -0
  410. tests/profile_subclass/min_classes.py +30 -0
  411. tests/profile_subclass/profiling.py +73 -0
  412. tests/profile_subclass/pyspy_profiling.py +41 -0
  413. tests/transforms/__init__.py +10 -0
  414. tests/transforms/compose/__init__.py +10 -0
  415. tests/transforms/compose/test_compose.py +758 -0
  416. tests/transforms/compose/test_some_of.py +258 -0
  417. tests/transforms/croppad/__init__.py +10 -0
  418. tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
  419. tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
  420. tests/transforms/functional/__init__.py +10 -0
  421. tests/transforms/functional/test_apply.py +75 -0
  422. tests/transforms/functional/test_resample.py +50 -0
  423. tests/transforms/intensity/__init__.py +10 -0
  424. tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
  425. tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
  426. tests/transforms/intensity/test_foreground_mask.py +98 -0
  427. tests/transforms/intensity/test_foreground_maskd.py +106 -0
  428. tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
  429. tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
  430. tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
  431. tests/transforms/inverse/__init__.py +10 -0
  432. tests/transforms/inverse/test_inverse_array.py +76 -0
  433. tests/transforms/inverse/test_traceable_transform.py +59 -0
  434. tests/transforms/post/__init__.py +10 -0
  435. tests/transforms/post/test_label_filterd.py +78 -0
  436. tests/transforms/post/test_probnms.py +72 -0
  437. tests/transforms/post/test_probnmsd.py +79 -0
  438. tests/transforms/post/test_remove_small_objects.py +102 -0
  439. tests/transforms/spatial/__init__.py +10 -0
  440. tests/transforms/spatial/test_convert_box_points.py +119 -0
  441. tests/transforms/spatial/test_grid_patch.py +134 -0
  442. tests/transforms/spatial/test_grid_patchd.py +102 -0
  443. tests/transforms/spatial/test_rand_grid_patch.py +150 -0
  444. tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
  445. tests/transforms/spatial/test_spatial_resampled.py +124 -0
  446. tests/transforms/test_activations.py +120 -0
  447. tests/transforms/test_activationsd.py +64 -0
  448. tests/transforms/test_adaptors.py +160 -0
  449. tests/transforms/test_add_coordinate_channels.py +53 -0
  450. tests/transforms/test_add_coordinate_channelsd.py +67 -0
  451. tests/transforms/test_add_extreme_points_channel.py +80 -0
  452. tests/transforms/test_add_extreme_points_channeld.py +77 -0
  453. tests/transforms/test_adjust_contrast.py +70 -0
  454. tests/transforms/test_adjust_contrastd.py +64 -0
  455. tests/transforms/test_affine.py +245 -0
  456. tests/transforms/test_affine_grid.py +152 -0
  457. tests/transforms/test_affined.py +190 -0
  458. tests/transforms/test_as_channel_last.py +38 -0
  459. tests/transforms/test_as_channel_lastd.py +44 -0
  460. tests/transforms/test_as_discrete.py +81 -0
  461. tests/transforms/test_as_discreted.py +82 -0
  462. tests/transforms/test_border_pad.py +49 -0
  463. tests/transforms/test_border_padd.py +45 -0
  464. tests/transforms/test_bounding_rect.py +54 -0
  465. tests/transforms/test_bounding_rectd.py +53 -0
  466. tests/transforms/test_cast_to_type.py +63 -0
  467. tests/transforms/test_cast_to_typed.py +74 -0
  468. tests/transforms/test_center_scale_crop.py +55 -0
  469. tests/transforms/test_center_scale_cropd.py +56 -0
  470. tests/transforms/test_center_spatial_crop.py +56 -0
  471. tests/transforms/test_center_spatial_cropd.py +63 -0
  472. tests/transforms/test_classes_to_indices.py +93 -0
  473. tests/transforms/test_classes_to_indicesd.py +110 -0
  474. tests/transforms/test_clip_intensity_percentiles.py +196 -0
  475. tests/transforms/test_clip_intensity_percentilesd.py +193 -0
  476. tests/transforms/test_compose_get_number_conversions.py +127 -0
  477. tests/transforms/test_concat_itemsd.py +82 -0
  478. tests/transforms/test_convert_to_multi_channel.py +59 -0
  479. tests/transforms/test_convert_to_multi_channeld.py +37 -0
  480. tests/transforms/test_copy_itemsd.py +86 -0
  481. tests/transforms/test_create_grid_and_affine.py +274 -0
  482. tests/transforms/test_crop_foreground.py +164 -0
  483. tests/transforms/test_crop_foregroundd.py +205 -0
  484. tests/transforms/test_cucim_dict_transform.py +142 -0
  485. tests/transforms/test_cucim_transform.py +141 -0
  486. tests/transforms/test_data_stats.py +221 -0
  487. tests/transforms/test_data_statsd.py +249 -0
  488. tests/transforms/test_delete_itemsd.py +58 -0
  489. tests/transforms/test_detect_envelope.py +159 -0
  490. tests/transforms/test_distance_transform_edt.py +202 -0
  491. tests/transforms/test_divisible_pad.py +49 -0
  492. tests/transforms/test_divisible_padd.py +42 -0
  493. tests/transforms/test_ensure_channel_first.py +113 -0
  494. tests/transforms/test_ensure_channel_firstd.py +85 -0
  495. tests/transforms/test_ensure_type.py +94 -0
  496. tests/transforms/test_ensure_typed.py +110 -0
  497. tests/transforms/test_fg_bg_to_indices.py +83 -0
  498. tests/transforms/test_fg_bg_to_indicesd.py +78 -0
  499. tests/transforms/test_fill_holes.py +207 -0
  500. tests/transforms/test_fill_holesd.py +209 -0
  501. tests/transforms/test_flatten_sub_keysd.py +64 -0
  502. tests/transforms/test_flip.py +83 -0
  503. tests/transforms/test_flipd.py +90 -0
  504. tests/transforms/test_fourier.py +70 -0
  505. tests/transforms/test_gaussian_sharpen.py +92 -0
  506. tests/transforms/test_gaussian_sharpend.py +92 -0
  507. tests/transforms/test_gaussian_smooth.py +96 -0
  508. tests/transforms/test_gaussian_smoothd.py +96 -0
  509. tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
  510. tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
  511. tests/transforms/test_generate_spatial_bounding_box.py +114 -0
  512. tests/transforms/test_get_extreme_points.py +57 -0
  513. tests/transforms/test_gibbs_noise.py +75 -0
  514. tests/transforms/test_gibbs_noised.py +88 -0
  515. tests/transforms/test_grid_distortion.py +113 -0
  516. tests/transforms/test_grid_distortiond.py +87 -0
  517. tests/transforms/test_grid_split.py +88 -0
  518. tests/transforms/test_grid_splitd.py +96 -0
  519. tests/transforms/test_histogram_normalize.py +59 -0
  520. tests/transforms/test_histogram_normalized.py +59 -0
  521. tests/transforms/test_image_filter.py +259 -0
  522. tests/transforms/test_intensity_stats.py +73 -0
  523. tests/transforms/test_intensity_statsd.py +90 -0
  524. tests/transforms/test_inverse.py +521 -0
  525. tests/transforms/test_inverse_collation.py +147 -0
  526. tests/transforms/test_invert.py +105 -0
  527. tests/transforms/test_invertd.py +142 -0
  528. tests/transforms/test_k_space_spike_noise.py +81 -0
  529. tests/transforms/test_k_space_spike_noised.py +98 -0
  530. tests/transforms/test_keep_largest_connected_component.py +419 -0
  531. tests/transforms/test_keep_largest_connected_componentd.py +348 -0
  532. tests/transforms/test_label_filter.py +78 -0
  533. tests/transforms/test_label_to_contour.py +179 -0
  534. tests/transforms/test_label_to_contourd.py +182 -0
  535. tests/transforms/test_label_to_mask.py +69 -0
  536. tests/transforms/test_label_to_maskd.py +70 -0
  537. tests/transforms/test_load_image.py +502 -0
  538. tests/transforms/test_load_imaged.py +198 -0
  539. tests/transforms/test_load_spacing_orientation.py +149 -0
  540. tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
  541. tests/transforms/test_map_binary_to_indices.py +75 -0
  542. tests/transforms/test_map_classes_to_indices.py +135 -0
  543. tests/transforms/test_map_label_value.py +89 -0
  544. tests/transforms/test_map_label_valued.py +85 -0
  545. tests/transforms/test_map_transform.py +45 -0
  546. tests/transforms/test_mask_intensity.py +74 -0
  547. tests/transforms/test_mask_intensityd.py +68 -0
  548. tests/transforms/test_mean_ensemble.py +77 -0
  549. tests/transforms/test_mean_ensembled.py +91 -0
  550. tests/transforms/test_median_smooth.py +41 -0
  551. tests/transforms/test_median_smoothd.py +65 -0
  552. tests/transforms/test_morphological_ops.py +101 -0
  553. tests/transforms/test_nifti_endianness.py +107 -0
  554. tests/transforms/test_normalize_intensity.py +143 -0
  555. tests/transforms/test_normalize_intensityd.py +81 -0
  556. tests/transforms/test_nvtx_decorator.py +289 -0
  557. tests/transforms/test_nvtx_transform.py +143 -0
  558. tests/transforms/test_orientation.py +247 -0
  559. tests/transforms/test_orientationd.py +112 -0
  560. tests/transforms/test_rand_adjust_contrast.py +45 -0
  561. tests/transforms/test_rand_adjust_contrastd.py +44 -0
  562. tests/transforms/test_rand_affine.py +201 -0
  563. tests/transforms/test_rand_affine_grid.py +212 -0
  564. tests/transforms/test_rand_affined.py +281 -0
  565. tests/transforms/test_rand_axis_flip.py +50 -0
  566. tests/transforms/test_rand_axis_flipd.py +50 -0
  567. tests/transforms/test_rand_bias_field.py +69 -0
  568. tests/transforms/test_rand_bias_fieldd.py +65 -0
  569. tests/transforms/test_rand_coarse_dropout.py +110 -0
  570. tests/transforms/test_rand_coarse_dropoutd.py +107 -0
  571. tests/transforms/test_rand_coarse_shuffle.py +65 -0
  572. tests/transforms/test_rand_coarse_shuffled.py +59 -0
  573. tests/transforms/test_rand_crop_by_label_classes.py +170 -0
  574. tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
  575. tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
  576. tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
  577. tests/transforms/test_rand_cucim_dict_transform.py +162 -0
  578. tests/transforms/test_rand_cucim_transform.py +162 -0
  579. tests/transforms/test_rand_deform_grid.py +138 -0
  580. tests/transforms/test_rand_elastic_2d.py +127 -0
  581. tests/transforms/test_rand_elastic_3d.py +104 -0
  582. tests/transforms/test_rand_elasticd_2d.py +177 -0
  583. tests/transforms/test_rand_elasticd_3d.py +156 -0
  584. tests/transforms/test_rand_flip.py +60 -0
  585. tests/transforms/test_rand_flipd.py +55 -0
  586. tests/transforms/test_rand_gaussian_noise.py +48 -0
  587. tests/transforms/test_rand_gaussian_noised.py +54 -0
  588. tests/transforms/test_rand_gaussian_sharpen.py +140 -0
  589. tests/transforms/test_rand_gaussian_sharpend.py +143 -0
  590. tests/transforms/test_rand_gaussian_smooth.py +98 -0
  591. tests/transforms/test_rand_gaussian_smoothd.py +98 -0
  592. tests/transforms/test_rand_gibbs_noise.py +103 -0
  593. tests/transforms/test_rand_gibbs_noised.py +117 -0
  594. tests/transforms/test_rand_grid_distortion.py +99 -0
  595. tests/transforms/test_rand_grid_distortiond.py +90 -0
  596. tests/transforms/test_rand_histogram_shift.py +92 -0
  597. tests/transforms/test_rand_k_space_spike_noise.py +92 -0
  598. tests/transforms/test_rand_k_space_spike_noised.py +76 -0
  599. tests/transforms/test_rand_rician_noise.py +52 -0
  600. tests/transforms/test_rand_rician_noised.py +52 -0
  601. tests/transforms/test_rand_rotate.py +166 -0
  602. tests/transforms/test_rand_rotate90.py +100 -0
  603. tests/transforms/test_rand_rotate90d.py +112 -0
  604. tests/transforms/test_rand_rotated.py +187 -0
  605. tests/transforms/test_rand_scale_crop.py +78 -0
  606. tests/transforms/test_rand_scale_cropd.py +98 -0
  607. tests/transforms/test_rand_scale_intensity.py +54 -0
  608. tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
  609. tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
  610. tests/transforms/test_rand_scale_intensityd.py +53 -0
  611. tests/transforms/test_rand_shift_intensity.py +52 -0
  612. tests/transforms/test_rand_shift_intensityd.py +67 -0
  613. tests/transforms/test_rand_simulate_low_resolution.py +83 -0
  614. tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
  615. tests/transforms/test_rand_spatial_crop.py +107 -0
  616. tests/transforms/test_rand_spatial_crop_samples.py +128 -0
  617. tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
  618. tests/transforms/test_rand_spatial_cropd.py +112 -0
  619. tests/transforms/test_rand_std_shift_intensity.py +43 -0
  620. tests/transforms/test_rand_std_shift_intensityd.py +38 -0
  621. tests/transforms/test_rand_zoom.py +105 -0
  622. tests/transforms/test_rand_zoomd.py +108 -0
  623. tests/transforms/test_randidentity.py +49 -0
  624. tests/transforms/test_random_order.py +144 -0
  625. tests/transforms/test_randtorchvisiond.py +65 -0
  626. tests/transforms/test_regularization.py +139 -0
  627. tests/transforms/test_remove_repeated_channel.py +34 -0
  628. tests/transforms/test_remove_repeated_channeld.py +44 -0
  629. tests/transforms/test_repeat_channel.py +34 -0
  630. tests/transforms/test_repeat_channeld.py +41 -0
  631. tests/transforms/test_resample_backends.py +65 -0
  632. tests/transforms/test_resample_to_match.py +110 -0
  633. tests/transforms/test_resample_to_matchd.py +93 -0
  634. tests/transforms/test_resampler.py +165 -0
  635. tests/transforms/test_resize.py +140 -0
  636. tests/transforms/test_resize_with_pad_or_crop.py +91 -0
  637. tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
  638. tests/transforms/test_resized.py +163 -0
  639. tests/transforms/test_rotate.py +160 -0
  640. tests/transforms/test_rotate90.py +212 -0
  641. tests/transforms/test_rotate90d.py +106 -0
  642. tests/transforms/test_rotated.py +179 -0
  643. tests/transforms/test_save_classificationd.py +109 -0
  644. tests/transforms/test_save_image.py +80 -0
  645. tests/transforms/test_save_imaged.py +130 -0
  646. tests/transforms/test_savitzky_golay_smooth.py +73 -0
  647. tests/transforms/test_savitzky_golay_smoothd.py +73 -0
  648. tests/transforms/test_scale_intensity.py +76 -0
  649. tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
  650. tests/transforms/test_scale_intensity_range.py +41 -0
  651. tests/transforms/test_scale_intensity_ranged.py +40 -0
  652. tests/transforms/test_scale_intensityd.py +57 -0
  653. tests/transforms/test_select_itemsd.py +41 -0
  654. tests/transforms/test_shift_intensity.py +31 -0
  655. tests/transforms/test_shift_intensityd.py +44 -0
  656. tests/transforms/test_signal_continuouswavelet.py +44 -0
  657. tests/transforms/test_signal_fillempty.py +52 -0
  658. tests/transforms/test_signal_fillemptyd.py +60 -0
  659. tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
  660. tests/transforms/test_signal_rand_add_sine.py +52 -0
  661. tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
  662. tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
  663. tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
  664. tests/transforms/test_signal_rand_drop.py +50 -0
  665. tests/transforms/test_signal_rand_scale.py +52 -0
  666. tests/transforms/test_signal_rand_shift.py +55 -0
  667. tests/transforms/test_signal_remove_frequency.py +71 -0
  668. tests/transforms/test_smooth_field.py +177 -0
  669. tests/transforms/test_sobel_gradient.py +189 -0
  670. tests/transforms/test_sobel_gradientd.py +212 -0
  671. tests/transforms/test_spacing.py +381 -0
  672. tests/transforms/test_spacingd.py +178 -0
  673. tests/transforms/test_spatial_crop.py +82 -0
  674. tests/transforms/test_spatial_cropd.py +74 -0
  675. tests/transforms/test_spatial_pad.py +57 -0
  676. tests/transforms/test_spatial_padd.py +43 -0
  677. tests/transforms/test_spatial_resample.py +235 -0
  678. tests/transforms/test_squeezedim.py +62 -0
  679. tests/transforms/test_squeezedimd.py +98 -0
  680. tests/transforms/test_std_shift_intensity.py +76 -0
  681. tests/transforms/test_std_shift_intensityd.py +74 -0
  682. tests/transforms/test_threshold_intensity.py +38 -0
  683. tests/transforms/test_threshold_intensityd.py +58 -0
  684. tests/transforms/test_to_contiguous.py +47 -0
  685. tests/transforms/test_to_cupy.py +112 -0
  686. tests/transforms/test_to_cupyd.py +76 -0
  687. tests/transforms/test_to_device.py +42 -0
  688. tests/transforms/test_to_deviced.py +37 -0
  689. tests/transforms/test_to_numpy.py +85 -0
  690. tests/transforms/test_to_numpyd.py +68 -0
  691. tests/transforms/test_to_pil.py +52 -0
  692. tests/transforms/test_to_pild.py +50 -0
  693. tests/transforms/test_to_tensor.py +60 -0
  694. tests/transforms/test_to_tensord.py +71 -0
  695. tests/transforms/test_torchvision.py +66 -0
  696. tests/transforms/test_torchvisiond.py +63 -0
  697. tests/transforms/test_transform.py +62 -0
  698. tests/transforms/test_transpose.py +41 -0
  699. tests/transforms/test_transposed.py +52 -0
  700. tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
  701. tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
  702. tests/transforms/test_vote_ensemble.py +84 -0
  703. tests/transforms/test_vote_ensembled.py +107 -0
  704. tests/transforms/test_with_allow_missing_keys.py +76 -0
  705. tests/transforms/test_zoom.py +120 -0
  706. tests/transforms/test_zoomd.py +94 -0
  707. tests/transforms/transform/__init__.py +10 -0
  708. tests/transforms/transform/test_randomizable.py +52 -0
  709. tests/transforms/transform/test_randomizable_transform_type.py +37 -0
  710. tests/transforms/utility/__init__.py +10 -0
  711. tests/transforms/utility/test_apply_transform_to_points.py +81 -0
  712. tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
  713. tests/transforms/utility/test_identity.py +29 -0
  714. tests/transforms/utility/test_identityd.py +30 -0
  715. tests/transforms/utility/test_lambda.py +71 -0
  716. tests/transforms/utility/test_lambdad.py +83 -0
  717. tests/transforms/utility/test_rand_lambda.py +87 -0
  718. tests/transforms/utility/test_rand_lambdad.py +77 -0
  719. tests/transforms/utility/test_simulatedelay.py +36 -0
  720. tests/transforms/utility/test_simulatedelayd.py +36 -0
  721. tests/transforms/utility/test_splitdim.py +52 -0
  722. tests/transforms/utility/test_splitdimd.py +96 -0
  723. tests/transforms/utils/__init__.py +10 -0
  724. tests/transforms/utils/test_correct_crop_centers.py +36 -0
  725. tests/transforms/utils/test_get_unique_labels.py +45 -0
  726. tests/transforms/utils/test_print_transform_backends.py +29 -0
  727. tests/transforms/utils/test_soft_clip.py +125 -0
  728. tests/utils/__init__.py +10 -0
  729. tests/utils/enums/__init__.py +10 -0
  730. tests/utils/enums/test_hovernet_loss.py +190 -0
  731. tests/utils/enums/test_ordering.py +289 -0
  732. tests/utils/enums/test_wsireader.py +663 -0
  733. tests/utils/misc/__init__.py +10 -0
  734. tests/utils/misc/test_ensure_tuple.py +53 -0
  735. tests/utils/misc/test_monai_env_vars.py +44 -0
  736. tests/utils/misc/test_monai_utils_misc.py +103 -0
  737. tests/utils/misc/test_str2bool.py +34 -0
  738. tests/utils/misc/test_str2list.py +33 -0
  739. tests/utils/test_alias.py +44 -0
  740. tests/utils/test_component_store.py +73 -0
  741. tests/utils/test_deprecated.py +455 -0
  742. tests/utils/test_enum_bound_interp.py +75 -0
  743. tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
  744. tests/utils/test_get_package_version.py +34 -0
  745. tests/utils/test_handler_logfile.py +84 -0
  746. tests/utils/test_handler_metric_logger.py +62 -0
  747. tests/utils/test_list_to_dict.py +43 -0
  748. tests/utils/test_look_up_option.py +87 -0
  749. tests/utils/test_optional_import.py +80 -0
  750. tests/utils/test_pad_mode.py +39 -0
  751. tests/utils/test_profiling.py +208 -0
  752. tests/utils/test_rankfilter_dist.py +77 -0
  753. tests/utils/test_require_pkg.py +83 -0
  754. tests/utils/test_sample_slices.py +43 -0
  755. tests/utils/test_set_determinism.py +74 -0
  756. tests/utils/test_squeeze_unsqueeze.py +71 -0
  757. tests/utils/test_state_cacher.py +67 -0
  758. tests/utils/test_torchscript_utils.py +113 -0
  759. tests/utils/test_version.py +91 -0
  760. tests/utils/test_version_after.py +65 -0
  761. tests/utils/type_conversion/__init__.py +10 -0
  762. tests/utils/type_conversion/test_convert_data_type.py +152 -0
  763. tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
  764. tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
  765. tests/visualize/__init__.py +10 -0
  766. tests/visualize/test_img2tensorboard.py +46 -0
  767. tests/visualize/test_occlusion_sensitivity.py +128 -0
  768. tests/visualize/test_plot_2d_or_3d_image.py +74 -0
  769. tests/visualize/test_vis_cam.py +98 -0
  770. tests/visualize/test_vis_gradcam.py +211 -0
  771. tests/visualize/utils/__init__.py +10 -0
  772. tests/visualize/utils/test_blend_images.py +63 -0
  773. tests/visualize/utils/test_matshow3d.py +133 -0
  774. monai_weekly-1.5.dev2506.dist-info/RECORD +0 -427
  775. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/LICENSE +0 -0
  776. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/WHEEL +0 -0
@@ -0,0 +1,36 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import time
15
+ import unittest
16
+
17
+ import numpy as np
18
+ from parameterized import parameterized
19
+
20
+ from monai.transforms.utility.dictionary import SimulateDelayd
21
+ from tests.test_utils import NumpyImageTestCase2D
22
+
23
+
24
+ class TestSimulateDelay(NumpyImageTestCase2D):
25
+ @parameterized.expand([(0.45,), (1,)])
26
+ def test_value(self, delay_test_time: float):
27
+ resize = SimulateDelayd(keys="imgd", delay_time=delay_test_time)
28
+ start: float = time.time()
29
+ _ = resize({"imgd": self.imt[0]})
30
+ stop: float = time.time()
31
+ measured_approximate: float = stop - start
32
+ np.testing.assert_allclose(delay_test_time, measured_approximate, rtol=0.5)
33
+
34
+
35
+ if __name__ == "__main__":
36
+ unittest.main()
@@ -0,0 +1,52 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import numpy as np
17
+ from parameterized import parameterized
18
+
19
+ from monai.transforms.utility.array import SplitDim
20
+ from tests.test_utils import TEST_NDARRAYS
21
+
22
+ TESTS = []
23
+ for p in TEST_NDARRAYS:
24
+ for keepdim in (True, False):
25
+ TESTS.append(((2, 10, 8, 7), keepdim, p))
26
+
27
+
28
+ class TestSplitDim(unittest.TestCase):
29
+ @parameterized.expand(TESTS)
30
+ def test_correct_shape(self, shape, keepdim, im_type):
31
+ arr = im_type(np.random.rand(*shape))
32
+ for dim in range(arr.ndim):
33
+ out = SplitDim(dim, keepdim)(arr)
34
+ self.assertIsInstance(out, (list, tuple))
35
+ self.assertEqual(type(out[0]), type(arr))
36
+ self.assertEqual(len(out), arr.shape[dim])
37
+ expected_ndim = arr.ndim if keepdim else arr.ndim - 1
38
+ self.assertEqual(out[0].ndim, expected_ndim)
39
+ # assert is a shallow copy
40
+ arr[0, 0, 0, 0] *= 2
41
+ self.assertEqual(arr.flatten()[0], out[0].flatten()[0])
42
+
43
+ def test_singleton(self):
44
+ shape = (2, 1, 8, 7)
45
+ for p in TEST_NDARRAYS:
46
+ arr = p(np.random.rand(*shape))
47
+ out = SplitDim(dim=1)(arr)
48
+ self.assertEqual(out[0].shape, shape)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ unittest.main()
@@ -0,0 +1,96 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+ from copy import deepcopy
16
+
17
+ import numpy as np
18
+ import torch
19
+ from parameterized import parameterized
20
+
21
+ from monai.data.meta_tensor import MetaTensor
22
+ from monai.transforms import LoadImaged
23
+ from monai.transforms.utility.dictionary import SplitDimd
24
+ from tests.test_utils import TEST_NDARRAYS, assert_allclose, make_nifti_image, make_rand_affine
25
+
26
+ TESTS = []
27
+ for p in TEST_NDARRAYS:
28
+ for keepdim in (True, False):
29
+ for update_meta in (True, False):
30
+ for list_output in (True, False):
31
+ TESTS.append((keepdim, p, update_meta, list_output))
32
+
33
+
34
+ class TestSplitDimd(unittest.TestCase):
35
+ data: MetaTensor
36
+
37
+ @classmethod
38
+ def setUpClass(cls) -> None:
39
+ arr = np.random.rand(2, 10, 8, 7)
40
+ affine = make_rand_affine()
41
+ data = {"i": make_nifti_image(arr, affine)}
42
+
43
+ loader = LoadImaged("i", image_only=True)
44
+ cls.data = loader(data)
45
+
46
+ @parameterized.expand(TESTS)
47
+ def test_correct(self, keepdim, im_type, update_meta, list_output):
48
+ data = deepcopy(self.data)
49
+ data["i"] = im_type(data["i"])
50
+ arr = data["i"]
51
+ for dim in range(arr.ndim):
52
+ out = SplitDimd("i", dim=dim, keepdim=keepdim, update_meta=update_meta, list_output=list_output)(data)
53
+ if list_output:
54
+ self.assertIsInstance(out, list)
55
+ self.assertEqual(len(out), arr.shape[dim])
56
+ else:
57
+ self.assertIsInstance(out, dict)
58
+ self.assertEqual(len(out.keys()), len(data.keys()) + arr.shape[dim])
59
+ # if updating metadata, pick some random points and
60
+ # check same world coordinates between input and output
61
+ if update_meta:
62
+ for _ in range(10):
63
+ idx = [np.random.choice(i) for i in arr.shape]
64
+ split_im_idx = idx[dim]
65
+ split_idx = deepcopy(idx)
66
+ split_idx[dim] = 0
67
+ if list_output:
68
+ split_im = out[split_im_idx]["i"]
69
+ else:
70
+ split_im = out[f"i_{split_im_idx}"]
71
+ if isinstance(data, MetaTensor) and isinstance(split_im, MetaTensor):
72
+ # idx[1:] to remove channel and then add 1 for 4th element
73
+ real_world = data.affine @ torch.tensor(idx[1:] + [1]).double()
74
+ real_world2 = split_im.affine @ torch.tensor(split_idx[1:] + [1]).double()
75
+ assert_allclose(real_world, real_world2)
76
+
77
+ if list_output:
78
+ out = out[0]["i"]
79
+ else:
80
+ out = out["i_0"]
81
+ expected_ndim = arr.ndim if keepdim else arr.ndim - 1
82
+ self.assertEqual(out.ndim, expected_ndim)
83
+ # assert is a shallow copy
84
+ arr[0, 0, 0, 0] *= 2
85
+ self.assertEqual(arr.flatten()[0], out.flatten()[0])
86
+
87
+ def test_singleton(self):
88
+ shape = (2, 1, 8, 7)
89
+ for p in TEST_NDARRAYS:
90
+ arr = p(np.random.rand(*shape))
91
+ out = SplitDimd("i", dim=1)({"i": arr})
92
+ self.assertEqual(out["i"].shape, shape)
93
+
94
+
95
+ if __name__ == "__main__":
96
+ unittest.main()
@@ -0,0 +1,10 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
@@ -0,0 +1,36 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import torch
17
+ from parameterized import parameterized
18
+
19
+ from monai.transforms.utils import correct_crop_centers
20
+ from tests.test_utils import assert_allclose
21
+
22
+ TESTS = [[[1, 5, 0], [2, 2, 2], [10, 10, 10]], [[4, 4, 4], [2, 2, 1], [10, 10, 10]]]
23
+
24
+
25
+ class TestCorrectCropCenters(unittest.TestCase):
26
+ @parameterized.expand(TESTS)
27
+ def test_torch(self, spatial_size, centers, label_spatial_shape):
28
+ result1 = correct_crop_centers(centers, spatial_size, label_spatial_shape)
29
+ centers = [torch.tensor(i) for i in centers]
30
+ result2 = correct_crop_centers(centers, spatial_size, label_spatial_shape)
31
+ assert_allclose(result1, result2)
32
+ self.assertEqual(type(result1[0]), type(result2[0]))
33
+
34
+
35
+ if __name__ == "__main__":
36
+ unittest.main()
@@ -0,0 +1,45 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from parameterized import parameterized
19
+
20
+ from monai.transforms.utils import get_unique_labels
21
+ from monai.transforms.utils_pytorch_numpy_unification import moveaxis
22
+ from tests.test_utils import TEST_NDARRAYS
23
+
24
+ grid_raw = [[0, 0, 0], [0, 0, 1], [2, 2, 3], [5, 5, 6], [3, 6, 2], [5, 6, 6]]
25
+ grid = torch.Tensor(grid_raw).unsqueeze(0).to(torch.int64)
26
+ grid_onehot = moveaxis(F.one_hot(grid)[0], -1, 0)
27
+
28
+ TESTS = []
29
+ for p in TEST_NDARRAYS:
30
+ for o_h in (False, True):
31
+ im = grid_onehot if o_h else grid
32
+ TESTS.append([dict(img=p(im), is_onehot=o_h), {0, 1, 2, 3, 5, 6}])
33
+ TESTS.append([dict(img=p(im), is_onehot=o_h, discard=0), {1, 2, 3, 5, 6}])
34
+ TESTS.append([dict(img=p(im), is_onehot=o_h, discard=[1, 2]), {0, 3, 5, 6}])
35
+
36
+
37
+ class TestGetUniqueLabels(unittest.TestCase):
38
+ @parameterized.expand(TESTS)
39
+ def test_correct_results(self, args, expected):
40
+ result = get_unique_labels(**args)
41
+ self.assertEqual(result, expected)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ unittest.main()
@@ -0,0 +1,29 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ from monai.transforms.utils import get_transform_backends, print_transform_backends
17
+
18
+
19
+ class TestPrintTransformBackends(unittest.TestCase):
20
+
21
+ def test_get_number_of_conversions(self):
22
+ tr_t_or_np, *_ = get_transform_backends()
23
+ self.assertGreater(len(tr_t_or_np), 0)
24
+ print_transform_backends()
25
+
26
+
27
+ if __name__ == "__main__":
28
+ a = TestPrintTransformBackends()
29
+ a.test_get_number_of_conversions()
@@ -0,0 +1,125 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import numpy as np
17
+ import torch
18
+ from parameterized import parameterized
19
+
20
+ from monai.transforms.utils import soft_clip
21
+
22
+ TEST_CASES = [
23
+ [
24
+ {"minv": 2, "maxv": 8, "sharpness_factor": 10},
25
+ {
26
+ "input": torch.arange(10).float(),
27
+ "clipped": torch.tensor([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 7.9307, 8.0000]),
28
+ },
29
+ ],
30
+ [
31
+ {"minv": 2, "maxv": None, "sharpness_factor": 10},
32
+ {
33
+ "input": torch.arange(10).float(),
34
+ "clipped": torch.tensor([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000]),
35
+ },
36
+ ],
37
+ [
38
+ {"minv": None, "maxv": 7, "sharpness_factor": 10},
39
+ {
40
+ "input": torch.arange(10).float(),
41
+ "clipped": torch.tensor([0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 6.9307, 7.0000, 7.0000]),
42
+ },
43
+ ],
44
+ [
45
+ {"minv": 2, "maxv": 8, "sharpness_factor": 1.0},
46
+ {
47
+ "input": torch.arange(10).float(),
48
+ "clipped": torch.tensor([2.1266, 2.3124, 2.6907, 3.3065, 4.1088, 5.0000, 5.8912, 6.6935, 7.3093, 7.6877]),
49
+ },
50
+ ],
51
+ [
52
+ {"minv": 2, "maxv": 8, "sharpness_factor": 3.0},
53
+ {
54
+ "input": torch.arange(10).float(),
55
+ "clipped": torch.tensor([2.0008, 2.0162, 2.2310, 3.0162, 4.0008, 5.0000, 5.9992, 6.9838, 7.7690, 7.9838]),
56
+ },
57
+ ],
58
+ [
59
+ {"minv": 2, "maxv": 8, "sharpness_factor": 5.0},
60
+ {
61
+ "input": torch.arange(10).float(),
62
+ "clipped": torch.tensor([2.0000, 2.0013, 2.1386, 3.0013, 4.0000, 5.0000, 6.0000, 6.9987, 7.8614, 7.9987]),
63
+ },
64
+ ],
65
+ [
66
+ {"minv": 2, "maxv": 8, "sharpness_factor": 10},
67
+ {
68
+ "input": np.arange(10).astype(np.float32),
69
+ "clipped": np.array([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 7.9307, 8.0000]),
70
+ },
71
+ ],
72
+ [
73
+ {"minv": 2, "maxv": None, "sharpness_factor": 10},
74
+ {
75
+ "input": np.arange(10).astype(float),
76
+ "clipped": np.array([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000]),
77
+ },
78
+ ],
79
+ [
80
+ {"minv": None, "maxv": 7, "sharpness_factor": 10},
81
+ {
82
+ "input": np.arange(10).astype(float),
83
+ "clipped": np.array([0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 6.9307, 7.0000, 7.0000]),
84
+ },
85
+ ],
86
+ [
87
+ {"minv": 2, "maxv": 8, "sharpness_factor": 1.0},
88
+ {
89
+ "input": np.arange(10).astype(float),
90
+ "clipped": np.array([2.1266, 2.3124, 2.6907, 3.3065, 4.1088, 5.0000, 5.8912, 6.6935, 7.3093, 7.6877]),
91
+ },
92
+ ],
93
+ [
94
+ {"minv": 2, "maxv": 8, "sharpness_factor": 3.0},
95
+ {
96
+ "input": np.arange(10).astype(float),
97
+ "clipped": np.array([2.0008, 2.0162, 2.2310, 3.0162, 4.0008, 5.0000, 5.9992, 6.9838, 7.7690, 7.9838]),
98
+ },
99
+ ],
100
+ [
101
+ {"minv": 2, "maxv": 8, "sharpness_factor": 5.0},
102
+ {
103
+ "input": np.arange(10).astype(float),
104
+ "clipped": np.array([2.0000, 2.0013, 2.1386, 3.0013, 4.0000, 5.0000, 6.0000, 6.9987, 7.8614, 7.9987]),
105
+ },
106
+ ],
107
+ ]
108
+
109
+
110
+ class TestSoftClip(unittest.TestCase):
111
+
112
+ @parameterized.expand(TEST_CASES)
113
+ def test_result(self, input_param, input_data):
114
+ outputs = soft_clip(input_data["input"], **input_param)
115
+ expected_val = input_data["clipped"]
116
+ if isinstance(outputs, torch.Tensor):
117
+ np.testing.assert_allclose(
118
+ outputs.detach().cpu().numpy(), expected_val.detach().cpu().numpy(), atol=1e-4, rtol=1e-4
119
+ )
120
+ else:
121
+ np.testing.assert_allclose(outputs, expected_val, atol=1e-4, rtol=1e-4)
122
+
123
+
124
+ if __name__ == "__main__":
125
+ unittest.main()
@@ -0,0 +1,10 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
@@ -0,0 +1,10 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
@@ -0,0 +1,190 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import random
15
+ import unittest
16
+
17
+ import numpy as np
18
+ import torch
19
+ from parameterized import parameterized
20
+ from torch.nn import functional as F
21
+
22
+ from monai.apps.pathology.losses import HoVerNetLoss
23
+ from monai.transforms import GaussianSmooth, Rotate
24
+ from monai.transforms.intensity.array import ComputeHoVerMaps
25
+ from monai.utils.enums import HoVerNetBranch
26
+
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
+ s = 10e-8
30
+ t = 1.0 - s
31
+ H = 40
32
+ W = 40
33
+ N = 5
34
+ B = 2
35
+
36
+
37
+ class PrepareTestInputs:
38
+
39
+ def __init__(self, inputs):
40
+ self.inputs = {HoVerNetBranch.NP: inputs[1], HoVerNetBranch.HV: inputs[3]}
41
+ self.targets = {HoVerNetBranch.NP: inputs[0], HoVerNetBranch.HV: inputs[2]}
42
+
43
+ if len(inputs) > 4:
44
+ self.targets[HoVerNetBranch.NC] = inputs[4]
45
+ self.inputs[HoVerNetBranch.NC] = inputs[5]
46
+
47
+
48
+ def test_shape_generator(num_classes=1, num_objects=3, batch_size=1, height=5, width=5, rotation=0.0, smoothing=False):
49
+ t_g = torch.zeros((batch_size, height, width), dtype=torch.int64)
50
+ t_p = None
51
+ hv_g = torch.zeros((batch_size, 2, height, width))
52
+ hv_p = torch.zeros((batch_size, 2, height, width))
53
+
54
+ rad_min = 2
55
+ rad_max = min(max(height // 3, width // 3, rad_min), 5)
56
+
57
+ for b in range(batch_size):
58
+ random.seed(10 + b)
59
+ inst_map = torch.zeros((height, width), dtype=torch.int64)
60
+ for inst_id in range(1, num_objects + 1):
61
+ x = random.randint(rad_max, width - rad_max)
62
+ y = random.randint(rad_max, height - rad_max)
63
+ rad = random.randint(rad_min, rad_max)
64
+ spy, spx = np.ogrid[-x : height - x, -y : width - y]
65
+ circle = torch.tensor((spx * spx + spy * spy) <= rad * rad)
66
+
67
+ if num_classes > 1:
68
+ t_g[b, circle] = np.ceil(random.random() * num_classes)
69
+ else:
70
+ t_g[b, circle] = 1
71
+
72
+ inst_map[circle] = inst_id
73
+
74
+ hv_g[b] = ComputeHoVerMaps()(inst_map[None])
75
+ hv_g[b] = hv_g[b].squeeze(0)
76
+ if rotation > 0.0:
77
+ hv_p[b] = Rotate(angle=rotation, keep_size=True, mode="bilinear")(hv_g[b])
78
+
79
+ n_g = t_g > 0
80
+ if rotation == 0.0:
81
+ hv_p = hv_g * 0.99
82
+
83
+ # rotation of prediction needs to happen before one-hot encoding
84
+ if rotation > 0.0:
85
+ n_p = Rotate(angle=rotation, keep_size=True, mode="nearest")(n_g)
86
+ n_p = F.one_hot(n_p.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)
87
+ if num_classes > 1:
88
+ t_p = Rotate(angle=rotation, keep_size=True, mode="nearest")(t_g)
89
+ t_p = F.one_hot(t_p.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)
90
+ t_g = F.one_hot(t_g.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)
91
+ else:
92
+ t_g = None
93
+ else:
94
+ n_p = F.one_hot(n_g.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)
95
+ if num_classes > 1:
96
+ t_p = F.one_hot(t_g.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)
97
+ t_g = F.one_hot(t_g.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)
98
+ else:
99
+ t_g = None
100
+
101
+ n_g = F.one_hot(n_g.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)
102
+
103
+ if smoothing:
104
+ n_p = GaussianSmooth()(n_p)
105
+ if num_classes > 1:
106
+ t_p = GaussianSmooth()(t_p)
107
+ hv_p = hv_p * 0.1
108
+ else:
109
+ n_p = torch.clamp(n_p, s, t)
110
+ if num_classes > 1:
111
+ t_p = torch.clamp(t_p, s, t)
112
+
113
+ # Apply log to emulate logits
114
+ if t_p is not None:
115
+ return n_g, n_p.log(), hv_g, hv_p, t_g, t_p.log()
116
+ else:
117
+ return n_g, n_p.log(), hv_g, hv_p
118
+
119
+
120
+ inputs_test = [
121
+ PrepareTestInputs(test_shape_generator(height=H, width=W)),
122
+ PrepareTestInputs(test_shape_generator(num_classes=N, height=H, width=W)),
123
+ PrepareTestInputs(test_shape_generator(num_classes=N, batch_size=B, height=H, width=W)),
124
+ PrepareTestInputs(test_shape_generator(num_classes=N, batch_size=B, height=H, width=W, rotation=0.15)),
125
+ PrepareTestInputs(test_shape_generator(num_classes=N, batch_size=B, height=H, width=W, rotation=0.2)),
126
+ PrepareTestInputs(test_shape_generator(num_classes=N, batch_size=B, height=H, width=W, rotation=0.25)),
127
+ ]
128
+
129
+ TEST_CASE_0 = [ # batch size of 1, no type prediction
130
+ {"prediction": inputs_test[0].inputs, "target": inputs_test[0].targets},
131
+ 0.003,
132
+ ]
133
+
134
+ TEST_CASE_1 = [ # batch size of 1, 2 classes with type prediction
135
+ {"prediction": inputs_test[1].inputs, "target": inputs_test[1].targets},
136
+ 0.2762,
137
+ ]
138
+
139
+ TEST_CASE_2 = [ # batch size of 2, 2 classes with type prediction
140
+ {"prediction": inputs_test[2].inputs, "target": inputs_test[2].targets},
141
+ 0.4852,
142
+ ]
143
+
144
+ TEST_CASE_3 = [ # batch size of 2, 3 classes with minor rotation of nuclear prediction
145
+ {"prediction": inputs_test[3].inputs, "target": inputs_test[3].targets},
146
+ 3.6348,
147
+ ]
148
+
149
+ TEST_CASE_4 = [ # batch size of 2, 3 classes with medium rotation of nuclear prediction
150
+ {"prediction": inputs_test[4].inputs, "target": inputs_test[4].targets},
151
+ 4.5312,
152
+ ]
153
+
154
+ TEST_CASE_5 = [ # batch size of 2, 3 classes with medium rotation of nuclear prediction
155
+ {"prediction": inputs_test[5].inputs, "target": inputs_test[5].targets},
156
+ 5.4929,
157
+ ]
158
+
159
+ CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]
160
+
161
+ ILL_CASES = [
162
+ [
163
+ {
164
+ "prediction": {"np": inputs_test[0].inputs[HoVerNetBranch.NP]},
165
+ "target": {
166
+ "np": inputs_test[0].targets[HoVerNetBranch.NP],
167
+ HoVerNetBranch.HV: inputs_test[0].targets[HoVerNetBranch.HV],
168
+ },
169
+ }
170
+ ]
171
+ ]
172
+
173
+
174
+ class TestHoverNetLoss(unittest.TestCase):
175
+
176
+ @parameterized.expand(CASES)
177
+ def test_shape(self, input_param, expected_loss):
178
+ loss = HoVerNetLoss()
179
+ result = loss(**input_param).to(device)
180
+ self.assertAlmostEqual(float(result), expected_loss, places=2)
181
+
182
+ @parameterized.expand(ILL_CASES)
183
+ def test_ill_input_hyper_params(self, input_param):
184
+ with self.assertRaises(ValueError):
185
+ loss = HoVerNetLoss()
186
+ _ = loss(**input_param).to(device)
187
+
188
+
189
+ if __name__ == "__main__":
190
+ unittest.main(argv=["first-arg-is-ignored"], exit=False)