monai-weekly 1.5.dev2506__py3-none-any.whl → 1.5.dev2508__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (787) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/transforms.py +1 -4
  4. monai/data/utils.py +6 -13
  5. monai/handlers/__init__.py +1 -0
  6. monai/handlers/average_precision.py +53 -0
  7. monai/inferers/inferer.py +10 -7
  8. monai/inferers/utils.py +1 -2
  9. monai/losses/dice.py +2 -14
  10. monai/losses/ds_loss.py +1 -3
  11. monai/metrics/__init__.py +1 -0
  12. monai/metrics/average_precision.py +187 -0
  13. monai/networks/layers/simplelayers.py +2 -14
  14. monai/networks/utils.py +4 -16
  15. monai/transforms/compose.py +28 -11
  16. monai/transforms/croppad/array.py +1 -6
  17. monai/transforms/io/array.py +0 -1
  18. monai/transforms/transform.py +15 -6
  19. monai/transforms/utility/array.py +2 -12
  20. monai/transforms/utils.py +1 -2
  21. monai/transforms/utils_pytorch_numpy_unification.py +2 -4
  22. monai/utils/enums.py +3 -2
  23. monai/utils/module.py +6 -6
  24. monai/utils/tf32.py +0 -10
  25. monai/visualize/class_activation_maps.py +5 -8
  26. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/METADATA +21 -17
  27. monai_weekly-1.5.dev2508.dist-info/RECORD +1185 -0
  28. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/top_level.txt +1 -0
  29. tests/apps/__init__.py +10 -0
  30. tests/apps/deepedit/__init__.py +10 -0
  31. tests/apps/deepedit/test_deepedit_transforms.py +314 -0
  32. tests/apps/deepgrow/__init__.py +10 -0
  33. tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
  34. tests/apps/deepgrow/transforms/__init__.py +10 -0
  35. tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
  36. tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
  37. tests/apps/detection/__init__.py +10 -0
  38. tests/apps/detection/metrics/__init__.py +10 -0
  39. tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
  40. tests/apps/detection/networks/__init__.py +10 -0
  41. tests/apps/detection/networks/test_retinanet.py +210 -0
  42. tests/apps/detection/networks/test_retinanet_detector.py +203 -0
  43. tests/apps/detection/test_box_transform.py +370 -0
  44. tests/apps/detection/utils/__init__.py +10 -0
  45. tests/apps/detection/utils/test_anchor_box.py +88 -0
  46. tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
  47. tests/apps/detection/utils/test_box_coder.py +43 -0
  48. tests/apps/detection/utils/test_detector_boxselector.py +67 -0
  49. tests/apps/detection/utils/test_detector_utils.py +96 -0
  50. tests/apps/detection/utils/test_hardnegsampler.py +54 -0
  51. tests/apps/nuclick/__init__.py +10 -0
  52. tests/apps/nuclick/test_nuclick_transforms.py +259 -0
  53. tests/apps/pathology/__init__.py +10 -0
  54. tests/apps/pathology/handlers/__init__.py +10 -0
  55. tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
  56. tests/apps/pathology/test_lesion_froc.py +333 -0
  57. tests/apps/pathology/test_pathology_prob_nms.py +55 -0
  58. tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
  59. tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
  60. tests/apps/pathology/transforms/__init__.py +10 -0
  61. tests/apps/pathology/transforms/post/__init__.py +10 -0
  62. tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
  63. tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
  64. tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
  65. tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
  66. tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
  67. tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
  68. tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
  69. tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
  70. tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
  71. tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
  72. tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
  73. tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
  74. tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
  75. tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
  76. tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
  77. tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
  78. tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
  79. tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
  80. tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
  81. tests/apps/pathology/transforms/post/test_watershed.py +60 -0
  82. tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
  83. tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
  84. tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
  85. tests/apps/reconstruction/__init__.py +10 -0
  86. tests/apps/reconstruction/nets/__init__.py +10 -0
  87. tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
  88. tests/apps/reconstruction/test_complex_utils.py +77 -0
  89. tests/apps/reconstruction/test_fastmri_reader.py +82 -0
  90. tests/apps/reconstruction/test_mri_utils.py +37 -0
  91. tests/apps/reconstruction/transforms/__init__.py +10 -0
  92. tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
  93. tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
  94. tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
  95. tests/apps/test_auto3dseg_bundlegen.py +156 -0
  96. tests/apps/test_check_hash.py +53 -0
  97. tests/apps/test_cross_validation.py +74 -0
  98. tests/apps/test_decathlondataset.py +93 -0
  99. tests/apps/test_download_and_extract.py +70 -0
  100. tests/apps/test_download_url_yandex.py +45 -0
  101. tests/apps/test_mednistdataset.py +72 -0
  102. tests/apps/test_mmar_download.py +154 -0
  103. tests/apps/test_tciadataset.py +123 -0
  104. tests/apps/vista3d/__init__.py +10 -0
  105. tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
  106. tests/apps/vista3d/test_vista3d_sampler.py +100 -0
  107. tests/apps/vista3d/test_vista3d_transforms.py +94 -0
  108. tests/bundle/__init__.py +10 -0
  109. tests/bundle/test_bundle_ckpt_export.py +107 -0
  110. tests/bundle/test_bundle_download.py +435 -0
  111. tests/bundle/test_bundle_get_data.py +94 -0
  112. tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
  113. tests/bundle/test_bundle_trt_export.py +147 -0
  114. tests/bundle/test_bundle_utils.py +149 -0
  115. tests/bundle/test_bundle_verify_metadata.py +66 -0
  116. tests/bundle/test_bundle_verify_net.py +76 -0
  117. tests/bundle/test_bundle_workflow.py +272 -0
  118. tests/bundle/test_component_locator.py +38 -0
  119. tests/bundle/test_config_item.py +138 -0
  120. tests/bundle/test_config_parser.py +392 -0
  121. tests/bundle/test_reference_resolver.py +114 -0
  122. tests/config/__init__.py +10 -0
  123. tests/config/test_cv2_dist.py +53 -0
  124. tests/engines/__init__.py +10 -0
  125. tests/engines/test_ensemble_evaluator.py +94 -0
  126. tests/engines/test_prepare_batch_default.py +76 -0
  127. tests/engines/test_prepare_batch_default_dist.py +76 -0
  128. tests/engines/test_prepare_batch_diffusion.py +104 -0
  129. tests/engines/test_prepare_batch_extra_input.py +80 -0
  130. tests/fl/__init__.py +10 -0
  131. tests/fl/monai_algo/__init__.py +10 -0
  132. tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
  133. tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
  134. tests/fl/test_fl_monai_algo_stats.py +81 -0
  135. tests/fl/utils/__init__.py +10 -0
  136. tests/fl/utils/test_fl_exchange_object.py +63 -0
  137. tests/handlers/__init__.py +10 -0
  138. tests/handlers/test_handler_average_precision.py +79 -0
  139. tests/handlers/test_handler_checkpoint_loader.py +182 -0
  140. tests/handlers/test_handler_checkpoint_saver.py +233 -0
  141. tests/handlers/test_handler_classification_saver.py +64 -0
  142. tests/handlers/test_handler_classification_saver_dist.py +77 -0
  143. tests/handlers/test_handler_clearml_image.py +65 -0
  144. tests/handlers/test_handler_clearml_stats.py +65 -0
  145. tests/handlers/test_handler_confusion_matrix.py +104 -0
  146. tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
  147. tests/handlers/test_handler_decollate_batch.py +66 -0
  148. tests/handlers/test_handler_early_stop.py +68 -0
  149. tests/handlers/test_handler_garbage_collector.py +73 -0
  150. tests/handlers/test_handler_hausdorff_distance.py +111 -0
  151. tests/handlers/test_handler_ignite_metric.py +191 -0
  152. tests/handlers/test_handler_lr_scheduler.py +94 -0
  153. tests/handlers/test_handler_mean_dice.py +98 -0
  154. tests/handlers/test_handler_mean_iou.py +76 -0
  155. tests/handlers/test_handler_metrics_reloaded.py +149 -0
  156. tests/handlers/test_handler_metrics_saver.py +89 -0
  157. tests/handlers/test_handler_metrics_saver_dist.py +120 -0
  158. tests/handlers/test_handler_mlflow.py +296 -0
  159. tests/handlers/test_handler_nvtx.py +93 -0
  160. tests/handlers/test_handler_panoptic_quality.py +89 -0
  161. tests/handlers/test_handler_parameter_scheduler.py +136 -0
  162. tests/handlers/test_handler_post_processing.py +74 -0
  163. tests/handlers/test_handler_prob_map_producer.py +111 -0
  164. tests/handlers/test_handler_regression_metrics.py +160 -0
  165. tests/handlers/test_handler_regression_metrics_dist.py +245 -0
  166. tests/handlers/test_handler_rocauc.py +48 -0
  167. tests/handlers/test_handler_rocauc_dist.py +54 -0
  168. tests/handlers/test_handler_stats.py +281 -0
  169. tests/handlers/test_handler_surface_distance.py +113 -0
  170. tests/handlers/test_handler_tb_image.py +61 -0
  171. tests/handlers/test_handler_tb_stats.py +166 -0
  172. tests/handlers/test_handler_validation.py +59 -0
  173. tests/handlers/test_trt_compile.py +145 -0
  174. tests/handlers/test_write_metrics_reports.py +68 -0
  175. tests/inferers/__init__.py +10 -0
  176. tests/inferers/test_avg_merger.py +179 -0
  177. tests/inferers/test_controlnet_inferers.py +1388 -0
  178. tests/inferers/test_diffusion_inferer.py +236 -0
  179. tests/inferers/test_latent_diffusion_inferer.py +884 -0
  180. tests/inferers/test_patch_inferer.py +309 -0
  181. tests/inferers/test_saliency_inferer.py +55 -0
  182. tests/inferers/test_slice_inferer.py +57 -0
  183. tests/inferers/test_sliding_window_inference.py +377 -0
  184. tests/inferers/test_sliding_window_splitter.py +284 -0
  185. tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
  186. tests/inferers/test_zarr_avg_merger.py +326 -0
  187. tests/integration/__init__.py +10 -0
  188. tests/integration/test_auto3dseg_ensemble.py +211 -0
  189. tests/integration/test_auto3dseg_hpo.py +189 -0
  190. tests/integration/test_deepedit_interaction.py +122 -0
  191. tests/integration/test_downsample_block.py +50 -0
  192. tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
  193. tests/integration/test_integration_autorunner.py +201 -0
  194. tests/integration/test_integration_bundle_run.py +240 -0
  195. tests/integration/test_integration_classification_2d.py +282 -0
  196. tests/integration/test_integration_determinism.py +95 -0
  197. tests/integration/test_integration_fast_train.py +231 -0
  198. tests/integration/test_integration_gpu_customization.py +159 -0
  199. tests/integration/test_integration_lazy_samples.py +219 -0
  200. tests/integration/test_integration_nnunetv2_runner.py +96 -0
  201. tests/integration/test_integration_segmentation_3d.py +304 -0
  202. tests/integration/test_integration_sliding_window.py +100 -0
  203. tests/integration/test_integration_stn.py +133 -0
  204. tests/integration/test_integration_unet_2d.py +67 -0
  205. tests/integration/test_integration_workers.py +61 -0
  206. tests/integration/test_integration_workflows.py +365 -0
  207. tests/integration/test_integration_workflows_adversarial.py +173 -0
  208. tests/integration/test_integration_workflows_gan.py +158 -0
  209. tests/integration/test_loader_semaphore.py +48 -0
  210. tests/integration/test_mapping_filed.py +122 -0
  211. tests/integration/test_meta_affine.py +183 -0
  212. tests/integration/test_metatensor_integration.py +114 -0
  213. tests/integration/test_module_list.py +76 -0
  214. tests/integration/test_one_of.py +283 -0
  215. tests/integration/test_pad_collation.py +124 -0
  216. tests/integration/test_reg_loss_integration.py +107 -0
  217. tests/integration/test_retinanet_predict_utils.py +154 -0
  218. tests/integration/test_seg_loss_integration.py +159 -0
  219. tests/integration/test_spatial_combine_transforms.py +185 -0
  220. tests/integration/test_testtimeaugmentation.py +186 -0
  221. tests/integration/test_vis_gradbased.py +69 -0
  222. tests/integration/test_vista3d_utils.py +159 -0
  223. tests/losses/__init__.py +10 -0
  224. tests/losses/deform/__init__.py +10 -0
  225. tests/losses/deform/test_bending_energy.py +88 -0
  226. tests/losses/deform/test_diffusion_loss.py +117 -0
  227. tests/losses/image_dissimilarity/__init__.py +10 -0
  228. tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
  229. tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
  230. tests/losses/test_adversarial_loss.py +94 -0
  231. tests/losses/test_barlow_twins_loss.py +109 -0
  232. tests/losses/test_cldice_loss.py +51 -0
  233. tests/losses/test_contrastive_loss.py +86 -0
  234. tests/losses/test_dice_ce_loss.py +123 -0
  235. tests/losses/test_dice_focal_loss.py +124 -0
  236. tests/losses/test_dice_loss.py +227 -0
  237. tests/losses/test_ds_loss.py +189 -0
  238. tests/losses/test_focal_loss.py +379 -0
  239. tests/losses/test_generalized_dice_focal_loss.py +85 -0
  240. tests/losses/test_generalized_dice_loss.py +221 -0
  241. tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
  242. tests/losses/test_giou_loss.py +62 -0
  243. tests/losses/test_hausdorff_loss.py +264 -0
  244. tests/losses/test_masked_dice_loss.py +152 -0
  245. tests/losses/test_masked_loss.py +87 -0
  246. tests/losses/test_multi_scale.py +86 -0
  247. tests/losses/test_nacl_loss.py +167 -0
  248. tests/losses/test_perceptual_loss.py +122 -0
  249. tests/losses/test_spectral_loss.py +86 -0
  250. tests/losses/test_ssim_loss.py +59 -0
  251. tests/losses/test_sure_loss.py +72 -0
  252. tests/losses/test_tversky_loss.py +198 -0
  253. tests/losses/test_unified_focal_loss.py +66 -0
  254. tests/metrics/__init__.py +10 -0
  255. tests/metrics/test_compute_average_precision.py +162 -0
  256. tests/metrics/test_compute_confusion_matrix.py +294 -0
  257. tests/metrics/test_compute_f_beta.py +80 -0
  258. tests/metrics/test_compute_fid_metric.py +40 -0
  259. tests/metrics/test_compute_froc.py +143 -0
  260. tests/metrics/test_compute_generalized_dice.py +240 -0
  261. tests/metrics/test_compute_meandice.py +306 -0
  262. tests/metrics/test_compute_meaniou.py +223 -0
  263. tests/metrics/test_compute_mmd_metric.py +56 -0
  264. tests/metrics/test_compute_multiscalessim_metric.py +83 -0
  265. tests/metrics/test_compute_panoptic_quality.py +113 -0
  266. tests/metrics/test_compute_regression_metrics.py +196 -0
  267. tests/metrics/test_compute_roc_auc.py +155 -0
  268. tests/metrics/test_compute_variance.py +147 -0
  269. tests/metrics/test_cumulative.py +63 -0
  270. tests/metrics/test_cumulative_average.py +74 -0
  271. tests/metrics/test_cumulative_average_dist.py +48 -0
  272. tests/metrics/test_hausdorff_distance.py +209 -0
  273. tests/metrics/test_label_quality_score.py +134 -0
  274. tests/metrics/test_loss_metric.py +57 -0
  275. tests/metrics/test_metrics_reloaded.py +96 -0
  276. tests/metrics/test_ssim_metric.py +78 -0
  277. tests/metrics/test_surface_dice.py +416 -0
  278. tests/metrics/test_surface_distance.py +186 -0
  279. tests/networks/__init__.py +10 -0
  280. tests/networks/blocks/__init__.py +10 -0
  281. tests/networks/blocks/dints_block/__init__.py +10 -0
  282. tests/networks/blocks/dints_block/test_acn_block.py +41 -0
  283. tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
  284. tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
  285. tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
  286. tests/networks/blocks/test_adn.py +86 -0
  287. tests/networks/blocks/test_convolutions.py +156 -0
  288. tests/networks/blocks/test_crf_cpu.py +513 -0
  289. tests/networks/blocks/test_crf_cuda.py +528 -0
  290. tests/networks/blocks/test_crossattention.py +185 -0
  291. tests/networks/blocks/test_denseblock.py +105 -0
  292. tests/networks/blocks/test_dynunet_block.py +116 -0
  293. tests/networks/blocks/test_fpn_block.py +88 -0
  294. tests/networks/blocks/test_localnet_block.py +121 -0
  295. tests/networks/blocks/test_mlp.py +78 -0
  296. tests/networks/blocks/test_patchembedding.py +212 -0
  297. tests/networks/blocks/test_regunet_block.py +103 -0
  298. tests/networks/blocks/test_se_block.py +85 -0
  299. tests/networks/blocks/test_se_blocks.py +78 -0
  300. tests/networks/blocks/test_segresnet_block.py +57 -0
  301. tests/networks/blocks/test_selfattention.py +232 -0
  302. tests/networks/blocks/test_simple_aspp.py +87 -0
  303. tests/networks/blocks/test_spatialattention.py +55 -0
  304. tests/networks/blocks/test_subpixel_upsample.py +87 -0
  305. tests/networks/blocks/test_text_encoding.py +49 -0
  306. tests/networks/blocks/test_transformerblock.py +90 -0
  307. tests/networks/blocks/test_unetr_block.py +158 -0
  308. tests/networks/blocks/test_upsample_block.py +134 -0
  309. tests/networks/blocks/warp/__init__.py +10 -0
  310. tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
  311. tests/networks/blocks/warp/test_warp.py +250 -0
  312. tests/networks/layers/__init__.py +10 -0
  313. tests/networks/layers/filtering/__init__.py +10 -0
  314. tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
  315. tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
  316. tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
  317. tests/networks/layers/filtering/test_phl_cpu.py +259 -0
  318. tests/networks/layers/filtering/test_phl_cuda.py +167 -0
  319. tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
  320. tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
  321. tests/networks/layers/test_affine_transform.py +385 -0
  322. tests/networks/layers/test_apply_filter.py +89 -0
  323. tests/networks/layers/test_channel_pad.py +51 -0
  324. tests/networks/layers/test_conjugate_gradient.py +56 -0
  325. tests/networks/layers/test_drop_path.py +46 -0
  326. tests/networks/layers/test_gaussian.py +317 -0
  327. tests/networks/layers/test_gaussian_filter.py +206 -0
  328. tests/networks/layers/test_get_layers.py +65 -0
  329. tests/networks/layers/test_gmm.py +314 -0
  330. tests/networks/layers/test_grid_pull.py +93 -0
  331. tests/networks/layers/test_hilbert_transform.py +131 -0
  332. tests/networks/layers/test_lltm.py +62 -0
  333. tests/networks/layers/test_median_filter.py +52 -0
  334. tests/networks/layers/test_polyval.py +55 -0
  335. tests/networks/layers/test_preset_filters.py +136 -0
  336. tests/networks/layers/test_savitzky_golay_filter.py +141 -0
  337. tests/networks/layers/test_separable_filter.py +87 -0
  338. tests/networks/layers/test_skip_connection.py +48 -0
  339. tests/networks/layers/test_vector_quantizer.py +89 -0
  340. tests/networks/layers/test_weight_init.py +50 -0
  341. tests/networks/nets/__init__.py +10 -0
  342. tests/networks/nets/dints/__init__.py +10 -0
  343. tests/networks/nets/dints/test_dints_cell.py +110 -0
  344. tests/networks/nets/dints/test_dints_mixop.py +84 -0
  345. tests/networks/nets/regunet/__init__.py +10 -0
  346. tests/networks/nets/regunet/test_localnet.py +86 -0
  347. tests/networks/nets/regunet/test_regunet.py +88 -0
  348. tests/networks/nets/test_ahnet.py +224 -0
  349. tests/networks/nets/test_attentionunet.py +88 -0
  350. tests/networks/nets/test_autoencoder.py +95 -0
  351. tests/networks/nets/test_autoencoderkl.py +337 -0
  352. tests/networks/nets/test_basic_unet.py +102 -0
  353. tests/networks/nets/test_basic_unetplusplus.py +109 -0
  354. tests/networks/nets/test_bundle_init_bundle.py +55 -0
  355. tests/networks/nets/test_cell_sam_wrapper.py +58 -0
  356. tests/networks/nets/test_controlnet.py +215 -0
  357. tests/networks/nets/test_daf3d.py +62 -0
  358. tests/networks/nets/test_densenet.py +121 -0
  359. tests/networks/nets/test_diffusion_model_unet.py +585 -0
  360. tests/networks/nets/test_dints_network.py +168 -0
  361. tests/networks/nets/test_discriminator.py +59 -0
  362. tests/networks/nets/test_dynunet.py +181 -0
  363. tests/networks/nets/test_efficientnet.py +400 -0
  364. tests/networks/nets/test_flexible_unet.py +341 -0
  365. tests/networks/nets/test_fullyconnectednet.py +69 -0
  366. tests/networks/nets/test_generator.py +59 -0
  367. tests/networks/nets/test_globalnet.py +103 -0
  368. tests/networks/nets/test_highresnet.py +67 -0
  369. tests/networks/nets/test_hovernet.py +218 -0
  370. tests/networks/nets/test_mednext.py +122 -0
  371. tests/networks/nets/test_milmodel.py +92 -0
  372. tests/networks/nets/test_net_adapter.py +68 -0
  373. tests/networks/nets/test_network_consistency.py +86 -0
  374. tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
  375. tests/networks/nets/test_quicknat.py +57 -0
  376. tests/networks/nets/test_resnet.py +340 -0
  377. tests/networks/nets/test_segresnet.py +120 -0
  378. tests/networks/nets/test_segresnet_ds.py +156 -0
  379. tests/networks/nets/test_senet.py +151 -0
  380. tests/networks/nets/test_spade_autoencoderkl.py +295 -0
  381. tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
  382. tests/networks/nets/test_spade_vaegan.py +140 -0
  383. tests/networks/nets/test_swin_unetr.py +139 -0
  384. tests/networks/nets/test_torchvision_fc_model.py +201 -0
  385. tests/networks/nets/test_transchex.py +84 -0
  386. tests/networks/nets/test_transformer.py +108 -0
  387. tests/networks/nets/test_unet.py +208 -0
  388. tests/networks/nets/test_unetr.py +137 -0
  389. tests/networks/nets/test_varautoencoder.py +127 -0
  390. tests/networks/nets/test_vista3d.py +84 -0
  391. tests/networks/nets/test_vit.py +139 -0
  392. tests/networks/nets/test_vitautoenc.py +112 -0
  393. tests/networks/nets/test_vnet.py +81 -0
  394. tests/networks/nets/test_voxelmorph.py +280 -0
  395. tests/networks/nets/test_vqvae.py +274 -0
  396. tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
  397. tests/networks/schedulers/__init__.py +10 -0
  398. tests/networks/schedulers/test_scheduler_ddim.py +83 -0
  399. tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
  400. tests/networks/schedulers/test_scheduler_pndm.py +108 -0
  401. tests/networks/test_bundle_onnx_export.py +71 -0
  402. tests/networks/test_convert_to_onnx.py +106 -0
  403. tests/networks/test_convert_to_torchscript.py +46 -0
  404. tests/networks/test_convert_to_trt.py +79 -0
  405. tests/networks/test_save_state.py +73 -0
  406. tests/networks/test_to_onehot.py +63 -0
  407. tests/networks/test_varnet.py +63 -0
  408. tests/networks/utils/__init__.py +10 -0
  409. tests/networks/utils/test_copy_model_state.py +187 -0
  410. tests/networks/utils/test_eval_mode.py +34 -0
  411. tests/networks/utils/test_freeze_layers.py +61 -0
  412. tests/networks/utils/test_replace_module.py +98 -0
  413. tests/networks/utils/test_train_mode.py +34 -0
  414. tests/optimizers/__init__.py +10 -0
  415. tests/optimizers/test_generate_param_groups.py +105 -0
  416. tests/optimizers/test_lr_finder.py +108 -0
  417. tests/optimizers/test_lr_scheduler.py +71 -0
  418. tests/optimizers/test_optim_novograd.py +100 -0
  419. tests/profile_subclass/__init__.py +10 -0
  420. tests/profile_subclass/cprofile_profiling.py +29 -0
  421. tests/profile_subclass/min_classes.py +30 -0
  422. tests/profile_subclass/profiling.py +73 -0
  423. tests/profile_subclass/pyspy_profiling.py +41 -0
  424. tests/transforms/__init__.py +10 -0
  425. tests/transforms/compose/__init__.py +10 -0
  426. tests/transforms/compose/test_compose.py +758 -0
  427. tests/transforms/compose/test_some_of.py +258 -0
  428. tests/transforms/croppad/__init__.py +10 -0
  429. tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
  430. tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
  431. tests/transforms/functional/__init__.py +10 -0
  432. tests/transforms/functional/test_apply.py +75 -0
  433. tests/transforms/functional/test_resample.py +50 -0
  434. tests/transforms/intensity/__init__.py +10 -0
  435. tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
  436. tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
  437. tests/transforms/intensity/test_foreground_mask.py +98 -0
  438. tests/transforms/intensity/test_foreground_maskd.py +106 -0
  439. tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
  440. tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
  441. tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
  442. tests/transforms/inverse/__init__.py +10 -0
  443. tests/transforms/inverse/test_inverse_array.py +76 -0
  444. tests/transforms/inverse/test_traceable_transform.py +59 -0
  445. tests/transforms/post/__init__.py +10 -0
  446. tests/transforms/post/test_label_filterd.py +78 -0
  447. tests/transforms/post/test_probnms.py +72 -0
  448. tests/transforms/post/test_probnmsd.py +79 -0
  449. tests/transforms/post/test_remove_small_objects.py +102 -0
  450. tests/transforms/spatial/__init__.py +10 -0
  451. tests/transforms/spatial/test_convert_box_points.py +119 -0
  452. tests/transforms/spatial/test_grid_patch.py +134 -0
  453. tests/transforms/spatial/test_grid_patchd.py +102 -0
  454. tests/transforms/spatial/test_rand_grid_patch.py +150 -0
  455. tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
  456. tests/transforms/spatial/test_spatial_resampled.py +124 -0
  457. tests/transforms/test_activations.py +120 -0
  458. tests/transforms/test_activationsd.py +64 -0
  459. tests/transforms/test_adaptors.py +160 -0
  460. tests/transforms/test_add_coordinate_channels.py +53 -0
  461. tests/transforms/test_add_coordinate_channelsd.py +67 -0
  462. tests/transforms/test_add_extreme_points_channel.py +80 -0
  463. tests/transforms/test_add_extreme_points_channeld.py +77 -0
  464. tests/transforms/test_adjust_contrast.py +70 -0
  465. tests/transforms/test_adjust_contrastd.py +64 -0
  466. tests/transforms/test_affine.py +245 -0
  467. tests/transforms/test_affine_grid.py +152 -0
  468. tests/transforms/test_affined.py +190 -0
  469. tests/transforms/test_as_channel_last.py +38 -0
  470. tests/transforms/test_as_channel_lastd.py +44 -0
  471. tests/transforms/test_as_discrete.py +81 -0
  472. tests/transforms/test_as_discreted.py +82 -0
  473. tests/transforms/test_border_pad.py +49 -0
  474. tests/transforms/test_border_padd.py +45 -0
  475. tests/transforms/test_bounding_rect.py +54 -0
  476. tests/transforms/test_bounding_rectd.py +53 -0
  477. tests/transforms/test_cast_to_type.py +63 -0
  478. tests/transforms/test_cast_to_typed.py +74 -0
  479. tests/transforms/test_center_scale_crop.py +55 -0
  480. tests/transforms/test_center_scale_cropd.py +56 -0
  481. tests/transforms/test_center_spatial_crop.py +56 -0
  482. tests/transforms/test_center_spatial_cropd.py +63 -0
  483. tests/transforms/test_classes_to_indices.py +93 -0
  484. tests/transforms/test_classes_to_indicesd.py +110 -0
  485. tests/transforms/test_clip_intensity_percentiles.py +196 -0
  486. tests/transforms/test_clip_intensity_percentilesd.py +193 -0
  487. tests/transforms/test_compose_get_number_conversions.py +127 -0
  488. tests/transforms/test_concat_itemsd.py +82 -0
  489. tests/transforms/test_convert_to_multi_channel.py +59 -0
  490. tests/transforms/test_convert_to_multi_channeld.py +37 -0
  491. tests/transforms/test_copy_itemsd.py +86 -0
  492. tests/transforms/test_create_grid_and_affine.py +274 -0
  493. tests/transforms/test_crop_foreground.py +164 -0
  494. tests/transforms/test_crop_foregroundd.py +205 -0
  495. tests/transforms/test_cucim_dict_transform.py +142 -0
  496. tests/transforms/test_cucim_transform.py +141 -0
  497. tests/transforms/test_data_stats.py +221 -0
  498. tests/transforms/test_data_statsd.py +249 -0
  499. tests/transforms/test_delete_itemsd.py +58 -0
  500. tests/transforms/test_detect_envelope.py +159 -0
  501. tests/transforms/test_distance_transform_edt.py +202 -0
  502. tests/transforms/test_divisible_pad.py +49 -0
  503. tests/transforms/test_divisible_padd.py +42 -0
  504. tests/transforms/test_ensure_channel_first.py +113 -0
  505. tests/transforms/test_ensure_channel_firstd.py +85 -0
  506. tests/transforms/test_ensure_type.py +94 -0
  507. tests/transforms/test_ensure_typed.py +110 -0
  508. tests/transforms/test_fg_bg_to_indices.py +83 -0
  509. tests/transforms/test_fg_bg_to_indicesd.py +78 -0
  510. tests/transforms/test_fill_holes.py +207 -0
  511. tests/transforms/test_fill_holesd.py +209 -0
  512. tests/transforms/test_flatten_sub_keysd.py +64 -0
  513. tests/transforms/test_flip.py +83 -0
  514. tests/transforms/test_flipd.py +90 -0
  515. tests/transforms/test_fourier.py +70 -0
  516. tests/transforms/test_gaussian_sharpen.py +92 -0
  517. tests/transforms/test_gaussian_sharpend.py +92 -0
  518. tests/transforms/test_gaussian_smooth.py +96 -0
  519. tests/transforms/test_gaussian_smoothd.py +96 -0
  520. tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
  521. tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
  522. tests/transforms/test_generate_spatial_bounding_box.py +114 -0
  523. tests/transforms/test_get_extreme_points.py +57 -0
  524. tests/transforms/test_gibbs_noise.py +73 -0
  525. tests/transforms/test_gibbs_noised.py +88 -0
  526. tests/transforms/test_grid_distortion.py +113 -0
  527. tests/transforms/test_grid_distortiond.py +87 -0
  528. tests/transforms/test_grid_split.py +88 -0
  529. tests/transforms/test_grid_splitd.py +96 -0
  530. tests/transforms/test_histogram_normalize.py +59 -0
  531. tests/transforms/test_histogram_normalized.py +59 -0
  532. tests/transforms/test_image_filter.py +259 -0
  533. tests/transforms/test_intensity_stats.py +73 -0
  534. tests/transforms/test_intensity_statsd.py +90 -0
  535. tests/transforms/test_inverse.py +521 -0
  536. tests/transforms/test_inverse_collation.py +147 -0
  537. tests/transforms/test_invert.py +105 -0
  538. tests/transforms/test_invertd.py +142 -0
  539. tests/transforms/test_k_space_spike_noise.py +81 -0
  540. tests/transforms/test_k_space_spike_noised.py +98 -0
  541. tests/transforms/test_keep_largest_connected_component.py +419 -0
  542. tests/transforms/test_keep_largest_connected_componentd.py +348 -0
  543. tests/transforms/test_label_filter.py +78 -0
  544. tests/transforms/test_label_to_contour.py +179 -0
  545. tests/transforms/test_label_to_contourd.py +182 -0
  546. tests/transforms/test_label_to_mask.py +69 -0
  547. tests/transforms/test_label_to_maskd.py +70 -0
  548. tests/transforms/test_load_image.py +502 -0
  549. tests/transforms/test_load_imaged.py +198 -0
  550. tests/transforms/test_load_spacing_orientation.py +149 -0
  551. tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
  552. tests/transforms/test_map_binary_to_indices.py +75 -0
  553. tests/transforms/test_map_classes_to_indices.py +135 -0
  554. tests/transforms/test_map_label_value.py +89 -0
  555. tests/transforms/test_map_label_valued.py +85 -0
  556. tests/transforms/test_map_transform.py +45 -0
  557. tests/transforms/test_mask_intensity.py +74 -0
  558. tests/transforms/test_mask_intensityd.py +68 -0
  559. tests/transforms/test_mean_ensemble.py +77 -0
  560. tests/transforms/test_mean_ensembled.py +91 -0
  561. tests/transforms/test_median_smooth.py +41 -0
  562. tests/transforms/test_median_smoothd.py +65 -0
  563. tests/transforms/test_morphological_ops.py +101 -0
  564. tests/transforms/test_nifti_endianness.py +107 -0
  565. tests/transforms/test_normalize_intensity.py +143 -0
  566. tests/transforms/test_normalize_intensityd.py +81 -0
  567. tests/transforms/test_nvtx_decorator.py +289 -0
  568. tests/transforms/test_nvtx_transform.py +143 -0
  569. tests/transforms/test_orientation.py +247 -0
  570. tests/transforms/test_orientationd.py +112 -0
  571. tests/transforms/test_rand_adjust_contrast.py +45 -0
  572. tests/transforms/test_rand_adjust_contrastd.py +44 -0
  573. tests/transforms/test_rand_affine.py +201 -0
  574. tests/transforms/test_rand_affine_grid.py +212 -0
  575. tests/transforms/test_rand_affined.py +281 -0
  576. tests/transforms/test_rand_axis_flip.py +50 -0
  577. tests/transforms/test_rand_axis_flipd.py +50 -0
  578. tests/transforms/test_rand_bias_field.py +69 -0
  579. tests/transforms/test_rand_bias_fieldd.py +65 -0
  580. tests/transforms/test_rand_coarse_dropout.py +110 -0
  581. tests/transforms/test_rand_coarse_dropoutd.py +107 -0
  582. tests/transforms/test_rand_coarse_shuffle.py +65 -0
  583. tests/transforms/test_rand_coarse_shuffled.py +59 -0
  584. tests/transforms/test_rand_crop_by_label_classes.py +170 -0
  585. tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
  586. tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
  587. tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
  588. tests/transforms/test_rand_cucim_dict_transform.py +162 -0
  589. tests/transforms/test_rand_cucim_transform.py +162 -0
  590. tests/transforms/test_rand_deform_grid.py +138 -0
  591. tests/transforms/test_rand_elastic_2d.py +127 -0
  592. tests/transforms/test_rand_elastic_3d.py +104 -0
  593. tests/transforms/test_rand_elasticd_2d.py +177 -0
  594. tests/transforms/test_rand_elasticd_3d.py +156 -0
  595. tests/transforms/test_rand_flip.py +60 -0
  596. tests/transforms/test_rand_flipd.py +55 -0
  597. tests/transforms/test_rand_gaussian_noise.py +48 -0
  598. tests/transforms/test_rand_gaussian_noised.py +54 -0
  599. tests/transforms/test_rand_gaussian_sharpen.py +140 -0
  600. tests/transforms/test_rand_gaussian_sharpend.py +143 -0
  601. tests/transforms/test_rand_gaussian_smooth.py +98 -0
  602. tests/transforms/test_rand_gaussian_smoothd.py +98 -0
  603. tests/transforms/test_rand_gibbs_noise.py +103 -0
  604. tests/transforms/test_rand_gibbs_noised.py +117 -0
  605. tests/transforms/test_rand_grid_distortion.py +99 -0
  606. tests/transforms/test_rand_grid_distortiond.py +90 -0
  607. tests/transforms/test_rand_histogram_shift.py +92 -0
  608. tests/transforms/test_rand_k_space_spike_noise.py +92 -0
  609. tests/transforms/test_rand_k_space_spike_noised.py +76 -0
  610. tests/transforms/test_rand_rician_noise.py +52 -0
  611. tests/transforms/test_rand_rician_noised.py +52 -0
  612. tests/transforms/test_rand_rotate.py +166 -0
  613. tests/transforms/test_rand_rotate90.py +100 -0
  614. tests/transforms/test_rand_rotate90d.py +112 -0
  615. tests/transforms/test_rand_rotated.py +187 -0
  616. tests/transforms/test_rand_scale_crop.py +78 -0
  617. tests/transforms/test_rand_scale_cropd.py +98 -0
  618. tests/transforms/test_rand_scale_intensity.py +54 -0
  619. tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
  620. tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
  621. tests/transforms/test_rand_scale_intensityd.py +53 -0
  622. tests/transforms/test_rand_shift_intensity.py +52 -0
  623. tests/transforms/test_rand_shift_intensityd.py +67 -0
  624. tests/transforms/test_rand_simulate_low_resolution.py +83 -0
  625. tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
  626. tests/transforms/test_rand_spatial_crop.py +107 -0
  627. tests/transforms/test_rand_spatial_crop_samples.py +128 -0
  628. tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
  629. tests/transforms/test_rand_spatial_cropd.py +112 -0
  630. tests/transforms/test_rand_std_shift_intensity.py +43 -0
  631. tests/transforms/test_rand_std_shift_intensityd.py +38 -0
  632. tests/transforms/test_rand_zoom.py +105 -0
  633. tests/transforms/test_rand_zoomd.py +108 -0
  634. tests/transforms/test_randidentity.py +49 -0
  635. tests/transforms/test_random_order.py +144 -0
  636. tests/transforms/test_randtorchvisiond.py +65 -0
  637. tests/transforms/test_regularization.py +139 -0
  638. tests/transforms/test_remove_repeated_channel.py +34 -0
  639. tests/transforms/test_remove_repeated_channeld.py +44 -0
  640. tests/transforms/test_repeat_channel.py +34 -0
  641. tests/transforms/test_repeat_channeld.py +41 -0
  642. tests/transforms/test_resample_backends.py +65 -0
  643. tests/transforms/test_resample_to_match.py +110 -0
  644. tests/transforms/test_resample_to_matchd.py +93 -0
  645. tests/transforms/test_resampler.py +165 -0
  646. tests/transforms/test_resize.py +140 -0
  647. tests/transforms/test_resize_with_pad_or_crop.py +91 -0
  648. tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
  649. tests/transforms/test_resized.py +163 -0
  650. tests/transforms/test_rotate.py +160 -0
  651. tests/transforms/test_rotate90.py +212 -0
  652. tests/transforms/test_rotate90d.py +106 -0
  653. tests/transforms/test_rotated.py +179 -0
  654. tests/transforms/test_save_classificationd.py +109 -0
  655. tests/transforms/test_save_image.py +80 -0
  656. tests/transforms/test_save_imaged.py +130 -0
  657. tests/transforms/test_savitzky_golay_smooth.py +73 -0
  658. tests/transforms/test_savitzky_golay_smoothd.py +73 -0
  659. tests/transforms/test_scale_intensity.py +76 -0
  660. tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
  661. tests/transforms/test_scale_intensity_range.py +41 -0
  662. tests/transforms/test_scale_intensity_ranged.py +40 -0
  663. tests/transforms/test_scale_intensityd.py +57 -0
  664. tests/transforms/test_select_itemsd.py +41 -0
  665. tests/transforms/test_shift_intensity.py +31 -0
  666. tests/transforms/test_shift_intensityd.py +44 -0
  667. tests/transforms/test_signal_continuouswavelet.py +44 -0
  668. tests/transforms/test_signal_fillempty.py +52 -0
  669. tests/transforms/test_signal_fillemptyd.py +60 -0
  670. tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
  671. tests/transforms/test_signal_rand_add_sine.py +52 -0
  672. tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
  673. tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
  674. tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
  675. tests/transforms/test_signal_rand_drop.py +50 -0
  676. tests/transforms/test_signal_rand_scale.py +52 -0
  677. tests/transforms/test_signal_rand_shift.py +55 -0
  678. tests/transforms/test_signal_remove_frequency.py +71 -0
  679. tests/transforms/test_smooth_field.py +177 -0
  680. tests/transforms/test_sobel_gradient.py +189 -0
  681. tests/transforms/test_sobel_gradientd.py +212 -0
  682. tests/transforms/test_spacing.py +381 -0
  683. tests/transforms/test_spacingd.py +178 -0
  684. tests/transforms/test_spatial_crop.py +82 -0
  685. tests/transforms/test_spatial_cropd.py +74 -0
  686. tests/transforms/test_spatial_pad.py +57 -0
  687. tests/transforms/test_spatial_padd.py +43 -0
  688. tests/transforms/test_spatial_resample.py +235 -0
  689. tests/transforms/test_squeezedim.py +62 -0
  690. tests/transforms/test_squeezedimd.py +98 -0
  691. tests/transforms/test_std_shift_intensity.py +76 -0
  692. tests/transforms/test_std_shift_intensityd.py +74 -0
  693. tests/transforms/test_threshold_intensity.py +38 -0
  694. tests/transforms/test_threshold_intensityd.py +58 -0
  695. tests/transforms/test_to_contiguous.py +47 -0
  696. tests/transforms/test_to_cupy.py +112 -0
  697. tests/transforms/test_to_cupyd.py +76 -0
  698. tests/transforms/test_to_device.py +42 -0
  699. tests/transforms/test_to_deviced.py +37 -0
  700. tests/transforms/test_to_numpy.py +85 -0
  701. tests/transforms/test_to_numpyd.py +68 -0
  702. tests/transforms/test_to_pil.py +52 -0
  703. tests/transforms/test_to_pild.py +50 -0
  704. tests/transforms/test_to_tensor.py +60 -0
  705. tests/transforms/test_to_tensord.py +71 -0
  706. tests/transforms/test_torchvision.py +66 -0
  707. tests/transforms/test_torchvisiond.py +63 -0
  708. tests/transforms/test_transform.py +62 -0
  709. tests/transforms/test_transpose.py +41 -0
  710. tests/transforms/test_transposed.py +52 -0
  711. tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
  712. tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
  713. tests/transforms/test_vote_ensemble.py +84 -0
  714. tests/transforms/test_vote_ensembled.py +107 -0
  715. tests/transforms/test_with_allow_missing_keys.py +76 -0
  716. tests/transforms/test_zoom.py +120 -0
  717. tests/transforms/test_zoomd.py +94 -0
  718. tests/transforms/transform/__init__.py +10 -0
  719. tests/transforms/transform/test_randomizable.py +52 -0
  720. tests/transforms/transform/test_randomizable_transform_type.py +37 -0
  721. tests/transforms/utility/__init__.py +10 -0
  722. tests/transforms/utility/test_apply_transform_to_points.py +81 -0
  723. tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
  724. tests/transforms/utility/test_identity.py +29 -0
  725. tests/transforms/utility/test_identityd.py +30 -0
  726. tests/transforms/utility/test_lambda.py +71 -0
  727. tests/transforms/utility/test_lambdad.py +83 -0
  728. tests/transforms/utility/test_rand_lambda.py +87 -0
  729. tests/transforms/utility/test_rand_lambdad.py +77 -0
  730. tests/transforms/utility/test_simulatedelay.py +36 -0
  731. tests/transforms/utility/test_simulatedelayd.py +36 -0
  732. tests/transforms/utility/test_splitdim.py +52 -0
  733. tests/transforms/utility/test_splitdimd.py +96 -0
  734. tests/transforms/utils/__init__.py +10 -0
  735. tests/transforms/utils/test_correct_crop_centers.py +36 -0
  736. tests/transforms/utils/test_get_unique_labels.py +45 -0
  737. tests/transforms/utils/test_print_transform_backends.py +29 -0
  738. tests/transforms/utils/test_soft_clip.py +125 -0
  739. tests/utils/__init__.py +10 -0
  740. tests/utils/enums/__init__.py +10 -0
  741. tests/utils/enums/test_hovernet_loss.py +190 -0
  742. tests/utils/enums/test_ordering.py +289 -0
  743. tests/utils/enums/test_wsireader.py +663 -0
  744. tests/utils/misc/__init__.py +10 -0
  745. tests/utils/misc/test_ensure_tuple.py +53 -0
  746. tests/utils/misc/test_monai_env_vars.py +44 -0
  747. tests/utils/misc/test_monai_utils_misc.py +103 -0
  748. tests/utils/misc/test_str2bool.py +34 -0
  749. tests/utils/misc/test_str2list.py +33 -0
  750. tests/utils/test_alias.py +44 -0
  751. tests/utils/test_component_store.py +73 -0
  752. tests/utils/test_deprecated.py +455 -0
  753. tests/utils/test_enum_bound_interp.py +75 -0
  754. tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
  755. tests/utils/test_get_package_version.py +34 -0
  756. tests/utils/test_handler_logfile.py +84 -0
  757. tests/utils/test_handler_metric_logger.py +62 -0
  758. tests/utils/test_list_to_dict.py +43 -0
  759. tests/utils/test_look_up_option.py +87 -0
  760. tests/utils/test_optional_import.py +80 -0
  761. tests/utils/test_pad_mode.py +39 -0
  762. tests/utils/test_profiling.py +208 -0
  763. tests/utils/test_rankfilter_dist.py +77 -0
  764. tests/utils/test_require_pkg.py +83 -0
  765. tests/utils/test_sample_slices.py +43 -0
  766. tests/utils/test_set_determinism.py +74 -0
  767. tests/utils/test_squeeze_unsqueeze.py +71 -0
  768. tests/utils/test_state_cacher.py +67 -0
  769. tests/utils/test_torchscript_utils.py +113 -0
  770. tests/utils/test_version.py +91 -0
  771. tests/utils/test_version_after.py +65 -0
  772. tests/utils/type_conversion/__init__.py +10 -0
  773. tests/utils/type_conversion/test_convert_data_type.py +152 -0
  774. tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
  775. tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
  776. tests/visualize/__init__.py +10 -0
  777. tests/visualize/test_img2tensorboard.py +46 -0
  778. tests/visualize/test_occlusion_sensitivity.py +128 -0
  779. tests/visualize/test_plot_2d_or_3d_image.py +74 -0
  780. tests/visualize/test_vis_cam.py +98 -0
  781. tests/visualize/test_vis_gradcam.py +211 -0
  782. tests/visualize/utils/__init__.py +10 -0
  783. tests/visualize/utils/test_blend_images.py +63 -0
  784. tests/visualize/utils/test_matshow3d.py +133 -0
  785. monai_weekly-1.5.dev2506.dist-info/RECORD +0 -427
  786. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/LICENSE +0 -0
  787. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/WHEEL +0 -0
