monai-weekly 1.5.dev2506__py3-none-any.whl → 1.5.dev2508__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (787) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/transforms.py +1 -4
  4. monai/data/utils.py +6 -13
  5. monai/handlers/__init__.py +1 -0
  6. monai/handlers/average_precision.py +53 -0
  7. monai/inferers/inferer.py +10 -7
  8. monai/inferers/utils.py +1 -2
  9. monai/losses/dice.py +2 -14
  10. monai/losses/ds_loss.py +1 -3
  11. monai/metrics/__init__.py +1 -0
  12. monai/metrics/average_precision.py +187 -0
  13. monai/networks/layers/simplelayers.py +2 -14
  14. monai/networks/utils.py +4 -16
  15. monai/transforms/compose.py +28 -11
  16. monai/transforms/croppad/array.py +1 -6
  17. monai/transforms/io/array.py +0 -1
  18. monai/transforms/transform.py +15 -6
  19. monai/transforms/utility/array.py +2 -12
  20. monai/transforms/utils.py +1 -2
  21. monai/transforms/utils_pytorch_numpy_unification.py +2 -4
  22. monai/utils/enums.py +3 -2
  23. monai/utils/module.py +6 -6
  24. monai/utils/tf32.py +0 -10
  25. monai/visualize/class_activation_maps.py +5 -8
  26. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/METADATA +21 -17
  27. monai_weekly-1.5.dev2508.dist-info/RECORD +1185 -0
  28. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/top_level.txt +1 -0
  29. tests/apps/__init__.py +10 -0
  30. tests/apps/deepedit/__init__.py +10 -0
  31. tests/apps/deepedit/test_deepedit_transforms.py +314 -0
  32. tests/apps/deepgrow/__init__.py +10 -0
  33. tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
  34. tests/apps/deepgrow/transforms/__init__.py +10 -0
  35. tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
  36. tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
  37. tests/apps/detection/__init__.py +10 -0
  38. tests/apps/detection/metrics/__init__.py +10 -0
  39. tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
  40. tests/apps/detection/networks/__init__.py +10 -0
  41. tests/apps/detection/networks/test_retinanet.py +210 -0
  42. tests/apps/detection/networks/test_retinanet_detector.py +203 -0
  43. tests/apps/detection/test_box_transform.py +370 -0
  44. tests/apps/detection/utils/__init__.py +10 -0
  45. tests/apps/detection/utils/test_anchor_box.py +88 -0
  46. tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
  47. tests/apps/detection/utils/test_box_coder.py +43 -0
  48. tests/apps/detection/utils/test_detector_boxselector.py +67 -0
  49. tests/apps/detection/utils/test_detector_utils.py +96 -0
  50. tests/apps/detection/utils/test_hardnegsampler.py +54 -0
  51. tests/apps/nuclick/__init__.py +10 -0
  52. tests/apps/nuclick/test_nuclick_transforms.py +259 -0
  53. tests/apps/pathology/__init__.py +10 -0
  54. tests/apps/pathology/handlers/__init__.py +10 -0
  55. tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
  56. tests/apps/pathology/test_lesion_froc.py +333 -0
  57. tests/apps/pathology/test_pathology_prob_nms.py +55 -0
  58. tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
  59. tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
  60. tests/apps/pathology/transforms/__init__.py +10 -0
  61. tests/apps/pathology/transforms/post/__init__.py +10 -0
  62. tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
  63. tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
  64. tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
  65. tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
  66. tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
  67. tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
  68. tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
  69. tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
  70. tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
  71. tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
  72. tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
  73. tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
  74. tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
  75. tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
  76. tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
  77. tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
  78. tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
  79. tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
  80. tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
  81. tests/apps/pathology/transforms/post/test_watershed.py +60 -0
  82. tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
  83. tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
  84. tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
  85. tests/apps/reconstruction/__init__.py +10 -0
  86. tests/apps/reconstruction/nets/__init__.py +10 -0
  87. tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
  88. tests/apps/reconstruction/test_complex_utils.py +77 -0
  89. tests/apps/reconstruction/test_fastmri_reader.py +82 -0
  90. tests/apps/reconstruction/test_mri_utils.py +37 -0
  91. tests/apps/reconstruction/transforms/__init__.py +10 -0
  92. tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
  93. tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
  94. tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
  95. tests/apps/test_auto3dseg_bundlegen.py +156 -0
  96. tests/apps/test_check_hash.py +53 -0
  97. tests/apps/test_cross_validation.py +74 -0
  98. tests/apps/test_decathlondataset.py +93 -0
  99. tests/apps/test_download_and_extract.py +70 -0
  100. tests/apps/test_download_url_yandex.py +45 -0
  101. tests/apps/test_mednistdataset.py +72 -0
  102. tests/apps/test_mmar_download.py +154 -0
  103. tests/apps/test_tciadataset.py +123 -0
  104. tests/apps/vista3d/__init__.py +10 -0
  105. tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
  106. tests/apps/vista3d/test_vista3d_sampler.py +100 -0
  107. tests/apps/vista3d/test_vista3d_transforms.py +94 -0
  108. tests/bundle/__init__.py +10 -0
  109. tests/bundle/test_bundle_ckpt_export.py +107 -0
  110. tests/bundle/test_bundle_download.py +435 -0
  111. tests/bundle/test_bundle_get_data.py +94 -0
  112. tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
  113. tests/bundle/test_bundle_trt_export.py +147 -0
  114. tests/bundle/test_bundle_utils.py +149 -0
  115. tests/bundle/test_bundle_verify_metadata.py +66 -0
  116. tests/bundle/test_bundle_verify_net.py +76 -0
  117. tests/bundle/test_bundle_workflow.py +272 -0
  118. tests/bundle/test_component_locator.py +38 -0
  119. tests/bundle/test_config_item.py +138 -0
  120. tests/bundle/test_config_parser.py +392 -0
  121. tests/bundle/test_reference_resolver.py +114 -0
  122. tests/config/__init__.py +10 -0
  123. tests/config/test_cv2_dist.py +53 -0
  124. tests/engines/__init__.py +10 -0
  125. tests/engines/test_ensemble_evaluator.py +94 -0
  126. tests/engines/test_prepare_batch_default.py +76 -0
  127. tests/engines/test_prepare_batch_default_dist.py +76 -0
  128. tests/engines/test_prepare_batch_diffusion.py +104 -0
  129. tests/engines/test_prepare_batch_extra_input.py +80 -0
  130. tests/fl/__init__.py +10 -0
  131. tests/fl/monai_algo/__init__.py +10 -0
  132. tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
  133. tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
  134. tests/fl/test_fl_monai_algo_stats.py +81 -0
  135. tests/fl/utils/__init__.py +10 -0
  136. tests/fl/utils/test_fl_exchange_object.py +63 -0
  137. tests/handlers/__init__.py +10 -0
  138. tests/handlers/test_handler_average_precision.py +79 -0
  139. tests/handlers/test_handler_checkpoint_loader.py +182 -0
  140. tests/handlers/test_handler_checkpoint_saver.py +233 -0
  141. tests/handlers/test_handler_classification_saver.py +64 -0
  142. tests/handlers/test_handler_classification_saver_dist.py +77 -0
  143. tests/handlers/test_handler_clearml_image.py +65 -0
  144. tests/handlers/test_handler_clearml_stats.py +65 -0
  145. tests/handlers/test_handler_confusion_matrix.py +104 -0
  146. tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
  147. tests/handlers/test_handler_decollate_batch.py +66 -0
  148. tests/handlers/test_handler_early_stop.py +68 -0
  149. tests/handlers/test_handler_garbage_collector.py +73 -0
  150. tests/handlers/test_handler_hausdorff_distance.py +111 -0
  151. tests/handlers/test_handler_ignite_metric.py +191 -0
  152. tests/handlers/test_handler_lr_scheduler.py +94 -0
  153. tests/handlers/test_handler_mean_dice.py +98 -0
  154. tests/handlers/test_handler_mean_iou.py +76 -0
  155. tests/handlers/test_handler_metrics_reloaded.py +149 -0
  156. tests/handlers/test_handler_metrics_saver.py +89 -0
  157. tests/handlers/test_handler_metrics_saver_dist.py +120 -0
  158. tests/handlers/test_handler_mlflow.py +296 -0
  159. tests/handlers/test_handler_nvtx.py +93 -0
  160. tests/handlers/test_handler_panoptic_quality.py +89 -0
  161. tests/handlers/test_handler_parameter_scheduler.py +136 -0
  162. tests/handlers/test_handler_post_processing.py +74 -0
  163. tests/handlers/test_handler_prob_map_producer.py +111 -0
  164. tests/handlers/test_handler_regression_metrics.py +160 -0
  165. tests/handlers/test_handler_regression_metrics_dist.py +245 -0
  166. tests/handlers/test_handler_rocauc.py +48 -0
  167. tests/handlers/test_handler_rocauc_dist.py +54 -0
  168. tests/handlers/test_handler_stats.py +281 -0
  169. tests/handlers/test_handler_surface_distance.py +113 -0
  170. tests/handlers/test_handler_tb_image.py +61 -0
  171. tests/handlers/test_handler_tb_stats.py +166 -0
  172. tests/handlers/test_handler_validation.py +59 -0
  173. tests/handlers/test_trt_compile.py +145 -0
  174. tests/handlers/test_write_metrics_reports.py +68 -0
  175. tests/inferers/__init__.py +10 -0
  176. tests/inferers/test_avg_merger.py +179 -0
  177. tests/inferers/test_controlnet_inferers.py +1388 -0
  178. tests/inferers/test_diffusion_inferer.py +236 -0
  179. tests/inferers/test_latent_diffusion_inferer.py +884 -0
  180. tests/inferers/test_patch_inferer.py +309 -0
  181. tests/inferers/test_saliency_inferer.py +55 -0
  182. tests/inferers/test_slice_inferer.py +57 -0
  183. tests/inferers/test_sliding_window_inference.py +377 -0
  184. tests/inferers/test_sliding_window_splitter.py +284 -0
  185. tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
  186. tests/inferers/test_zarr_avg_merger.py +326 -0
  187. tests/integration/__init__.py +10 -0
  188. tests/integration/test_auto3dseg_ensemble.py +211 -0
  189. tests/integration/test_auto3dseg_hpo.py +189 -0
  190. tests/integration/test_deepedit_interaction.py +122 -0
  191. tests/integration/test_downsample_block.py +50 -0
  192. tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
  193. tests/integration/test_integration_autorunner.py +201 -0
  194. tests/integration/test_integration_bundle_run.py +240 -0
  195. tests/integration/test_integration_classification_2d.py +282 -0
  196. tests/integration/test_integration_determinism.py +95 -0
  197. tests/integration/test_integration_fast_train.py +231 -0
  198. tests/integration/test_integration_gpu_customization.py +159 -0
  199. tests/integration/test_integration_lazy_samples.py +219 -0
  200. tests/integration/test_integration_nnunetv2_runner.py +96 -0
  201. tests/integration/test_integration_segmentation_3d.py +304 -0
  202. tests/integration/test_integration_sliding_window.py +100 -0
  203. tests/integration/test_integration_stn.py +133 -0
  204. tests/integration/test_integration_unet_2d.py +67 -0
  205. tests/integration/test_integration_workers.py +61 -0
  206. tests/integration/test_integration_workflows.py +365 -0
  207. tests/integration/test_integration_workflows_adversarial.py +173 -0
  208. tests/integration/test_integration_workflows_gan.py +158 -0
  209. tests/integration/test_loader_semaphore.py +48 -0
  210. tests/integration/test_mapping_filed.py +122 -0
  211. tests/integration/test_meta_affine.py +183 -0
  212. tests/integration/test_metatensor_integration.py +114 -0
  213. tests/integration/test_module_list.py +76 -0
  214. tests/integration/test_one_of.py +283 -0
  215. tests/integration/test_pad_collation.py +124 -0
  216. tests/integration/test_reg_loss_integration.py +107 -0
  217. tests/integration/test_retinanet_predict_utils.py +154 -0
  218. tests/integration/test_seg_loss_integration.py +159 -0
  219. tests/integration/test_spatial_combine_transforms.py +185 -0
  220. tests/integration/test_testtimeaugmentation.py +186 -0
  221. tests/integration/test_vis_gradbased.py +69 -0
  222. tests/integration/test_vista3d_utils.py +159 -0
  223. tests/losses/__init__.py +10 -0
  224. tests/losses/deform/__init__.py +10 -0
  225. tests/losses/deform/test_bending_energy.py +88 -0
  226. tests/losses/deform/test_diffusion_loss.py +117 -0
  227. tests/losses/image_dissimilarity/__init__.py +10 -0
  228. tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
  229. tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
  230. tests/losses/test_adversarial_loss.py +94 -0
  231. tests/losses/test_barlow_twins_loss.py +109 -0
  232. tests/losses/test_cldice_loss.py +51 -0
  233. tests/losses/test_contrastive_loss.py +86 -0
  234. tests/losses/test_dice_ce_loss.py +123 -0
  235. tests/losses/test_dice_focal_loss.py +124 -0
  236. tests/losses/test_dice_loss.py +227 -0
  237. tests/losses/test_ds_loss.py +189 -0
  238. tests/losses/test_focal_loss.py +379 -0
  239. tests/losses/test_generalized_dice_focal_loss.py +85 -0
  240. tests/losses/test_generalized_dice_loss.py +221 -0
  241. tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
  242. tests/losses/test_giou_loss.py +62 -0
  243. tests/losses/test_hausdorff_loss.py +264 -0
  244. tests/losses/test_masked_dice_loss.py +152 -0
  245. tests/losses/test_masked_loss.py +87 -0
  246. tests/losses/test_multi_scale.py +86 -0
  247. tests/losses/test_nacl_loss.py +167 -0
  248. tests/losses/test_perceptual_loss.py +122 -0
  249. tests/losses/test_spectral_loss.py +86 -0
  250. tests/losses/test_ssim_loss.py +59 -0
  251. tests/losses/test_sure_loss.py +72 -0
  252. tests/losses/test_tversky_loss.py +198 -0
  253. tests/losses/test_unified_focal_loss.py +66 -0
  254. tests/metrics/__init__.py +10 -0
  255. tests/metrics/test_compute_average_precision.py +162 -0
  256. tests/metrics/test_compute_confusion_matrix.py +294 -0
  257. tests/metrics/test_compute_f_beta.py +80 -0
  258. tests/metrics/test_compute_fid_metric.py +40 -0
  259. tests/metrics/test_compute_froc.py +143 -0
  260. tests/metrics/test_compute_generalized_dice.py +240 -0
  261. tests/metrics/test_compute_meandice.py +306 -0
  262. tests/metrics/test_compute_meaniou.py +223 -0
  263. tests/metrics/test_compute_mmd_metric.py +56 -0
  264. tests/metrics/test_compute_multiscalessim_metric.py +83 -0
  265. tests/metrics/test_compute_panoptic_quality.py +113 -0
  266. tests/metrics/test_compute_regression_metrics.py +196 -0
  267. tests/metrics/test_compute_roc_auc.py +155 -0
  268. tests/metrics/test_compute_variance.py +147 -0
  269. tests/metrics/test_cumulative.py +63 -0
  270. tests/metrics/test_cumulative_average.py +74 -0
  271. tests/metrics/test_cumulative_average_dist.py +48 -0
  272. tests/metrics/test_hausdorff_distance.py +209 -0
  273. tests/metrics/test_label_quality_score.py +134 -0
  274. tests/metrics/test_loss_metric.py +57 -0
  275. tests/metrics/test_metrics_reloaded.py +96 -0
  276. tests/metrics/test_ssim_metric.py +78 -0
  277. tests/metrics/test_surface_dice.py +416 -0
  278. tests/metrics/test_surface_distance.py +186 -0
  279. tests/networks/__init__.py +10 -0
  280. tests/networks/blocks/__init__.py +10 -0
  281. tests/networks/blocks/dints_block/__init__.py +10 -0
  282. tests/networks/blocks/dints_block/test_acn_block.py +41 -0
  283. tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
  284. tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
  285. tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
  286. tests/networks/blocks/test_adn.py +86 -0
  287. tests/networks/blocks/test_convolutions.py +156 -0
  288. tests/networks/blocks/test_crf_cpu.py +513 -0
  289. tests/networks/blocks/test_crf_cuda.py +528 -0
  290. tests/networks/blocks/test_crossattention.py +185 -0
  291. tests/networks/blocks/test_denseblock.py +105 -0
  292. tests/networks/blocks/test_dynunet_block.py +116 -0
  293. tests/networks/blocks/test_fpn_block.py +88 -0
  294. tests/networks/blocks/test_localnet_block.py +121 -0
  295. tests/networks/blocks/test_mlp.py +78 -0
  296. tests/networks/blocks/test_patchembedding.py +212 -0
  297. tests/networks/blocks/test_regunet_block.py +103 -0
  298. tests/networks/blocks/test_se_block.py +85 -0
  299. tests/networks/blocks/test_se_blocks.py +78 -0
  300. tests/networks/blocks/test_segresnet_block.py +57 -0
  301. tests/networks/blocks/test_selfattention.py +232 -0
  302. tests/networks/blocks/test_simple_aspp.py +87 -0
  303. tests/networks/blocks/test_spatialattention.py +55 -0
  304. tests/networks/blocks/test_subpixel_upsample.py +87 -0
  305. tests/networks/blocks/test_text_encoding.py +49 -0
  306. tests/networks/blocks/test_transformerblock.py +90 -0
  307. tests/networks/blocks/test_unetr_block.py +158 -0
  308. tests/networks/blocks/test_upsample_block.py +134 -0
  309. tests/networks/blocks/warp/__init__.py +10 -0
  310. tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
  311. tests/networks/blocks/warp/test_warp.py +250 -0
  312. tests/networks/layers/__init__.py +10 -0
  313. tests/networks/layers/filtering/__init__.py +10 -0
  314. tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
  315. tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
  316. tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
  317. tests/networks/layers/filtering/test_phl_cpu.py +259 -0
  318. tests/networks/layers/filtering/test_phl_cuda.py +167 -0
  319. tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
  320. tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
  321. tests/networks/layers/test_affine_transform.py +385 -0
  322. tests/networks/layers/test_apply_filter.py +89 -0
  323. tests/networks/layers/test_channel_pad.py +51 -0
  324. tests/networks/layers/test_conjugate_gradient.py +56 -0
  325. tests/networks/layers/test_drop_path.py +46 -0
  326. tests/networks/layers/test_gaussian.py +317 -0
  327. tests/networks/layers/test_gaussian_filter.py +206 -0
  328. tests/networks/layers/test_get_layers.py +65 -0
  329. tests/networks/layers/test_gmm.py +314 -0
  330. tests/networks/layers/test_grid_pull.py +93 -0
  331. tests/networks/layers/test_hilbert_transform.py +131 -0
  332. tests/networks/layers/test_lltm.py +62 -0
  333. tests/networks/layers/test_median_filter.py +52 -0
  334. tests/networks/layers/test_polyval.py +55 -0
  335. tests/networks/layers/test_preset_filters.py +136 -0
  336. tests/networks/layers/test_savitzky_golay_filter.py +141 -0
  337. tests/networks/layers/test_separable_filter.py +87 -0
  338. tests/networks/layers/test_skip_connection.py +48 -0
  339. tests/networks/layers/test_vector_quantizer.py +89 -0
  340. tests/networks/layers/test_weight_init.py +50 -0
  341. tests/networks/nets/__init__.py +10 -0
  342. tests/networks/nets/dints/__init__.py +10 -0
  343. tests/networks/nets/dints/test_dints_cell.py +110 -0
  344. tests/networks/nets/dints/test_dints_mixop.py +84 -0
  345. tests/networks/nets/regunet/__init__.py +10 -0
  346. tests/networks/nets/regunet/test_localnet.py +86 -0
  347. tests/networks/nets/regunet/test_regunet.py +88 -0
  348. tests/networks/nets/test_ahnet.py +224 -0
  349. tests/networks/nets/test_attentionunet.py +88 -0
  350. tests/networks/nets/test_autoencoder.py +95 -0
  351. tests/networks/nets/test_autoencoderkl.py +337 -0
  352. tests/networks/nets/test_basic_unet.py +102 -0
  353. tests/networks/nets/test_basic_unetplusplus.py +109 -0
  354. tests/networks/nets/test_bundle_init_bundle.py +55 -0
  355. tests/networks/nets/test_cell_sam_wrapper.py +58 -0
  356. tests/networks/nets/test_controlnet.py +215 -0
  357. tests/networks/nets/test_daf3d.py +62 -0
  358. tests/networks/nets/test_densenet.py +121 -0
  359. tests/networks/nets/test_diffusion_model_unet.py +585 -0
  360. tests/networks/nets/test_dints_network.py +168 -0
  361. tests/networks/nets/test_discriminator.py +59 -0
  362. tests/networks/nets/test_dynunet.py +181 -0
  363. tests/networks/nets/test_efficientnet.py +400 -0
  364. tests/networks/nets/test_flexible_unet.py +341 -0
  365. tests/networks/nets/test_fullyconnectednet.py +69 -0
  366. tests/networks/nets/test_generator.py +59 -0
  367. tests/networks/nets/test_globalnet.py +103 -0
  368. tests/networks/nets/test_highresnet.py +67 -0
  369. tests/networks/nets/test_hovernet.py +218 -0
  370. tests/networks/nets/test_mednext.py +122 -0
  371. tests/networks/nets/test_milmodel.py +92 -0
  372. tests/networks/nets/test_net_adapter.py +68 -0
  373. tests/networks/nets/test_network_consistency.py +86 -0
  374. tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
  375. tests/networks/nets/test_quicknat.py +57 -0
  376. tests/networks/nets/test_resnet.py +340 -0
  377. tests/networks/nets/test_segresnet.py +120 -0
  378. tests/networks/nets/test_segresnet_ds.py +156 -0
  379. tests/networks/nets/test_senet.py +151 -0
  380. tests/networks/nets/test_spade_autoencoderkl.py +295 -0
  381. tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
  382. tests/networks/nets/test_spade_vaegan.py +140 -0
  383. tests/networks/nets/test_swin_unetr.py +139 -0
  384. tests/networks/nets/test_torchvision_fc_model.py +201 -0
  385. tests/networks/nets/test_transchex.py +84 -0
  386. tests/networks/nets/test_transformer.py +108 -0
  387. tests/networks/nets/test_unet.py +208 -0
  388. tests/networks/nets/test_unetr.py +137 -0
  389. tests/networks/nets/test_varautoencoder.py +127 -0
  390. tests/networks/nets/test_vista3d.py +84 -0
  391. tests/networks/nets/test_vit.py +139 -0
  392. tests/networks/nets/test_vitautoenc.py +112 -0
  393. tests/networks/nets/test_vnet.py +81 -0
  394. tests/networks/nets/test_voxelmorph.py +280 -0
  395. tests/networks/nets/test_vqvae.py +274 -0
  396. tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
  397. tests/networks/schedulers/__init__.py +10 -0
  398. tests/networks/schedulers/test_scheduler_ddim.py +83 -0
  399. tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
  400. tests/networks/schedulers/test_scheduler_pndm.py +108 -0
  401. tests/networks/test_bundle_onnx_export.py +71 -0
  402. tests/networks/test_convert_to_onnx.py +106 -0
  403. tests/networks/test_convert_to_torchscript.py +46 -0
  404. tests/networks/test_convert_to_trt.py +79 -0
  405. tests/networks/test_save_state.py +73 -0
  406. tests/networks/test_to_onehot.py +63 -0
  407. tests/networks/test_varnet.py +63 -0
  408. tests/networks/utils/__init__.py +10 -0
  409. tests/networks/utils/test_copy_model_state.py +187 -0
  410. tests/networks/utils/test_eval_mode.py +34 -0
  411. tests/networks/utils/test_freeze_layers.py +61 -0
  412. tests/networks/utils/test_replace_module.py +98 -0
  413. tests/networks/utils/test_train_mode.py +34 -0
  414. tests/optimizers/__init__.py +10 -0
  415. tests/optimizers/test_generate_param_groups.py +105 -0
  416. tests/optimizers/test_lr_finder.py +108 -0
  417. tests/optimizers/test_lr_scheduler.py +71 -0
  418. tests/optimizers/test_optim_novograd.py +100 -0
  419. tests/profile_subclass/__init__.py +10 -0
  420. tests/profile_subclass/cprofile_profiling.py +29 -0
  421. tests/profile_subclass/min_classes.py +30 -0
  422. tests/profile_subclass/profiling.py +73 -0
  423. tests/profile_subclass/pyspy_profiling.py +41 -0
  424. tests/transforms/__init__.py +10 -0
  425. tests/transforms/compose/__init__.py +10 -0
  426. tests/transforms/compose/test_compose.py +758 -0
  427. tests/transforms/compose/test_some_of.py +258 -0
  428. tests/transforms/croppad/__init__.py +10 -0
  429. tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
  430. tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
  431. tests/transforms/functional/__init__.py +10 -0
  432. tests/transforms/functional/test_apply.py +75 -0
  433. tests/transforms/functional/test_resample.py +50 -0
  434. tests/transforms/intensity/__init__.py +10 -0
  435. tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
  436. tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
  437. tests/transforms/intensity/test_foreground_mask.py +98 -0
  438. tests/transforms/intensity/test_foreground_maskd.py +106 -0
  439. tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
  440. tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
  441. tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
  442. tests/transforms/inverse/__init__.py +10 -0
  443. tests/transforms/inverse/test_inverse_array.py +76 -0
  444. tests/transforms/inverse/test_traceable_transform.py +59 -0
  445. tests/transforms/post/__init__.py +10 -0
  446. tests/transforms/post/test_label_filterd.py +78 -0
  447. tests/transforms/post/test_probnms.py +72 -0
  448. tests/transforms/post/test_probnmsd.py +79 -0
  449. tests/transforms/post/test_remove_small_objects.py +102 -0
  450. tests/transforms/spatial/__init__.py +10 -0
  451. tests/transforms/spatial/test_convert_box_points.py +119 -0
  452. tests/transforms/spatial/test_grid_patch.py +134 -0
  453. tests/transforms/spatial/test_grid_patchd.py +102 -0
  454. tests/transforms/spatial/test_rand_grid_patch.py +150 -0
  455. tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
  456. tests/transforms/spatial/test_spatial_resampled.py +124 -0
  457. tests/transforms/test_activations.py +120 -0
  458. tests/transforms/test_activationsd.py +64 -0
  459. tests/transforms/test_adaptors.py +160 -0
  460. tests/transforms/test_add_coordinate_channels.py +53 -0
  461. tests/transforms/test_add_coordinate_channelsd.py +67 -0
  462. tests/transforms/test_add_extreme_points_channel.py +80 -0
  463. tests/transforms/test_add_extreme_points_channeld.py +77 -0
  464. tests/transforms/test_adjust_contrast.py +70 -0
  465. tests/transforms/test_adjust_contrastd.py +64 -0
  466. tests/transforms/test_affine.py +245 -0
  467. tests/transforms/test_affine_grid.py +152 -0
  468. tests/transforms/test_affined.py +190 -0
  469. tests/transforms/test_as_channel_last.py +38 -0
  470. tests/transforms/test_as_channel_lastd.py +44 -0
  471. tests/transforms/test_as_discrete.py +81 -0
  472. tests/transforms/test_as_discreted.py +82 -0
  473. tests/transforms/test_border_pad.py +49 -0
  474. tests/transforms/test_border_padd.py +45 -0
  475. tests/transforms/test_bounding_rect.py +54 -0
  476. tests/transforms/test_bounding_rectd.py +53 -0
  477. tests/transforms/test_cast_to_type.py +63 -0
  478. tests/transforms/test_cast_to_typed.py +74 -0
  479. tests/transforms/test_center_scale_crop.py +55 -0
  480. tests/transforms/test_center_scale_cropd.py +56 -0
  481. tests/transforms/test_center_spatial_crop.py +56 -0
  482. tests/transforms/test_center_spatial_cropd.py +63 -0
  483. tests/transforms/test_classes_to_indices.py +93 -0
  484. tests/transforms/test_classes_to_indicesd.py +110 -0
  485. tests/transforms/test_clip_intensity_percentiles.py +196 -0
  486. tests/transforms/test_clip_intensity_percentilesd.py +193 -0
  487. tests/transforms/test_compose_get_number_conversions.py +127 -0
  488. tests/transforms/test_concat_itemsd.py +82 -0
  489. tests/transforms/test_convert_to_multi_channel.py +59 -0
  490. tests/transforms/test_convert_to_multi_channeld.py +37 -0
  491. tests/transforms/test_copy_itemsd.py +86 -0
  492. tests/transforms/test_create_grid_and_affine.py +274 -0
  493. tests/transforms/test_crop_foreground.py +164 -0
  494. tests/transforms/test_crop_foregroundd.py +205 -0
  495. tests/transforms/test_cucim_dict_transform.py +142 -0
  496. tests/transforms/test_cucim_transform.py +141 -0
  497. tests/transforms/test_data_stats.py +221 -0
  498. tests/transforms/test_data_statsd.py +249 -0
  499. tests/transforms/test_delete_itemsd.py +58 -0
  500. tests/transforms/test_detect_envelope.py +159 -0
  501. tests/transforms/test_distance_transform_edt.py +202 -0
  502. tests/transforms/test_divisible_pad.py +49 -0
  503. tests/transforms/test_divisible_padd.py +42 -0
  504. tests/transforms/test_ensure_channel_first.py +113 -0
  505. tests/transforms/test_ensure_channel_firstd.py +85 -0
  506. tests/transforms/test_ensure_type.py +94 -0
  507. tests/transforms/test_ensure_typed.py +110 -0
  508. tests/transforms/test_fg_bg_to_indices.py +83 -0
  509. tests/transforms/test_fg_bg_to_indicesd.py +78 -0
  510. tests/transforms/test_fill_holes.py +207 -0
  511. tests/transforms/test_fill_holesd.py +209 -0
  512. tests/transforms/test_flatten_sub_keysd.py +64 -0
  513. tests/transforms/test_flip.py +83 -0
  514. tests/transforms/test_flipd.py +90 -0
  515. tests/transforms/test_fourier.py +70 -0
  516. tests/transforms/test_gaussian_sharpen.py +92 -0
  517. tests/transforms/test_gaussian_sharpend.py +92 -0
  518. tests/transforms/test_gaussian_smooth.py +96 -0
  519. tests/transforms/test_gaussian_smoothd.py +96 -0
  520. tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
  521. tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
  522. tests/transforms/test_generate_spatial_bounding_box.py +114 -0
  523. tests/transforms/test_get_extreme_points.py +57 -0
  524. tests/transforms/test_gibbs_noise.py +73 -0
  525. tests/transforms/test_gibbs_noised.py +88 -0
  526. tests/transforms/test_grid_distortion.py +113 -0
  527. tests/transforms/test_grid_distortiond.py +87 -0
  528. tests/transforms/test_grid_split.py +88 -0
  529. tests/transforms/test_grid_splitd.py +96 -0
  530. tests/transforms/test_histogram_normalize.py +59 -0
  531. tests/transforms/test_histogram_normalized.py +59 -0
  532. tests/transforms/test_image_filter.py +259 -0
  533. tests/transforms/test_intensity_stats.py +73 -0
  534. tests/transforms/test_intensity_statsd.py +90 -0
  535. tests/transforms/test_inverse.py +521 -0
  536. tests/transforms/test_inverse_collation.py +147 -0
  537. tests/transforms/test_invert.py +105 -0
  538. tests/transforms/test_invertd.py +142 -0
  539. tests/transforms/test_k_space_spike_noise.py +81 -0
  540. tests/transforms/test_k_space_spike_noised.py +98 -0
  541. tests/transforms/test_keep_largest_connected_component.py +419 -0
  542. tests/transforms/test_keep_largest_connected_componentd.py +348 -0
  543. tests/transforms/test_label_filter.py +78 -0
  544. tests/transforms/test_label_to_contour.py +179 -0
  545. tests/transforms/test_label_to_contourd.py +182 -0
  546. tests/transforms/test_label_to_mask.py +69 -0
  547. tests/transforms/test_label_to_maskd.py +70 -0
  548. tests/transforms/test_load_image.py +502 -0
  549. tests/transforms/test_load_imaged.py +198 -0
  550. tests/transforms/test_load_spacing_orientation.py +149 -0
  551. tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
  552. tests/transforms/test_map_binary_to_indices.py +75 -0
  553. tests/transforms/test_map_classes_to_indices.py +135 -0
  554. tests/transforms/test_map_label_value.py +89 -0
  555. tests/transforms/test_map_label_valued.py +85 -0
  556. tests/transforms/test_map_transform.py +45 -0
  557. tests/transforms/test_mask_intensity.py +74 -0
  558. tests/transforms/test_mask_intensityd.py +68 -0
  559. tests/transforms/test_mean_ensemble.py +77 -0
  560. tests/transforms/test_mean_ensembled.py +91 -0
  561. tests/transforms/test_median_smooth.py +41 -0
  562. tests/transforms/test_median_smoothd.py +65 -0
  563. tests/transforms/test_morphological_ops.py +101 -0
  564. tests/transforms/test_nifti_endianness.py +107 -0
  565. tests/transforms/test_normalize_intensity.py +143 -0
  566. tests/transforms/test_normalize_intensityd.py +81 -0
  567. tests/transforms/test_nvtx_decorator.py +289 -0
  568. tests/transforms/test_nvtx_transform.py +143 -0
  569. tests/transforms/test_orientation.py +247 -0
  570. tests/transforms/test_orientationd.py +112 -0
  571. tests/transforms/test_rand_adjust_contrast.py +45 -0
  572. tests/transforms/test_rand_adjust_contrastd.py +44 -0
  573. tests/transforms/test_rand_affine.py +201 -0
  574. tests/transforms/test_rand_affine_grid.py +212 -0
  575. tests/transforms/test_rand_affined.py +281 -0
  576. tests/transforms/test_rand_axis_flip.py +50 -0
  577. tests/transforms/test_rand_axis_flipd.py +50 -0
  578. tests/transforms/test_rand_bias_field.py +69 -0
  579. tests/transforms/test_rand_bias_fieldd.py +65 -0
  580. tests/transforms/test_rand_coarse_dropout.py +110 -0
  581. tests/transforms/test_rand_coarse_dropoutd.py +107 -0
  582. tests/transforms/test_rand_coarse_shuffle.py +65 -0
  583. tests/transforms/test_rand_coarse_shuffled.py +59 -0
  584. tests/transforms/test_rand_crop_by_label_classes.py +170 -0
  585. tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
  586. tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
  587. tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
  588. tests/transforms/test_rand_cucim_dict_transform.py +162 -0
  589. tests/transforms/test_rand_cucim_transform.py +162 -0
  590. tests/transforms/test_rand_deform_grid.py +138 -0
  591. tests/transforms/test_rand_elastic_2d.py +127 -0
  592. tests/transforms/test_rand_elastic_3d.py +104 -0
  593. tests/transforms/test_rand_elasticd_2d.py +177 -0
  594. tests/transforms/test_rand_elasticd_3d.py +156 -0
  595. tests/transforms/test_rand_flip.py +60 -0
  596. tests/transforms/test_rand_flipd.py +55 -0
  597. tests/transforms/test_rand_gaussian_noise.py +48 -0
  598. tests/transforms/test_rand_gaussian_noised.py +54 -0
  599. tests/transforms/test_rand_gaussian_sharpen.py +140 -0
  600. tests/transforms/test_rand_gaussian_sharpend.py +143 -0
  601. tests/transforms/test_rand_gaussian_smooth.py +98 -0
  602. tests/transforms/test_rand_gaussian_smoothd.py +98 -0
  603. tests/transforms/test_rand_gibbs_noise.py +103 -0
  604. tests/transforms/test_rand_gibbs_noised.py +117 -0
  605. tests/transforms/test_rand_grid_distortion.py +99 -0
  606. tests/transforms/test_rand_grid_distortiond.py +90 -0
  607. tests/transforms/test_rand_histogram_shift.py +92 -0
  608. tests/transforms/test_rand_k_space_spike_noise.py +92 -0
  609. tests/transforms/test_rand_k_space_spike_noised.py +76 -0
  610. tests/transforms/test_rand_rician_noise.py +52 -0
  611. tests/transforms/test_rand_rician_noised.py +52 -0
  612. tests/transforms/test_rand_rotate.py +166 -0
  613. tests/transforms/test_rand_rotate90.py +100 -0
  614. tests/transforms/test_rand_rotate90d.py +112 -0
  615. tests/transforms/test_rand_rotated.py +187 -0
  616. tests/transforms/test_rand_scale_crop.py +78 -0
  617. tests/transforms/test_rand_scale_cropd.py +98 -0
  618. tests/transforms/test_rand_scale_intensity.py +54 -0
  619. tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
  620. tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
  621. tests/transforms/test_rand_scale_intensityd.py +53 -0
  622. tests/transforms/test_rand_shift_intensity.py +52 -0
  623. tests/transforms/test_rand_shift_intensityd.py +67 -0
  624. tests/transforms/test_rand_simulate_low_resolution.py +83 -0
  625. tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
  626. tests/transforms/test_rand_spatial_crop.py +107 -0
  627. tests/transforms/test_rand_spatial_crop_samples.py +128 -0
  628. tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
  629. tests/transforms/test_rand_spatial_cropd.py +112 -0
  630. tests/transforms/test_rand_std_shift_intensity.py +43 -0
  631. tests/transforms/test_rand_std_shift_intensityd.py +38 -0
  632. tests/transforms/test_rand_zoom.py +105 -0
  633. tests/transforms/test_rand_zoomd.py +108 -0
  634. tests/transforms/test_randidentity.py +49 -0
  635. tests/transforms/test_random_order.py +144 -0
  636. tests/transforms/test_randtorchvisiond.py +65 -0
  637. tests/transforms/test_regularization.py +139 -0
  638. tests/transforms/test_remove_repeated_channel.py +34 -0
  639. tests/transforms/test_remove_repeated_channeld.py +44 -0
  640. tests/transforms/test_repeat_channel.py +34 -0
  641. tests/transforms/test_repeat_channeld.py +41 -0
  642. tests/transforms/test_resample_backends.py +65 -0
  643. tests/transforms/test_resample_to_match.py +110 -0
  644. tests/transforms/test_resample_to_matchd.py +93 -0
  645. tests/transforms/test_resampler.py +165 -0
  646. tests/transforms/test_resize.py +140 -0
  647. tests/transforms/test_resize_with_pad_or_crop.py +91 -0
  648. tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
  649. tests/transforms/test_resized.py +163 -0
  650. tests/transforms/test_rotate.py +160 -0
  651. tests/transforms/test_rotate90.py +212 -0
  652. tests/transforms/test_rotate90d.py +106 -0
  653. tests/transforms/test_rotated.py +179 -0
  654. tests/transforms/test_save_classificationd.py +109 -0
  655. tests/transforms/test_save_image.py +80 -0
  656. tests/transforms/test_save_imaged.py +130 -0
  657. tests/transforms/test_savitzky_golay_smooth.py +73 -0
  658. tests/transforms/test_savitzky_golay_smoothd.py +73 -0
  659. tests/transforms/test_scale_intensity.py +76 -0
  660. tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
  661. tests/transforms/test_scale_intensity_range.py +41 -0
  662. tests/transforms/test_scale_intensity_ranged.py +40 -0
  663. tests/transforms/test_scale_intensityd.py +57 -0
  664. tests/transforms/test_select_itemsd.py +41 -0
  665. tests/transforms/test_shift_intensity.py +31 -0
  666. tests/transforms/test_shift_intensityd.py +44 -0
  667. tests/transforms/test_signal_continuouswavelet.py +44 -0
  668. tests/transforms/test_signal_fillempty.py +52 -0
  669. tests/transforms/test_signal_fillemptyd.py +60 -0
  670. tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
  671. tests/transforms/test_signal_rand_add_sine.py +52 -0
  672. tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
  673. tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
  674. tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
  675. tests/transforms/test_signal_rand_drop.py +50 -0
  676. tests/transforms/test_signal_rand_scale.py +52 -0
  677. tests/transforms/test_signal_rand_shift.py +55 -0
  678. tests/transforms/test_signal_remove_frequency.py +71 -0
  679. tests/transforms/test_smooth_field.py +177 -0
  680. tests/transforms/test_sobel_gradient.py +189 -0
  681. tests/transforms/test_sobel_gradientd.py +212 -0
  682. tests/transforms/test_spacing.py +381 -0
  683. tests/transforms/test_spacingd.py +178 -0
  684. tests/transforms/test_spatial_crop.py +82 -0
  685. tests/transforms/test_spatial_cropd.py +74 -0
  686. tests/transforms/test_spatial_pad.py +57 -0
  687. tests/transforms/test_spatial_padd.py +43 -0
  688. tests/transforms/test_spatial_resample.py +235 -0
  689. tests/transforms/test_squeezedim.py +62 -0
  690. tests/transforms/test_squeezedimd.py +98 -0
  691. tests/transforms/test_std_shift_intensity.py +76 -0
  692. tests/transforms/test_std_shift_intensityd.py +74 -0
  693. tests/transforms/test_threshold_intensity.py +38 -0
  694. tests/transforms/test_threshold_intensityd.py +58 -0
  695. tests/transforms/test_to_contiguous.py +47 -0
  696. tests/transforms/test_to_cupy.py +112 -0
  697. tests/transforms/test_to_cupyd.py +76 -0
  698. tests/transforms/test_to_device.py +42 -0
  699. tests/transforms/test_to_deviced.py +37 -0
  700. tests/transforms/test_to_numpy.py +85 -0
  701. tests/transforms/test_to_numpyd.py +68 -0
  702. tests/transforms/test_to_pil.py +52 -0
  703. tests/transforms/test_to_pild.py +50 -0
  704. tests/transforms/test_to_tensor.py +60 -0
  705. tests/transforms/test_to_tensord.py +71 -0
  706. tests/transforms/test_torchvision.py +66 -0
  707. tests/transforms/test_torchvisiond.py +63 -0
  708. tests/transforms/test_transform.py +62 -0
  709. tests/transforms/test_transpose.py +41 -0
  710. tests/transforms/test_transposed.py +52 -0
  711. tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
  712. tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
  713. tests/transforms/test_vote_ensemble.py +84 -0
  714. tests/transforms/test_vote_ensembled.py +107 -0
  715. tests/transforms/test_with_allow_missing_keys.py +76 -0
  716. tests/transforms/test_zoom.py +120 -0
  717. tests/transforms/test_zoomd.py +94 -0
  718. tests/transforms/transform/__init__.py +10 -0
  719. tests/transforms/transform/test_randomizable.py +52 -0
  720. tests/transforms/transform/test_randomizable_transform_type.py +37 -0
  721. tests/transforms/utility/__init__.py +10 -0
  722. tests/transforms/utility/test_apply_transform_to_points.py +81 -0
  723. tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
  724. tests/transforms/utility/test_identity.py +29 -0
  725. tests/transforms/utility/test_identityd.py +30 -0
  726. tests/transforms/utility/test_lambda.py +71 -0
  727. tests/transforms/utility/test_lambdad.py +83 -0
  728. tests/transforms/utility/test_rand_lambda.py +87 -0
  729. tests/transforms/utility/test_rand_lambdad.py +77 -0
  730. tests/transforms/utility/test_simulatedelay.py +36 -0
  731. tests/transforms/utility/test_simulatedelayd.py +36 -0
  732. tests/transforms/utility/test_splitdim.py +52 -0
  733. tests/transforms/utility/test_splitdimd.py +96 -0
  734. tests/transforms/utils/__init__.py +10 -0
  735. tests/transforms/utils/test_correct_crop_centers.py +36 -0
  736. tests/transforms/utils/test_get_unique_labels.py +45 -0
  737. tests/transforms/utils/test_print_transform_backends.py +29 -0
  738. tests/transforms/utils/test_soft_clip.py +125 -0
  739. tests/utils/__init__.py +10 -0
  740. tests/utils/enums/__init__.py +10 -0
  741. tests/utils/enums/test_hovernet_loss.py +190 -0
  742. tests/utils/enums/test_ordering.py +289 -0
  743. tests/utils/enums/test_wsireader.py +663 -0
  744. tests/utils/misc/__init__.py +10 -0
  745. tests/utils/misc/test_ensure_tuple.py +53 -0
  746. tests/utils/misc/test_monai_env_vars.py +44 -0
  747. tests/utils/misc/test_monai_utils_misc.py +103 -0
  748. tests/utils/misc/test_str2bool.py +34 -0
  749. tests/utils/misc/test_str2list.py +33 -0
  750. tests/utils/test_alias.py +44 -0
  751. tests/utils/test_component_store.py +73 -0
  752. tests/utils/test_deprecated.py +455 -0
  753. tests/utils/test_enum_bound_interp.py +75 -0
  754. tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
  755. tests/utils/test_get_package_version.py +34 -0
  756. tests/utils/test_handler_logfile.py +84 -0
  757. tests/utils/test_handler_metric_logger.py +62 -0
  758. tests/utils/test_list_to_dict.py +43 -0
  759. tests/utils/test_look_up_option.py +87 -0
  760. tests/utils/test_optional_import.py +80 -0
  761. tests/utils/test_pad_mode.py +39 -0
  762. tests/utils/test_profiling.py +208 -0
  763. tests/utils/test_rankfilter_dist.py +77 -0
  764. tests/utils/test_require_pkg.py +83 -0
  765. tests/utils/test_sample_slices.py +43 -0
  766. tests/utils/test_set_determinism.py +74 -0
  767. tests/utils/test_squeeze_unsqueeze.py +71 -0
  768. tests/utils/test_state_cacher.py +67 -0
  769. tests/utils/test_torchscript_utils.py +113 -0
  770. tests/utils/test_version.py +91 -0
  771. tests/utils/test_version_after.py +65 -0
  772. tests/utils/type_conversion/__init__.py +10 -0
  773. tests/utils/type_conversion/test_convert_data_type.py +152 -0
  774. tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
  775. tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
  776. tests/visualize/__init__.py +10 -0
  777. tests/visualize/test_img2tensorboard.py +46 -0
  778. tests/visualize/test_occlusion_sensitivity.py +128 -0
  779. tests/visualize/test_plot_2d_or_3d_image.py +74 -0
  780. tests/visualize/test_vis_cam.py +98 -0
  781. tests/visualize/test_vis_gradcam.py +211 -0
  782. tests/visualize/utils/__init__.py +10 -0
  783. tests/visualize/utils/test_blend_images.py +63 -0
  784. tests/visualize/utils/test_matshow3d.py +133 -0
  785. monai_weekly-1.5.dev2506.dist-info/RECORD +0 -427
  786. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/LICENSE +0 -0
  787. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/WHEEL +0 -0