monai/__init__.py CHANGED
@@ -136,4 +136,4 @@ except BaseException:
136
136
 
137
137
  if MONAIEnvVars.debug():
138
138
  raise
139
- __commit_id__ = "8dcb9dc0e594059d87ce9882c20fe5a59340a6b2"
139
+ __commit_id__ = "a7905909e785d1ef24103c32a2d3a5a36e1059a2"
monai/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-02-09T02:26:31+0000",
11
+ "date": "2025-02-23T02:28:09+0000",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "5d0616a0fd56273619eb40ed1e6a22683d697f60",
15
- "version": "1.5.dev2506"
14
+ "full-revisionid": "e55b5cbfbbba1800a968a9c06b2deaaa5c9bec54",
15
+ "version": "1.5.dev2508"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -18,7 +18,6 @@ import numpy as np
18
18
  import torch
19
19
 
20
20
  from monai.config import KeysCollection
21
- from monai.networks.utils import pytorch_after
22
21
  from monai.transforms import MapTransform
23
22
  from monai.utils.misc import ImageMetaKey
24
23
 
@@ -74,9 +73,7 @@ class EnsureSameShaped(MapTransform):
74
73
  f", the metadata was not updated {filename}."
75
74
  )
76
75
  d[key] = torch.nn.functional.interpolate(
77
- input=d[key].unsqueeze(0),
78
- size=image_shape,
79
- mode="nearest-exact" if pytorch_after(1, 11) else "nearest",
76
+ input=d[key].unsqueeze(0), size=image_shape, mode="nearest-exact"
80
77
  ).squeeze(0)