@@ -0,0 +1,296 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import glob
15
+ import os
16
+ import shutil
17
+ import tempfile
18
+ import unittest
19
+ from concurrent.futures import ThreadPoolExecutor
20
+ from unittest.mock import MagicMock
21
+
22
+ import numpy as np
23
+ from ignite.engine import Engine, Events
24
+ from parameterized import parameterized
25
+
26
+ from monai.apps import download_and_extract
27
+ from monai.bundle import ConfigWorkflow, download
28
+ from monai.handlers import MLFlowHandler
29
+ from monai.utils import optional_import, path_to_uri
30
+ from tests.test_utils import skip_if_downloading_fails, skip_if_quick
31
+
32
+ _, has_dataset_tracking = optional_import("mlflow", "2.4.0")
33
+
34
+
35
+ def get_event_filter(e):
36
+ def event_filter(_, event):
37
+ if event in e:
38
+ return True
39
+ return False
40
+
41
+ return event_filter
42
+
43
+
44
+ def dummy_train(tracking_folder):
45
+ tempdir = tempfile.mkdtemp()
46
+
47
+ # set up engine
48
+ def _train_func(engine, batch):
49
+ return [batch + 1.0]
50
+
51
+ engine = Engine(_train_func)
52
+
53
+ # set up testing handler
54
+ test_path = os.path.join(tempdir, tracking_folder)
55
+ handler = MLFlowHandler(
56
+ iteration_log=False,
57
+ epoch_log=True,
58
+ tracking_uri=path_to_uri(test_path),
59
+ state_attributes=["test"],
60
+ close_on_complete=True,
61
+ )
62
+ handler.attach(engine)
63
+ engine.run(range(3), max_epochs=2)
64
+ return test_path
65
+
66
+
67
+ class TestHandlerMLFlow(unittest.TestCase):
68
+ def setUp(self):
69
+ self.tmpdir_list = []
70
+
71
+ def tearDown(self):
72
+ for tmpdir in self.tmpdir_list:
73
+ if tmpdir and os.path.exists(tmpdir):
74
+ shutil.rmtree(tmpdir)
75
+
76
+ def test_multi_run(self):
77
+ with tempfile.TemporaryDirectory() as tempdir:
78
+ # set up the train function for engine
79
+ def _train_func(engine, batch):
80
+ return [batch + 1.0]
81
+
82
+ # create and run an engine several times to get several runs
83
+ create_engine_times = 3
84
+ for _ in range(create_engine_times):
85
+ engine = Engine(_train_func)
86
+
87
+ @engine.on(Events.EPOCH_COMPLETED)
88
+ def _update_metric(engine):
89
+ current_metric = engine.state.metrics.get("acc", 0.1)
90
+ engine.state.metrics["acc"] = current_metric + 0.1
91
+ engine.state.test = current_metric
92
+
93
+ # set up testing handler
94
+ test_path = os.path.join(tempdir, "mlflow_test")
95
+ handler = MLFlowHandler(
96
+ iteration_log=False,
97
+ epoch_log=True,
98
+ tracking_uri=path_to_uri(test_path),
99
+ state_attributes=["test"],
100
+ close_on_complete=True,
101
+ )
102
+ handler.attach(engine)
103
+ engine.run(range(3), max_epochs=2)
104
+ run_cnt = len(handler.client.search_runs(handler.experiment.experiment_id))
105
+ handler.close()
106
+ # the run count should equal to the times of creating engine
107
+ self.assertEqual(create_engine_times, run_cnt)
108
+
109
+ def test_metrics_track(self):
110
+ experiment_param = {"backbone": "efficientnet_b0"}
111
+ with tempfile.TemporaryDirectory() as tempdir:
112
+ # set up engine
113
+ def _train_func(engine, batch):
114
+ return [batch + 1.0]
115
+
116
+ engine = Engine(_train_func)
117
+
118
+ # set up dummy metric
119
+ @engine.on(Events.EPOCH_COMPLETED)
120
+ def _update_metric(engine):
121
+ current_metric = engine.state.metrics.get("acc", 0.1)
122
+ engine.state.metrics["acc"] = current_metric + 0.1
123
+ # log nested metrics
124
+ engine.state.metrics["acc_per_label"] = {
125
+ "label_0": current_metric + 0.1,
126
+ "label_1": current_metric + 0.2,
127
+ }
128
+ engine.state.test = current_metric
129
+
130
+ # set up testing handler
131
+ test_path = os.path.join(tempdir, "mlflow_test")
132
+ artifact_path = os.path.join(tempdir, "artifacts")
133
+ os.makedirs(artifact_path, exist_ok=True)
134
+ dummy_numpy = np.zeros((64, 64, 3))
135
+ dummy_path = os.path.join(artifact_path, "tmp.npy")
136
+ np.save(dummy_path, dummy_numpy)
137
+ handler = MLFlowHandler(
138
+ iteration_log=False,
139
+ epoch_log=True,
140
+ tracking_uri=path_to_uri(test_path),
141
+ state_attributes=["test"],
142
+ experiment_param=experiment_param,
143
+ artifacts=[artifact_path],
144
+ close_on_complete=False,
145
+ )
146
+ handler.attach(engine)
147
+ engine.run(range(3), max_epochs=2)
148
+ cur_run = handler.client.get_run(handler.cur_run.info.run_id)
149
+ self.assertTrue("label_0" in cur_run.data.metrics.keys())
150
+ handler.close()
151
+ # check logging output
152
+ self.assertTrue(len(glob.glob(test_path)) > 0)
153
+
154
+ @parameterized.expand([[True], [get_event_filter([1, 2])]])
155
+ def test_metrics_track_mock(self, epoch_log):
156
+ experiment_param = {"backbone": "efficientnet_b0"}
157
+ with tempfile.TemporaryDirectory() as tempdir:
158
+ # set up engine
159
+ def _train_func(engine, batch):
160
+ return [batch + 1.0]
161
+
162
+ engine = Engine(_train_func)
163
+
164
+ # set up dummy metric
165
+ @engine.on(Events.EPOCH_COMPLETED)
166
+ def _update_metric(engine):
167
+ current_metric = engine.state.metrics.get("acc", 0.1)
168
+ engine.state.metrics["acc"] = current_metric + 0.1
169
+ engine.state.test = current_metric
170
+
171
+ # set up testing handler
172
+ test_path = os.path.join(tempdir, "mlflow_test")
173
+ handler = MLFlowHandler(
174
+ iteration_log=False,
175
+ epoch_log=epoch_log,
176
+ tracking_uri=path_to_uri(test_path),
177
+ state_attributes=["test"],
178
+ experiment_param=experiment_param,
179
+ close_on_complete=True,
180
+ )
181
+ handler._default_epoch_log = MagicMock()
182
+ handler.attach(engine)
183
+
184
+ max_epochs = 4
185
+ engine.run(range(3), max_epochs=max_epochs)
186
+ handler.close()
187
+ # check logging output
188
+ if epoch_log is True:
189
+ self.assertEqual(handler._default_epoch_log.call_count, max_epochs)
190
+ else:
191
+ self.assertEqual(handler._default_epoch_log.call_count, 2) # 2 = len([1, 2]) from event_filter
192
+
193
+ @parameterized.expand([[True], [get_event_filter([1, 3])]])
194
+ def test_metrics_track_iters_mock(self, iteration_log):
195
+ experiment_param = {"backbone": "efficientnet_b0"}
196
+ with tempfile.TemporaryDirectory() as tempdir:
197
+ # set up engine
198
+ def _train_func(engine, batch):
199
+ return [batch + 1.0]
200
+
201
+ engine = Engine(_train_func)
202
+
203
+ # set up dummy metric
204
+ @engine.on(Events.EPOCH_COMPLETED)
205
+ def _update_metric(engine):
206
+ current_metric = engine.state.metrics.get("acc", 0.1)
207
+ engine.state.metrics["acc"] = current_metric + 0.1
208
+ engine.state.test = current_metric
209
+
210
+ # set up testing handler
211
+ test_path = os.path.join(tempdir, "mlflow_test")
212
+ handler = MLFlowHandler(
213
+ iteration_log=iteration_log,
214
+ epoch_log=False,
215
+ tracking_uri=path_to_uri(test_path),
216
+ state_attributes=["test"],
217
+ experiment_param=experiment_param,
218
+ close_on_complete=True,
219
+ )
220
+ handler._default_iteration_log = MagicMock()
221
+ handler.attach(engine)
222
+
223
+ num_iters = 3
224
+ max_epochs = 2
225
+ engine.run(range(num_iters), max_epochs=max_epochs)
226
+ handler.close()
227
+ # check logging output
228
+ if iteration_log is True:
229
+ self.assertEqual(handler._default_iteration_log.call_count, num_iters * max_epochs)
230
+ else:
231
+ self.assertEqual(handler._default_iteration_log.call_count, 2) # 2 = len([1, 3]) from event_filter
232
+
233
+ def test_multi_thread(self):
234
+ test_uri_list = ["monai_mlflow_test1", "monai_mlflow_test2"]
235
+ with ThreadPoolExecutor(2, "Training") as executor:
236
+ futures = {}
237
+ for t in test_uri_list:
238
+ futures[t] = executor.submit(dummy_train, t)
239
+
240
+ for _, future in futures.items():
241
+ res = future.result()
242
+ self.tmpdir_list.append(res)
243
+ self.assertTrue(len(glob.glob(res)) > 0)
244
+
245
+ @skip_if_quick
246
+ @unittest.skipUnless(has_dataset_tracking, reason="Requires mlflow version >= 2.4.0.")
247
+ def test_dataset_tracking(self):
248
+ test_bundle_name = "endoscopic_tool_segmentation"
249
+ with tempfile.TemporaryDirectory() as tempdir:
250
+ resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/endoscopic_tool_dataset.zip"
251
+ md5 = "f82da47259c0a617202fb54624798a55"
252
+ compressed_file = os.path.join(tempdir, "endoscopic_tool_segmentation.zip")
253
+ data_dir = os.path.join(tempdir, "endoscopic_tool_dataset")
254
+ with skip_if_downloading_fails():
255
+ if not os.path.exists(data_dir):
256
+ download_and_extract(resource, compressed_file, tempdir, md5)
257
+
258
+ download(test_bundle_name, bundle_dir=tempdir)
259
+
260
+ bundle_root = os.path.join(tempdir, test_bundle_name)
261
+ config_file = os.path.join(bundle_root, "configs/inference.json")
262
+ meta_file = os.path.join(bundle_root, "configs/metadata.json")
263
+ logging_file = os.path.join(bundle_root, "configs/logging.conf")
264
+ workflow = ConfigWorkflow(
265
+ workflow_type="infer",
266
+ config_file=config_file,
267
+ meta_file=meta_file,
268
+ logging_file=logging_file,
269
+ init_id="initialize",
270
+ run_id="run",
271
+ final_id="finalize",
272
+ )
273
+
274
+ tracking_path = os.path.join(bundle_root, "eval")
275
+ workflow.bundle_root = bundle_root
276
+ workflow.dataset_dir = data_dir
277
+ workflow.initialize()
278
+ infer_dataset = workflow.dataset
279
+ mlflow_handler = MLFlowHandler(
280
+ iteration_log=False,
281
+ epoch_log=False,
282
+ dataset_dict={"test": infer_dataset},
283
+ tracking_uri=path_to_uri(tracking_path),
284
+ )
285
+ mlflow_handler.attach(workflow.evaluator)
286
+ workflow.run()
287
+ workflow.finalize()
288
+
289
+ cur_run = mlflow_handler.client.get_run(mlflow_handler.cur_run.info.run_id)
290
+ logged_nontrain_set = [x for x in cur_run.inputs.dataset_inputs if x.dataset.name.startswith("test")]
291
+ self.assertEqual(len(logged_nontrain_set), 1)
292
+ mlflow_handler.close()
293
+
294
+
295
+ if __name__ == "__main__":
296
+ unittest.main()
@@ -0,0 +1,93 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import torch
17
+ from ignite.engine import Events
18
+ from parameterized import parameterized
19
+
20
+ from monai.engines import SupervisedEvaluator
21
+ from monai.handlers import StatsHandler, from_engine
22
+ from monai.handlers.nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler
23
+ from monai.utils import optional_import
24
+ from tests.test_utils import assert_allclose
25
+
26
+ _, has_nvtx = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?")
27
+
28
+ TENSOR_0 = torch.tensor([[[[1.0], [2.0]], [[3.0], [4.0]]]])
29
+
30
+ TENSOR_1 = torch.tensor([[[[0.0], [-2.0]], [[-3.0], [4.0]]]])
31
+
32
+ TENSOR_1_EXPECTED = torch.tensor([[[1.0], [0.5]], [[0.25], [5.0]]])
33
+
34
+ TEST_CASE_0 = [[{"image": TENSOR_0}], TENSOR_0[0] + 1.0]
35
+ TEST_CASE_1 = [[{"image": TENSOR_1}], TENSOR_1_EXPECTED]
36
+
37
+
38
+ class TestHandlerDecollateBatch(unittest.TestCase):
39
+ @parameterized.expand([TEST_CASE_0, TEST_CASE_1])
40
+ @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX!")
41
+ def test_compute(self, data, expected):
42
+ # Set up handlers
43
+ handlers = [
44
+ # Mark with Ignite Event
45
+ MarkHandler(Events.STARTED),
46
+ # Mark with literal
47
+ MarkHandler("EPOCH_STARTED"),
48
+ # Mark with literal and providing the message
49
+ MarkHandler("EPOCH_STARTED", "Start of the epoch"),
50
+ # Define a range using one prefix (between BATCH_STARTED and BATCH_COMPLETED)
51
+ RangeHandler("Batch"),
52
+ # Define a range using a pair of events
53
+ RangeHandler((Events.STARTED, Events.COMPLETED)),
54
+ # Define a range using a pair of literals
55
+ RangeHandler(("GET_BATCH_STARTED", "GET_BATCH_COMPLETED"), msg="Batching!"),
56
+ # Define a range using a pair of literal and events
57
+ RangeHandler(("GET_BATCH_STARTED", Events.COMPLETED)),
58
+ # Define the start of range using literal
59
+ RangePushHandler("ITERATION_STARTED"),
60
+ # Define the start of range using event
61
+ RangePushHandler(Events.ITERATION_STARTED, "Iteration 2"),
62
+ # Define the start of range using literals and providing message
63
+ RangePushHandler("EPOCH_STARTED", "Epoch 2"),
64
+ # Define the end of range using Ignite Event
65
+ RangePopHandler(Events.ITERATION_COMPLETED),
66
+ RangePopHandler(Events.EPOCH_COMPLETED),
67
+ # Define the end of range using literal
68
+ RangePopHandler("ITERATION_COMPLETED"),
69
+ # Other handlers
70
+ StatsHandler(tag_name="train", output_transform=from_engine(["label"], first=True)),
71
+ ]
72
+
73
+ # Set up an engine
74
+ engine = SupervisedEvaluator(
75
+ device=torch.device("cpu:0"),
76
+ val_data_loader=data,
77
+ epoch_length=1,
78
+ network=torch.nn.PReLU(),
79
+ postprocessing=lambda x: dict(pred=x["pred"] + 1.0),
80
+ decollate=True,
81
+ val_handlers=handlers,
82
+ )
83
+ # Run the engine
84
+ engine.run()
85
+
86
+ # Get the output from the engine
87
+ output = engine.state.output[0]
88
+
89
+ assert_allclose(output["pred"], expected)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ unittest.main()
@@ -0,0 +1,89 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import torch
17
+ from ignite.engine import Engine, Events
18
+ from parameterized import parameterized
19
+
20
+ from monai.handlers import PanopticQuality, from_engine
21
+ from tests.test_utils import SkipIfNoModule, assert_allclose
22
+
23
+ sample_1_pred = torch.as_tensor(
24
+ [[[0, 1, 1, 1], [0, 0, 5, 5], [2, 0, 3, 3], [2, 2, 2, 0]], [[0, 1, 1, 1], [0, 0, 0, 0], [2, 0, 3, 3], [4, 2, 2, 0]]]
25
+ )
26
+
27
+ sample_1_gt = torch.as_tensor(
28
+ [[[0, 6, 6, 6], [1, 0, 5, 5], [1, 0, 3, 3], [1, 3, 2, 0]], [[0, 1, 1, 1], [0, 0, 1, 1], [2, 0, 3, 3], [4, 4, 4, 3]]]
29
+ )
30
+
31
+ sample_2_pred = torch.as_tensor(
32
+ [[[3, 1, 1, 1], [3, 1, 1, 4], [3, 1, 4, 4], [3, 2, 2, 4]], [[0, 1, 1, 1], [2, 2, 2, 2], [2, 0, 0, 3], [4, 2, 2, 3]]]
33
+ )
34
+
35
+ sample_2_gt = torch.as_tensor(
36
+ [[[0, 6, 6, 6], [1, 0, 5, 5], [1, 0, 3, 3], [1, 3, 2, 0]], [[0, 1, 1, 1], [2, 1, 1, 3], [2, 0, 0, 3], [4, 2, 2, 3]]]
37
+ )
38
+
39
+ TEST_CASE_1 = [{"num_classes": 4, "output_transform": from_engine(["pred", "label"])}, [0.6667, 0.1538, 0.6667, 0.5714]]
40
+ TEST_CASE_2 = [
41
+ {
42
+ "num_classes": 5,
43
+ "output_transform": from_engine(["pred", "label"]),
44
+ "metric_name": "rq",
45
+ "match_iou_threshold": 0.3,
46
+ },
47
+ [0.6667, 0.7692, 0.8889, 0.5714, 0.0000],
48
+ ]
49
+ TEST_CASE_3 = [
50
+ {
51
+ "num_classes": 5,
52
+ "reduction": "mean",
53
+ "output_transform": from_engine(["pred", "label"]),
54
+ "metric_name": "SQ",
55
+ "match_iou_threshold": 0.2,
56
+ },
57
+ 0.8235,
58
+ ]
59
+
60
+
61
+ @SkipIfNoModule("scipy.optimize")
62
+ class TestHandlerPanopticQuality(unittest.TestCase):
63
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
64
+ def test_compute(self, input_params, expected_avg):
65
+ metric = PanopticQuality(**input_params)
66
+
67
+ # set up engine
68
+
69
+ def _val_func(engine, batch):
70
+ pass
71
+
72
+ engine = Engine(_val_func)
73
+ metric.attach(engine=engine, name="panoptic_quality")
74
+ # test input a list of channel-first tensor
75
+ y_pred = [sample_1_pred, sample_2_pred]
76
+ y = [sample_1_gt, sample_2_gt]
77
+ engine.state.output = {"pred": y_pred, "label": y}
78
+ engine.fire_event(Events.ITERATION_COMPLETED)
79
+ y_pred = [sample_1_pred, sample_1_pred]
80
+ y = [sample_1_gt, sample_1_gt]
81
+ engine.state.output = {"pred": y_pred, "label": y}
82
+ engine.fire_event(Events.ITERATION_COMPLETED)
83
+
84
+ engine.fire_event(Events.EPOCH_COMPLETED)
85
+ assert_allclose(engine.state.metrics["panoptic_quality"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False)
86
+
87
+
88
+ if __name__ == "__main__":
89
+ unittest.main()
@@ -0,0 +1,136 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ from ignite.engine import Engine, Events
17
+ from torch.nn import Module
18
+
19
+ from monai.handlers.parameter_scheduler import ParamSchedulerHandler
20
+ from tests.test_utils import assert_allclose
21
+
22
+
23
+ class ToyNet(Module):
24
+ def __init__(self, value):
25
+ super().__init__()
26
+ self.value = value
27
+
28
+ def forward(self, input):
29
+ return input
30
+
31
+ def get_value(self):
32
+ return self.value
33
+
34
+ def set_value(self, value):
35
+ self.value = value
36
+
37
+
38
+ class TestHandlerParameterScheduler(unittest.TestCase):
39
+ def test_linear_scheduler(self):
40
+ # Testing step_constant
41
+ net = ToyNet(value=-1)
42
+ engine = Engine(lambda e, b: None)
43
+ ParamSchedulerHandler(
44
+ parameter_setter=net.set_value,
45
+ value_calculator="linear",
46
+ vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10},
47
+ epoch_level=True,
48
+ event=Events.EPOCH_COMPLETED,
49
+ ).attach(engine)
50
+ engine.run([0] * 8, max_epochs=2)
51
+ assert_allclose(net.get_value(), 0)
52
+
53
+ # Testing linear increase
54
+ net = ToyNet(value=-1)
55
+ engine = Engine(lambda e, b: None)
56
+ ParamSchedulerHandler(
57
+ parameter_setter=net.set_value,
58
+ value_calculator="linear",
59
+ vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10},
60
+ epoch_level=True,
61
+ event=Events.EPOCH_COMPLETED,
62
+ ).attach(engine)
63
+ engine.run([0] * 8, max_epochs=3)
64
+ assert_allclose(net.get_value(), 3.333333, atol=0.001, rtol=0.0)
65
+
66
+ # Testing max_value
67
+ net = ToyNet(value=-1)
68
+ engine = Engine(lambda e, b: None)
69
+ ParamSchedulerHandler(
70
+ parameter_setter=net.set_value,
71
+ value_calculator="linear",
72
+ vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10},
73
+ epoch_level=True,
74
+ event=Events.EPOCH_COMPLETED,
75
+ ).attach(engine)
76
+ engine.run([0] * 8, max_epochs=10)
77
+ assert_allclose(net.get_value(), 10)
78
+
79
+ def test_exponential_scheduler(self):
80
+ net = ToyNet(value=-1)
81
+ engine = Engine(lambda e, b: None)
82
+ ParamSchedulerHandler(
83
+ parameter_setter=net.set_value,
84
+ value_calculator="exponential",
85
+ vc_kwargs={"initial_value": 10, "gamma": 0.99},
86
+ epoch_level=True,
87
+ event=Events.EPOCH_COMPLETED,
88
+ ).attach(engine)
89
+ engine.run([0] * 8, max_epochs=2)
90
+ assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
91
+
92
+ def test_step_scheduler(self):
93
+ net = ToyNet(value=-1)
94
+ engine = Engine(lambda e, b: None)
95
+ ParamSchedulerHandler(
96
+ parameter_setter=net.set_value,
97
+ value_calculator="step",
98
+ vc_kwargs={"initial_value": 10, "gamma": 0.99, "step_size": 5},
99
+ epoch_level=True,
100
+ event=Events.EPOCH_COMPLETED,
101
+ ).attach(engine)
102
+ engine.run([0] * 8, max_epochs=10)
103
+ assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
104
+
105
+ def test_multistep_scheduler(self):
106
+ net = ToyNet(value=-1)
107
+ engine = Engine(lambda e, b: None)
108
+ ParamSchedulerHandler(
109
+ parameter_setter=net.set_value,
110
+ value_calculator="multistep",
111
+ vc_kwargs={"initial_value": 10, "gamma": 0.99, "milestones": [3, 6]},
112
+ epoch_level=True,
113
+ event=Events.EPOCH_COMPLETED,
114
+ ).attach(engine)
115
+ engine.run([0] * 8, max_epochs=10)
116
+ assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
117
+
118
+ def test_custom_scheduler(self):
119
+ def custom_logic(initial_value, gamma, current_step):
120
+ return initial_value * gamma ** (current_step % 9)
121
+
122
+ net = ToyNet(value=-1)
123
+ engine = Engine(lambda e, b: None)
124
+ ParamSchedulerHandler(
125
+ parameter_setter=net.set_value,
126
+ value_calculator=custom_logic,
127
+ vc_kwargs={"initial_value": 10, "gamma": 0.99},
128
+ epoch_level=True,
129
+ event=Events.EPOCH_COMPLETED,
130
+ ).attach(engine)
131
+ engine.run([0] * 8, max_epochs=2)
132
+ assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
133
+
134
+
135
+ if __name__ == "__main__":
136
+ unittest.main()
@@ -0,0 +1,74 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import torch
17
+ from parameterized import parameterized
18
+
19
+ from monai.engines import SupervisedEvaluator
20
+ from monai.handlers import PostProcessing
21
+ from monai.transforms import Activationsd, AsDiscreted, Compose, CopyItemsd
22
+ from tests.test_utils import assert_allclose
23
+
24
+ # test lambda function as `transform`
25
+ TEST_CASE_1 = [{"transform": lambda x: dict(pred=x["pred"] + 1.0)}, False, torch.tensor([[[[1.9975], [1.9997]]]])]
26
+ # test composed postprocessing transforms as `transform`
27
+ TEST_CASE_2 = [
28
+ {
29
+ "transform": Compose(
30
+ [
31
+ CopyItemsd(keys="filename", times=1, names="filename_bak"),
32
+ AsDiscreted(keys="pred", threshold=0.5, to_onehot=2),
33
+ ]
34
+ ),
35
+ "event": "iteration_completed",
36
+ },
37
+ True,
38
+ torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]]),
39
+ ]
40
+
41
+
42
+ class TestHandlerPostProcessing(unittest.TestCase):
43
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2])
44
+ def test_compute(self, input_params, decollate, expected):
45
+ data = [
46
+ {"image": torch.tensor([[[[2.0], [3.0]]]]), "filename": ["test1"]},
47
+ {"image": torch.tensor([[[[6.0], [8.0]]]]), "filename": ["test2"]},
48
+ ]
49
+ # set up engine, PostProcessing handler works together with postprocessing transforms of engine
50
+ engine = SupervisedEvaluator(
51
+ device=torch.device("cpu:0"),
52
+ val_data_loader=data,
53
+ epoch_length=2,
54
+ network=torch.nn.PReLU(),
55
+ postprocessing=Compose([Activationsd(keys="pred", sigmoid=True)]),
56
+ val_handlers=[PostProcessing(**input_params)],
57
+ decollate=decollate,
58
+ )
59
+ engine.run()
60
+
61
+ if isinstance(engine.state.output, list):
62
+ # test decollated list items
63
+ for o, e in zip(engine.state.output, expected):
64
+ assert_allclose(o["pred"], e, atol=1e-4, rtol=1e-4, type_test=False)
65
+ filename = o.get("filename_bak")
66
+ if filename is not None:
67
+ self.assertEqual(filename, "test2")
68
+ else:
69
+ # test batch data
70
+ assert_allclose(engine.state.output["pred"], expected, atol=1e-4, rtol=1e-4, type_test=False)
71
+
72
+
73
+ if __name__ == "__main__":
74
+ unittest.main()