81
78
  else:
82
79
  raise ValueError(
monai/data/utils.py CHANGED
@@ -50,7 +50,6 @@ from monai.utils import (
50
50
  issequenceiterable,
51
51
  look_up_option,
52
52
  optional_import,
53
- pytorch_after,
54
53
  )
55
54
 
56
55
  pd, _ = optional_import("pandas")
@@ -450,12 +449,9 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
450
449
  Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor`
451
450
  and so should not be used as a collate function directly in dataloaders.
452
451
  """
453
- if pytorch_after(1, 13):
454
- from torch.utils.data._utils.collate import collate_tensor_fn # imported here for pylint/mypy issues
452
+ from torch.utils.data._utils.collate import collate_tensor_fn # imported here for pylint/mypy issues
455
453
 
456
- collated = collate_tensor_fn(batch)
457
- else:
458
- collated = default_collate(batch)
454
+ collated = collate_tensor_fn(batch)
459
455
 
460
456
  meta_dicts = [i.meta or TraceKeys.NONE for i in batch]
461
457
  common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)])
@@ -494,18 +490,15 @@ def list_data_collate(batch: Sequence):
494
490
  Need to use this collate if apply some transforms that can generate batch data.
495
491
 
496
492
  """
493
+ from torch.utils.data._utils.collate import default_collate_fn_map
497
494
 
498
- if pytorch_after(1, 13):
499
- # needs to go here to avoid circular import
500
- from torch.utils.data._utils.collate import default_collate_fn_map
501
-
502
- from monai.data.meta_tensor import MetaTensor
495
+ from monai.data.meta_tensor import MetaTensor
503
496
 
504
- default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn})
497
+ default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn})
505
498
  elem = batch[0]
506
499
  data = [i for k in batch for i in k] if isinstance(elem, list) else batch
507
500
  key = None
508
- collate_fn = default_collate if pytorch_after(1, 13) else collate_meta_tensor
501
+ collate_fn = default_collate
509
502
  try:
510
503
  if config.USE_META_DICT:
511
504
  data = pickle_operations(data) # bc 0.9.0
@@ -11,6 +11,7 @@
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
+ from .average_precision import AveragePrecision
14
15
  from .checkpoint_loader import CheckpointLoader
15
16
  from .checkpoint_saver import CheckpointSaver
16
17
  from .classification_saver import ClassificationSaver
@@ -0,0 +1,53 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from collections.abc import Callable
15
+
16
+ from monai.handlers.ignite_metric import IgniteMetricHandler
17
+ from monai.metrics import AveragePrecisionMetric
18
+ from monai.utils import Average
19
+
20
+
21
+ class AveragePrecision(IgniteMetricHandler):
22
+ """
23
+ Computes Average Precision (AP).
24
+ accumulating predictions and the ground-truth during an epoch and applying `compute_average_precision`.
25
+
26
+ Args:
27
+ average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
28
+ Type of averaging performed if not binary classification. Defaults to ``"macro"``.
29
+
30
+ - ``"macro"``: calculate metrics for each label, and find their unweighted mean.
31
+ This does not take label imbalance into account.
32
+ - ``"weighted"``: calculate metrics for each label, and find their average,
33
+ weighted by support (the number of true instances for each label).
34
+ - ``"micro"``: calculate metrics globally by considering each element of the label
35
+ indicator matrix as a label.
36
+ - ``"none"``: the scores for each class are returned.
37
+
38
+ output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
39
+ construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
40
+ lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
41
+ `engine.state` and `output_transform` inherit from the ignite concept:
42
+ https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
43
+ https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
44
+
45
+ Note:
46
+ Average Precision expects y to be comprised of 0's and 1's.
47
+ y_pred must either be probability estimates or confidence values.
48
+
49
+ """
50
+
51
+ def __init__(self, average: Average | str = Average.MACRO, output_transform: Callable = lambda x: x) -> None:
52
+ metric_fn = AveragePrecisionMetric(average=Average(average))
53
+ super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False)
monai/inferers/inferer.py CHANGED
@@ -1202,15 +1202,16 @@ class LatentDiffusionInferer(DiffusionInferer):
1202
1202
 
1203
1203
  if self.autoencoder_latent_shape is not None:
1204
1204
  latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
1205
- latent_intermediates = [
1206
- torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
1207
- ]
1205
+ if save_intermediates:
1206
+ latent_intermediates = [
1207
+ torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
1208
+ for l in latent_intermediates
1209
+ ]
1208
1210
 
1209
1211
  decode = autoencoder_model.decode_stage_2_outputs
1210
1212
  if isinstance(autoencoder_model, SPADEAutoencoderKL):
1211
1213
  decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
1212
1214
  image = decode(latent / self.scale_factor)
1213
-
1214
1215
  if save_intermediates:
1215
1216
  intermediates = []
1216
1217
  for latent_intermediate in latent_intermediates:
@@ -1727,9 +1728,11 @@ class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer):
1727
1728
 
1728
1729
  if self.autoencoder_latent_shape is not None:
1729
1730
  latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
1730
- latent_intermediates = [
1731
- torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
1732
- ]
1731
+ if save_intermediates:
1732
+ latent_intermediates = [
1733
+ torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
1734
+ for l in latent_intermediates
1735
+ ]
1733
1736
 
1734
1737
  decode = autoencoder_model.decode_stage_2_outputs
1735
1738
  if isinstance(autoencoder_model, SPADEAutoencoderKL):
monai/inferers/utils.py CHANGED
@@ -31,11 +31,10 @@ from monai.utils import (
31
31
  fall_back_tuple,
32
32
  look_up_option,
33
33
  optional_import,
34
- pytorch_after,
35
34
  )
36
35
 
37
36
  tqdm, _ = optional_import("tqdm", name="tqdm")
38
- _nearest_mode = "nearest-exact" if pytorch_after(1, 11) else "nearest"
37
+ _nearest_mode = "nearest-exact"
39
38
 
40
39
  __all__ = ["sliding_window_inference"]
41
40
 
monai/losses/dice.py CHANGED
@@ -25,7 +25,7 @@ from monai.losses.focal_loss import FocalLoss
25
25
  from monai.losses.spatial_mask import MaskedLoss
26
26
  from monai.losses.utils import compute_tp_fp_fn
27
27
  from monai.networks import one_hot
28
- from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after
28
+ from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option
29
29
 
30
30
 
31
31
  class DiceLoss(_Loss):
@@ -738,12 +738,7 @@ class DiceCELoss(_Loss):
738
738
  batch=batch,
739
739
  weight=dice_weight,
740
740
  )
741
- if pytorch_after(1, 10):
742
- self.cross_entropy = nn.CrossEntropyLoss(
743
- weight=weight, reduction=reduction, label_smoothing=label_smoothing
744
- )
745
- else:
746
- self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
741
+ self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction, label_smoothing=label_smoothing)
747
742
  self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction)
748
743
  if lambda_dice < 0.0:
749
744
  raise ValueError("lambda_dice should be no less than 0.0.")
@@ -751,7 +746,6 @@ class DiceCELoss(_Loss):
751
746
  raise ValueError("lambda_ce should be no less than 0.0.")
752
747
  self.lambda_dice = lambda_dice
753
748
  self.lambda_ce = lambda_ce
754
- self.old_pt_ver = not pytorch_after(1, 10)
755
749
 
756
750
  def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
757
751
  """
@@ -764,12 +758,6 @@ class DiceCELoss(_Loss):
764
758
  if n_pred_ch != n_target_ch and n_target_ch == 1:
765
759
  target = torch.squeeze(target, dim=1)
766
760
  target = target.long()
767
- elif self.old_pt_ver:
768
- warnings.warn(
769
- f"Multichannel targets are not supported in this older Pytorch version {torch.__version__}. "
770
- "Using argmax (as a workaround) to convert target to a single channel."
771
- )
772
- target = torch.argmax(target, dim=1)
773
761
  elif not torch.is_floating_point(target):
774
762
  target = target.to(dtype=input.dtype)
775
763
 
monai/losses/ds_loss.py CHANGED
@@ -17,8 +17,6 @@ import torch
17
17
  import torch.nn.functional as F
18
18
  from torch.nn.modules.loss import _Loss
19
19
 
20
- from monai.utils import pytorch_after
21
-
22
20
 
23
21
  class DeepSupervisionLoss(_Loss):
24
22
  """
@@ -42,7 +40,7 @@ class DeepSupervisionLoss(_Loss):
42
40
  self.loss = loss
43
41
  self.weight_mode = weight_mode
44
42
  self.weights = weights
45
- self.interp_mode = "nearest-exact" if pytorch_after(1, 11) else "nearest"
43
+ self.interp_mode = "nearest-exact"
46
44
 
47
45
  def get_weights(self, levels: int = 1) -> list[float]:
48
46
  """
monai/metrics/__init__.py CHANGED
@@ -12,6 +12,7 @@
12
12
  from __future__ import annotations
13
13
 
14
14
  from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score
15
+ from .average_precision import AveragePrecisionMetric, compute_average_precision
15
16
  from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix
16
17
  from .cumulative_average import CumulativeAverage
17
18
  from .f_beta_score import FBetaScore
@@ -0,0 +1,187 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import warnings
15
+ from typing import TYPE_CHECKING, cast
16
+
17
+ import numpy as np
18
+
19
+ if TYPE_CHECKING:
20
+ import numpy.typing as npt
21
+
22
+ import torch
23
+
24
+ from monai.utils import Average, look_up_option
25
+
26
+ from .metric import CumulativeIterationMetric
27
+
28
+
29
+ class AveragePrecisionMetric(CumulativeIterationMetric):
30
+ """
31
+ Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are
32
+ imbalanced. It can take values between 0.0 and 1.0, 1.0 being the best possible score.
33
+ It summarizes a Precision-Recall curve as the weighted mean of precisions achieved at each
34
+ threshold, with the increase in recall from the previous threshold used as the weight:
35
+
36
+ .. math::
37
+ \\text{AP} = \\sum_n (R_n - R_{n-1}) P_n
38
+ :label: ap
39
+
40
+ where :math:`P_n` and :math:`R_n` are the precision and recall at the :math:`n^{th}` threshold.
41
+
42
+ Referring to: `sklearn.metrics.average_precision_score
43
+ <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.
44
+
45
+ The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor.
46
+
47
+ Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
48
+
49
+ Args:
50
+ average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
51
+ Type of averaging performed if not binary classification.
52
+ Defaults to ``"macro"``.
53
+
54
+ - ``"macro"``: calculate metrics for each label, and find their unweighted mean.
55
+ This does not take label imbalance into account.
56
+ - ``"weighted"``: calculate metrics for each label, and find their average,
57
+ weighted by support (the number of true instances for each label).
58
+ - ``"micro"``: calculate metrics globally by considering each element of the label
59
+ indicator matrix as a label.
60
+ - ``"none"``: the scores for each class are returned.
61
+
62
+ """
63
+
64
+ def __init__(self, average: Average | str = Average.MACRO) -> None:
65
+ super().__init__()
66
+ self.average = average
67
+
68
+ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override]
69
+ return y_pred, y
70
+
71
+ def aggregate(self, average: Average | str | None = None) -> np.ndarray | float | npt.ArrayLike:
72
+ """
73
+ Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration,
74
+ This function reads the buffers and computes the Average Precision.
75
+
76
+ Args:
77
+ average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
78
+ Type of averaging performed if not binary classification. Defaults to `self.average`.
79
+
80
+ """
81
+ y_pred, y = self.get_buffer()
82
+ # compute final value and do metric reduction
83
+ if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):
84
+ raise ValueError("y_pred and y must be PyTorch Tensor.")
85
+
86
+ return compute_average_precision(y_pred=y_pred, y=y, average=average or self.average)
87
+
88
+
89
+ def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float:
90
+ if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)):
91
+ raise AssertionError("y and y_pred must be 1 dimension data with same length.")
92
+ y_unique = y.unique()
93
+ if len(y_unique) == 1:
94
+ warnings.warn(f"y values can not be all {y_unique.item()}, skip AP computation and return `Nan`.")
95
+ return float("nan")
96
+ if not y_unique.equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)):
97
+ warnings.warn(f"y values must be 0 or 1, but in {y_unique.tolist()}, skip AP computation and return `Nan`.")
98
+ return float("nan")
99
+
100
+ n = len(y)
101
+ indices = y_pred.argsort(descending=True)
102
+ y = y[indices].cpu().numpy() # type: ignore[assignment]
103
+ y_pred = y_pred[indices].cpu().numpy() # type: ignore[assignment]
104
+ npos = ap = tmp_pos = 0.0
105
+
106
+ for i in range(n):
107
+ y_i = cast(float, y[i])
108
+ if i + 1 < n and y_pred[i] == y_pred[i + 1]:
109
+ tmp_pos += y_i
110
+ else:
111
+ tmp_pos += y_i
112
+ npos += tmp_pos
113
+ ap += tmp_pos * npos / (i + 1)
114
+ tmp_pos = 0
115
+
116
+ return ap / npos
117
+
118
+
119
+ def compute_average_precision(
120
+ y_pred: torch.Tensor, y: torch.Tensor, average: Average | str = Average.MACRO
121
+ ) -> np.ndarray | float | npt.ArrayLike:
122
+ """Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are
123
+ imbalanced. It summarizes a Precision-Recall according to equation :eq:`ap`.
124
+ Referring to: `sklearn.metrics.average_precision_score
125
+ <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.
126
+
127
+ Args:
128
+ y_pred: input data to compute, typical classification model output.
129
+ the first dim must be batch, if multi-classes, it must be in One-Hot format.
130
+ for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
131
+ y: ground truth to compute AP metric, the first dim must be batch.
132
+ if multi-classes, it must be in One-Hot format.
133
+ for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
134
+ average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
135
+ Type of averaging performed if not binary classification.
136
+ Defaults to ``"macro"``.
137
+
138
+ - ``"macro"``: calculate metrics for each label, and find their unweighted mean.
139
+ This does not take label imbalance into account.
140
+ - ``"weighted"``: calculate metrics for each label, and find their average,
141
+ weighted by support (the number of true instances for each label).
142
+ - ``"micro"``: calculate metrics globally by considering each element of the label
143
+ indicator matrix as a label.
144
+ - ``"none"``: the scores for each class are returned.
145
+
146
+ Raises:
147
+ ValueError: When ``y_pred`` dimension is not one of [1, 2].
148
+ ValueError: When ``y`` dimension is not one of [1, 2].
149
+ ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"].
150
+
151
+ Note:
152
+ Average Precision expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.
153
+
154
+ """
155
+ y_pred_ndim = y_pred.ndimension()
156
+ y_ndim = y.ndimension()
157
+ if y_pred_ndim not in (1, 2):
158
+ raise ValueError(
159
+ f"Predictions should be of shape (batch_size, num_classes) or (batch_size, ), got {y_pred.shape}."
160
+ )
161
+ if y_ndim not in (1, 2):
162
+ raise ValueError(f"Targets should be of shape (batch_size, num_classes) or (batch_size, ), got {y.shape}.")
163
+ if y_pred_ndim == 2 and y_pred.shape[1] == 1:
164
+ y_pred = y_pred.squeeze(dim=-1)
165
+ y_pred_ndim = 1
166
+ if y_ndim == 2 and y.shape[1] == 1:
167
+ y = y.squeeze(dim=-1)
168
+
169
+ if y_pred_ndim == 1:
170
+ return _calculate(y_pred, y)
171
+
172
+ if y.shape != y_pred.shape:
173
+ raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.")
174
+
175
+ average = look_up_option(average, Average)
176
+ if average == Average.MICRO:
177
+ return _calculate(y_pred.flatten(), y.flatten())
178
+ y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)
179
+ ap_values = [_calculate(y_pred_, y_) for y_pred_, y_ in zip(y_pred, y)]
180
+ if average == Average.NONE:
181
+ return ap_values
182
+ if average == Average.MACRO:
183
+ return np.mean(ap_values)
184
+ if average == Average.WEIGHTED:
185
+ weights = [sum(y_) for y_ in y]
186
+ return np.average(ap_values, weights=weights) # type: ignore[no-any-return]
187
+ raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].')
@@ -31,7 +31,6 @@ from monai.utils import (
31
31
  issequenceiterable,
32
32
  look_up_option,
33
33
  optional_import,
34
- pytorch_after,
35
34
  )
36
35
 
37
36
  _C, _ = optional_import("monai._C")
@@ -293,14 +292,7 @@ def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tenso
293
292
  x = x.view(1, kernel.shape[0], *spatials)
294
293
  conv = [F.conv1d, F.conv2d, F.conv3d][n_spatial - 1]
295
294
  if "padding" not in kwargs:
296
- if pytorch_after(1, 10):
297
- kwargs["padding"] = "same"
298
- else:
299
- # even-sized kernels are not supported
300
- kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
301
- elif kwargs["padding"] == "same" and not pytorch_after(1, 10):
302
- # even-sized kernels are not supported
303
- kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
295
+ kwargs["padding"] = "same"
304
296
 
305
297
  if "stride" not in kwargs:
306
298
  kwargs["stride"] = 1
@@ -372,11 +364,7 @@ class SavitzkyGolayFilter(nn.Module):
372
364
  a = idx ** torch.arange(order + 1, dtype=torch.float, device="cpu").reshape(-1, 1)
373
365
  y = torch.zeros(order + 1, dtype=torch.float, device="cpu")
374
366
  y[0] = 1.0
375
- return (
376
- torch.lstsq(y, a).solution.squeeze() # type: ignore
377
- if not pytorch_after(1, 11)
378
- else torch.linalg.lstsq(a, y).solution.squeeze()
379
- )
367
+ return torch.linalg.lstsq(a, y).solution.squeeze()
380
368
 
381
369
 
382
370
  class HilbertTransform(nn.Module):
monai/networks/utils.py CHANGED
@@ -31,7 +31,7 @@ import torch.nn as nn
31
31
  from monai.apps.utils import get_logger
32
32
  from monai.config import PathLike
33
33
  from monai.utils.misc import ensure_tuple, save_obj, set_determinism
34
- from monai.utils.module import look_up_option, optional_import, pytorch_after
34
+ from monai.utils.module import look_up_option, optional_import
35
35
  from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor
36
36
 
37
37
  onnx, _ = optional_import("onnx")
@@ -676,15 +676,6 @@ def convert_to_onnx(
676
676
  torch_versioned_kwargs["verify"] = verify
677
677
  verify = False
678
678
  else:
679
- if not pytorch_after(1, 10):
680
- if "example_outputs" not in kwargs:
681
- # https://github.com/pytorch/pytorch/blob/release/1.9/torch/onnx/__init__.py#L182
682
- raise TypeError(
683
- "example_outputs is required in scripting mode before PyTorch 1.10."
684
- "Please provide example outputs or use trace mode to export onnx model."
685
- )
686
- torch_versioned_kwargs["example_outputs"] = kwargs["example_outputs"]
687
- del kwargs["example_outputs"]
688
679
  mode_to_export = torch.jit.script(model, **kwargs)
689
680
 
690
681
  if torch.is_tensor(inputs) or isinstance(inputs, dict):
@@ -746,8 +737,7 @@ def convert_to_onnx(
746
737
  # compare onnx/ort and PyTorch results
747
738
  for r1, r2 in zip(torch_out, onnx_out):
748
739
  if isinstance(r1, torch.Tensor):
749
- assert_fn = torch.testing.assert_close if pytorch_after(1, 11) else torch.testing.assert_allclose
750
- assert_fn(r1.cpu(), convert_to_tensor(r2, dtype=r1.dtype), rtol=rtol, atol=atol) # type: ignore
740
+ torch.testing.assert_close(r1.cpu(), convert_to_tensor(r2, dtype=r1.dtype), rtol=rtol, atol=atol) # type: ignore
751
741
 
752
742
  return onnx_model
753
743
 
@@ -817,8 +807,7 @@ def convert_to_torchscript(
817
807
  # compare TorchScript and PyTorch results
818
808
  for r1, r2 in zip(torch_out, torchscript_out):
819
809
  if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor):
820
- assert_fn = torch.testing.assert_close if pytorch_after(1, 11) else torch.testing.assert_allclose
821
- assert_fn(r1, r2, rtol=rtol, atol=atol) # type: ignore
810
+ torch.testing.assert_close(r1, r2, rtol=rtol, atol=atol) # type: ignore
822
811
 
823
812
  return script_module
824
813
 
@@ -1031,8 +1020,7 @@ def convert_to_trt(
1031
1020
  # compare TorchScript and PyTorch results
1032
1021
  for r1, r2 in zip(torch_out, trt_out):
1033
1022
  if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor):
1034
- assert_fn = torch.testing.assert_close if pytorch_after(1, 11) else torch.testing.assert_allclose
1035
- assert_fn(r1, r2, rtol=rtol, atol=atol) # type: ignore
1023
+ torch.testing.assert_close(r1, r2, rtol=rtol, atol=atol) # type: ignore
1036
1024
 
1037
1025
  return trt_model
1038
1026
 
@@ -47,7 +47,7 @@ __all__ = ["Compose", "OneOf", "RandomOrder", "SomeOf", "execute_compose"]
47
47
  def execute_compose(
48
48
  data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
49
49
  transforms: Sequence[Any],
50
- map_items: bool = True,
50
+ map_items: bool | int = True,
51
51
  unpack_items: bool = False,
52
52
  start: int = 0,
53
53
  end: int | None = None,
@@ -65,8 +65,13 @@ def execute_compose(
65
65
  Args:
66
66
  data: a tensor-like object to be transformed
67
67
  transforms: a sequence of transforms to be carried out
68
- map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
69
- defaults to `True`.
68
+ map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
69
+ it can behave as follows:
70
+ - Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
71
+ to the first level of items in `data`.
72
+ - If an integer is provided, it specifies the maximum level of nesting to which the transformation
73
+ should be recursively applied. This allows treating multi-sample transforms applied after another
74
+ multi-sample transform while controlling how deep the mapping goes.
70
75
  unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
71
76
  defaults to `False`.
72
77
  start: the index of the first transform to be executed. If not set, this defaults to 0
@@ -205,8 +210,14 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):
205
210
 
206
211
  Args:
207
212
  transforms: sequence of callables.
208
- map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
209
- defaults to `True`.
213
+ map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
214
+ it can behave as follows:
215
+
216
+ - Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
217
+ to the first level of items in `data`.
218
+ - If an integer is provided, it specifies the maximum level of nesting to which the transformation
219
+ should be recursively applied. This allows treating multi-sample transforms applied after another
220
+ multi-sample transform while controlling how deep the mapping goes.
210
221
  unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
211
222
  defaults to `False`.
212
223
  log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.
@@ -227,7 +238,7 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):
227
238
  def __init__(
228
239
  self,
229
240
  transforms: Sequence[Callable] | Callable | None = None,
230
- map_items: bool = True,
241
+ map_items: bool | int = True,
231
242
  unpack_items: bool = False,
232
243
  log_stats: bool | str = False,
233
244
  lazy: bool | None = False,
@@ -238,9 +249,9 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):
238
249
  if transforms is None:
239
250
  transforms = []
240
251
 
241
- if not isinstance(map_items, bool):
252
+ if not isinstance(map_items, (bool, int)):
242
253
  raise ValueError(
243
- f"Argument 'map_items' should be boolean. Got {type(map_items)}."
254
+ f"Argument 'map_items' should be boolean or int. Got {type(map_items)}."
244
255
  "Check brackets when passing a sequence of callables."
245
256
  )
246
257
 
@@ -391,8 +402,14 @@ class OneOf(Compose):
391
402
  transforms: sequence of callables.
392
403
  weights: probabilities corresponding to each callable in transforms.
393
404
  Probabilities are normalized to sum to one.
394
- map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
395
- defaults to `True`.
405
+ map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
406
+ it can behave as follows:
407
+
408
+ - Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
409
+ to the first level of items in `data`.
410
+ - If an integer is provided, it specifies the maximum level of nesting to which the transformation
411
+ should be recursively applied. This allows treating multi-sample transforms applied after another
412
+ multi-sample transform while controlling how deep the mapping goes.
396
413
  unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
397
414
  defaults to `False`.
398
415
  log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.
@@ -414,7 +431,7 @@ class OneOf(Compose):
414
431
  self,
415
432
  transforms: Sequence[Callable] | Callable | None = None,
416
433
  weights: Sequence[float] | float | None = None,
417
- map_items: bool = True,
434
+ map_items: bool | int = True,
418
435
  unpack_items: bool = False,
419
436
  log_stats: bool | str = False,
420
437
  lazy: bool | None = False,
@@ -56,7 +56,6 @@ from monai.utils import (
56
56
  ensure_tuple_rep,
57
57
  fall_back_tuple,
58
58
  look_up_option,
59
- pytorch_after,
60
59
  )
61
60
 
62
61
  __all__ = [
@@ -392,11 +391,7 @@ class Crop(InvertibleTransform, LazyTransform):
392
391
  roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True, device="cpu")
393
392
  roi_size_t = convert_to_tensor(data=roi_size, dtype=torch.int16, wrap_sequence=True, device="cpu")
394
393
  _zeros = torch.zeros_like(roi_center_t)
395
- half = (
396
- torch.divide(roi_size_t, 2, rounding_mode="floor")
397
- if pytorch_after(1, 8)
398
- else torch.floor_divide(roi_size_t, 2)
399
- )
394
+ half = torch.divide(roi_size_t, 2, rounding_mode="floor")
400
395
  roi_start_t = torch.maximum(roi_center_t - half, _zeros)
401
396
  roi_end_t = torch.maximum(roi_start_t + roi_size_t, roi_start_t)
402
397
  else: