monai-weekly 1.5.dev2506__py3-none-any.whl → 1.5.dev2507__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (776) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/transforms.py +1 -4
  4. monai/data/utils.py +6 -13
  5. monai/inferers/utils.py +1 -2
  6. monai/losses/dice.py +2 -14
  7. monai/losses/ds_loss.py +1 -3
  8. monai/networks/layers/simplelayers.py +2 -14
  9. monai/networks/utils.py +4 -16
  10. monai/transforms/compose.py +28 -11
  11. monai/transforms/croppad/array.py +1 -6
  12. monai/transforms/io/array.py +0 -1
  13. monai/transforms/transform.py +15 -6
  14. monai/transforms/utils.py +1 -2
  15. monai/utils/tf32.py +0 -10
  16. monai/visualize/class_activation_maps.py +5 -8
  17. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/METADATA +2 -2
  18. monai_weekly-1.5.dev2507.dist-info/RECORD +1181 -0
  19. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/top_level.txt +1 -0
  20. tests/apps/__init__.py +10 -0
  21. tests/apps/deepedit/__init__.py +10 -0
  22. tests/apps/deepedit/test_deepedit_transforms.py +314 -0
  23. tests/apps/deepgrow/__init__.py +10 -0
  24. tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
  25. tests/apps/deepgrow/transforms/__init__.py +10 -0
  26. tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
  27. tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
  28. tests/apps/detection/__init__.py +10 -0
  29. tests/apps/detection/metrics/__init__.py +10 -0
  30. tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
  31. tests/apps/detection/networks/__init__.py +10 -0
  32. tests/apps/detection/networks/test_retinanet.py +210 -0
  33. tests/apps/detection/networks/test_retinanet_detector.py +203 -0
  34. tests/apps/detection/test_box_transform.py +370 -0
  35. tests/apps/detection/utils/__init__.py +10 -0
  36. tests/apps/detection/utils/test_anchor_box.py +88 -0
  37. tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
  38. tests/apps/detection/utils/test_box_coder.py +43 -0
  39. tests/apps/detection/utils/test_detector_boxselector.py +67 -0
  40. tests/apps/detection/utils/test_detector_utils.py +96 -0
  41. tests/apps/detection/utils/test_hardnegsampler.py +54 -0
  42. tests/apps/nuclick/__init__.py +10 -0
  43. tests/apps/nuclick/test_nuclick_transforms.py +259 -0
  44. tests/apps/pathology/__init__.py +10 -0
  45. tests/apps/pathology/handlers/__init__.py +10 -0
  46. tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
  47. tests/apps/pathology/test_lesion_froc.py +333 -0
  48. tests/apps/pathology/test_pathology_prob_nms.py +55 -0
  49. tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
  50. tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
  51. tests/apps/pathology/transforms/__init__.py +10 -0
  52. tests/apps/pathology/transforms/post/__init__.py +10 -0
  53. tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
  54. tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
  55. tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
  56. tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
  57. tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
  58. tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
  59. tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
  60. tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
  61. tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
  62. tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
  63. tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
  64. tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
  65. tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
  66. tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
  67. tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
  68. tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
  69. tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
  70. tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
  71. tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
  72. tests/apps/pathology/transforms/post/test_watershed.py +60 -0
  73. tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
  74. tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
  75. tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
  76. tests/apps/reconstruction/__init__.py +10 -0
  77. tests/apps/reconstruction/nets/__init__.py +10 -0
  78. tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
  79. tests/apps/reconstruction/test_complex_utils.py +77 -0
  80. tests/apps/reconstruction/test_fastmri_reader.py +82 -0
  81. tests/apps/reconstruction/test_mri_utils.py +37 -0
  82. tests/apps/reconstruction/transforms/__init__.py +10 -0
  83. tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
  84. tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
  85. tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
  86. tests/apps/test_auto3dseg_bundlegen.py +156 -0
  87. tests/apps/test_check_hash.py +53 -0
  88. tests/apps/test_cross_validation.py +74 -0
  89. tests/apps/test_decathlondataset.py +93 -0
  90. tests/apps/test_download_and_extract.py +70 -0
  91. tests/apps/test_download_url_yandex.py +45 -0
  92. tests/apps/test_mednistdataset.py +72 -0
  93. tests/apps/test_mmar_download.py +154 -0
  94. tests/apps/test_tciadataset.py +123 -0
  95. tests/apps/vista3d/__init__.py +10 -0
  96. tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
  97. tests/apps/vista3d/test_vista3d_sampler.py +100 -0
  98. tests/apps/vista3d/test_vista3d_transforms.py +94 -0
  99. tests/bundle/__init__.py +10 -0
  100. tests/bundle/test_bundle_ckpt_export.py +107 -0
  101. tests/bundle/test_bundle_download.py +435 -0
  102. tests/bundle/test_bundle_get_data.py +94 -0
  103. tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
  104. tests/bundle/test_bundle_trt_export.py +147 -0
  105. tests/bundle/test_bundle_utils.py +149 -0
  106. tests/bundle/test_bundle_verify_metadata.py +66 -0
  107. tests/bundle/test_bundle_verify_net.py +76 -0
  108. tests/bundle/test_bundle_workflow.py +272 -0
  109. tests/bundle/test_component_locator.py +38 -0
  110. tests/bundle/test_config_item.py +138 -0
  111. tests/bundle/test_config_parser.py +392 -0
  112. tests/bundle/test_reference_resolver.py +114 -0
  113. tests/config/__init__.py +10 -0
  114. tests/config/test_cv2_dist.py +53 -0
  115. tests/engines/__init__.py +10 -0
  116. tests/engines/test_ensemble_evaluator.py +94 -0
  117. tests/engines/test_prepare_batch_default.py +76 -0
  118. tests/engines/test_prepare_batch_default_dist.py +76 -0
  119. tests/engines/test_prepare_batch_diffusion.py +104 -0
  120. tests/engines/test_prepare_batch_extra_input.py +80 -0
  121. tests/fl/__init__.py +10 -0
  122. tests/fl/monai_algo/__init__.py +10 -0
  123. tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
  124. tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
  125. tests/fl/test_fl_monai_algo_stats.py +81 -0
  126. tests/fl/utils/__init__.py +10 -0
  127. tests/fl/utils/test_fl_exchange_object.py +63 -0
  128. tests/handlers/__init__.py +10 -0
  129. tests/handlers/test_handler_checkpoint_loader.py +182 -0
  130. tests/handlers/test_handler_checkpoint_saver.py +233 -0
  131. tests/handlers/test_handler_classification_saver.py +64 -0
  132. tests/handlers/test_handler_classification_saver_dist.py +77 -0
  133. tests/handlers/test_handler_clearml_image.py +65 -0
  134. tests/handlers/test_handler_clearml_stats.py +65 -0
  135. tests/handlers/test_handler_confusion_matrix.py +104 -0
  136. tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
  137. tests/handlers/test_handler_decollate_batch.py +66 -0
  138. tests/handlers/test_handler_early_stop.py +68 -0
  139. tests/handlers/test_handler_garbage_collector.py +73 -0
  140. tests/handlers/test_handler_hausdorff_distance.py +111 -0
  141. tests/handlers/test_handler_ignite_metric.py +191 -0
  142. tests/handlers/test_handler_lr_scheduler.py +94 -0
  143. tests/handlers/test_handler_mean_dice.py +98 -0
  144. tests/handlers/test_handler_mean_iou.py +76 -0
  145. tests/handlers/test_handler_metrics_reloaded.py +149 -0
  146. tests/handlers/test_handler_metrics_saver.py +89 -0
  147. tests/handlers/test_handler_metrics_saver_dist.py +120 -0
  148. tests/handlers/test_handler_mlflow.py +296 -0
  149. tests/handlers/test_handler_nvtx.py +93 -0
  150. tests/handlers/test_handler_panoptic_quality.py +89 -0
  151. tests/handlers/test_handler_parameter_scheduler.py +136 -0
  152. tests/handlers/test_handler_post_processing.py +74 -0
  153. tests/handlers/test_handler_prob_map_producer.py +111 -0
  154. tests/handlers/test_handler_regression_metrics.py +160 -0
  155. tests/handlers/test_handler_regression_metrics_dist.py +245 -0
  156. tests/handlers/test_handler_rocauc.py +48 -0
  157. tests/handlers/test_handler_rocauc_dist.py +54 -0
  158. tests/handlers/test_handler_stats.py +281 -0
  159. tests/handlers/test_handler_surface_distance.py +113 -0
  160. tests/handlers/test_handler_tb_image.py +61 -0
  161. tests/handlers/test_handler_tb_stats.py +166 -0
  162. tests/handlers/test_handler_validation.py +59 -0
  163. tests/handlers/test_trt_compile.py +145 -0
  164. tests/handlers/test_write_metrics_reports.py +68 -0
  165. tests/inferers/__init__.py +10 -0
  166. tests/inferers/test_avg_merger.py +179 -0
  167. tests/inferers/test_controlnet_inferers.py +1310 -0
  168. tests/inferers/test_diffusion_inferer.py +236 -0
  169. tests/inferers/test_latent_diffusion_inferer.py +824 -0
  170. tests/inferers/test_patch_inferer.py +309 -0
  171. tests/inferers/test_saliency_inferer.py +55 -0
  172. tests/inferers/test_slice_inferer.py +57 -0
  173. tests/inferers/test_sliding_window_inference.py +377 -0
  174. tests/inferers/test_sliding_window_splitter.py +284 -0
  175. tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
  176. tests/inferers/test_zarr_avg_merger.py +326 -0
  177. tests/integration/__init__.py +10 -0
  178. tests/integration/test_auto3dseg_ensemble.py +211 -0
  179. tests/integration/test_auto3dseg_hpo.py +189 -0
  180. tests/integration/test_deepedit_interaction.py +122 -0
  181. tests/integration/test_downsample_block.py +50 -0
  182. tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
  183. tests/integration/test_integration_autorunner.py +201 -0
  184. tests/integration/test_integration_bundle_run.py +240 -0
  185. tests/integration/test_integration_classification_2d.py +282 -0
  186. tests/integration/test_integration_determinism.py +95 -0
  187. tests/integration/test_integration_fast_train.py +231 -0
  188. tests/integration/test_integration_gpu_customization.py +159 -0
  189. tests/integration/test_integration_lazy_samples.py +219 -0
  190. tests/integration/test_integration_nnunetv2_runner.py +96 -0
  191. tests/integration/test_integration_segmentation_3d.py +304 -0
  192. tests/integration/test_integration_sliding_window.py +100 -0
  193. tests/integration/test_integration_stn.py +133 -0
  194. tests/integration/test_integration_unet_2d.py +67 -0
  195. tests/integration/test_integration_workers.py +61 -0
  196. tests/integration/test_integration_workflows.py +365 -0
  197. tests/integration/test_integration_workflows_adversarial.py +173 -0
  198. tests/integration/test_integration_workflows_gan.py +158 -0
  199. tests/integration/test_loader_semaphore.py +48 -0
  200. tests/integration/test_mapping_filed.py +122 -0
  201. tests/integration/test_meta_affine.py +183 -0
  202. tests/integration/test_metatensor_integration.py +114 -0
  203. tests/integration/test_module_list.py +76 -0
  204. tests/integration/test_one_of.py +283 -0
  205. tests/integration/test_pad_collation.py +124 -0
  206. tests/integration/test_reg_loss_integration.py +107 -0
  207. tests/integration/test_retinanet_predict_utils.py +154 -0
  208. tests/integration/test_seg_loss_integration.py +159 -0
  209. tests/integration/test_spatial_combine_transforms.py +185 -0
  210. tests/integration/test_testtimeaugmentation.py +186 -0
  211. tests/integration/test_vis_gradbased.py +69 -0
  212. tests/integration/test_vista3d_utils.py +159 -0
  213. tests/losses/__init__.py +10 -0
  214. tests/losses/deform/__init__.py +10 -0
  215. tests/losses/deform/test_bending_energy.py +88 -0
  216. tests/losses/deform/test_diffusion_loss.py +117 -0
  217. tests/losses/image_dissimilarity/__init__.py +10 -0
  218. tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
  219. tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
  220. tests/losses/test_adversarial_loss.py +94 -0
  221. tests/losses/test_barlow_twins_loss.py +109 -0
  222. tests/losses/test_cldice_loss.py +51 -0
  223. tests/losses/test_contrastive_loss.py +86 -0
  224. tests/losses/test_dice_ce_loss.py +123 -0
  225. tests/losses/test_dice_focal_loss.py +124 -0
  226. tests/losses/test_dice_loss.py +227 -0
  227. tests/losses/test_ds_loss.py +189 -0
  228. tests/losses/test_focal_loss.py +379 -0
  229. tests/losses/test_generalized_dice_focal_loss.py +85 -0
  230. tests/losses/test_generalized_dice_loss.py +221 -0
  231. tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
  232. tests/losses/test_giou_loss.py +62 -0
  233. tests/losses/test_hausdorff_loss.py +264 -0
  234. tests/losses/test_masked_dice_loss.py +152 -0
  235. tests/losses/test_masked_loss.py +87 -0
  236. tests/losses/test_multi_scale.py +86 -0
  237. tests/losses/test_nacl_loss.py +167 -0
  238. tests/losses/test_perceptual_loss.py +122 -0
  239. tests/losses/test_spectral_loss.py +86 -0
  240. tests/losses/test_ssim_loss.py +59 -0
  241. tests/losses/test_sure_loss.py +72 -0
  242. tests/losses/test_tversky_loss.py +198 -0
  243. tests/losses/test_unified_focal_loss.py +66 -0
  244. tests/metrics/__init__.py +10 -0
  245. tests/metrics/test_compute_confusion_matrix.py +294 -0
  246. tests/metrics/test_compute_f_beta.py +80 -0
  247. tests/metrics/test_compute_fid_metric.py +40 -0
  248. tests/metrics/test_compute_froc.py +143 -0
  249. tests/metrics/test_compute_generalized_dice.py +240 -0
  250. tests/metrics/test_compute_meandice.py +306 -0
  251. tests/metrics/test_compute_meaniou.py +223 -0
  252. tests/metrics/test_compute_mmd_metric.py +56 -0
  253. tests/metrics/test_compute_multiscalessim_metric.py +83 -0
  254. tests/metrics/test_compute_panoptic_quality.py +113 -0
  255. tests/metrics/test_compute_regression_metrics.py +196 -0
  256. tests/metrics/test_compute_roc_auc.py +155 -0
  257. tests/metrics/test_compute_variance.py +147 -0
  258. tests/metrics/test_cumulative.py +63 -0
  259. tests/metrics/test_cumulative_average.py +74 -0
  260. tests/metrics/test_cumulative_average_dist.py +48 -0
  261. tests/metrics/test_hausdorff_distance.py +209 -0
  262. tests/metrics/test_label_quality_score.py +134 -0
  263. tests/metrics/test_loss_metric.py +57 -0
  264. tests/metrics/test_metrics_reloaded.py +96 -0
  265. tests/metrics/test_ssim_metric.py +78 -0
  266. tests/metrics/test_surface_dice.py +416 -0
  267. tests/metrics/test_surface_distance.py +186 -0
  268. tests/networks/__init__.py +10 -0
  269. tests/networks/blocks/__init__.py +10 -0
  270. tests/networks/blocks/dints_block/__init__.py +10 -0
  271. tests/networks/blocks/dints_block/test_acn_block.py +41 -0
  272. tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
  273. tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
  274. tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
  275. tests/networks/blocks/test_adn.py +86 -0
  276. tests/networks/blocks/test_convolutions.py +156 -0
  277. tests/networks/blocks/test_crf_cpu.py +513 -0
  278. tests/networks/blocks/test_crf_cuda.py +528 -0
  279. tests/networks/blocks/test_crossattention.py +185 -0
  280. tests/networks/blocks/test_denseblock.py +105 -0
  281. tests/networks/blocks/test_dynunet_block.py +116 -0
  282. tests/networks/blocks/test_fpn_block.py +88 -0
  283. tests/networks/blocks/test_localnet_block.py +121 -0
  284. tests/networks/blocks/test_mlp.py +78 -0
  285. tests/networks/blocks/test_patchembedding.py +212 -0
  286. tests/networks/blocks/test_regunet_block.py +103 -0
  287. tests/networks/blocks/test_se_block.py +85 -0
  288. tests/networks/blocks/test_se_blocks.py +78 -0
  289. tests/networks/blocks/test_segresnet_block.py +57 -0
  290. tests/networks/blocks/test_selfattention.py +232 -0
  291. tests/networks/blocks/test_simple_aspp.py +87 -0
  292. tests/networks/blocks/test_spatialattention.py +55 -0
  293. tests/networks/blocks/test_subpixel_upsample.py +87 -0
  294. tests/networks/blocks/test_text_encoding.py +49 -0
  295. tests/networks/blocks/test_transformerblock.py +90 -0
  296. tests/networks/blocks/test_unetr_block.py +158 -0
  297. tests/networks/blocks/test_upsample_block.py +134 -0
  298. tests/networks/blocks/warp/__init__.py +10 -0
  299. tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
  300. tests/networks/blocks/warp/test_warp.py +250 -0
  301. tests/networks/layers/__init__.py +10 -0
  302. tests/networks/layers/filtering/__init__.py +10 -0
  303. tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
  304. tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
  305. tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
  306. tests/networks/layers/filtering/test_phl_cpu.py +259 -0
  307. tests/networks/layers/filtering/test_phl_cuda.py +167 -0
  308. tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
  309. tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
  310. tests/networks/layers/test_affine_transform.py +385 -0
  311. tests/networks/layers/test_apply_filter.py +89 -0
  312. tests/networks/layers/test_channel_pad.py +51 -0
  313. tests/networks/layers/test_conjugate_gradient.py +56 -0
  314. tests/networks/layers/test_drop_path.py +46 -0
  315. tests/networks/layers/test_gaussian.py +317 -0
  316. tests/networks/layers/test_gaussian_filter.py +206 -0
  317. tests/networks/layers/test_get_layers.py +65 -0
  318. tests/networks/layers/test_gmm.py +314 -0
  319. tests/networks/layers/test_grid_pull.py +93 -0
  320. tests/networks/layers/test_hilbert_transform.py +131 -0
  321. tests/networks/layers/test_lltm.py +62 -0
  322. tests/networks/layers/test_median_filter.py +52 -0
  323. tests/networks/layers/test_polyval.py +55 -0
  324. tests/networks/layers/test_preset_filters.py +136 -0
  325. tests/networks/layers/test_savitzky_golay_filter.py +141 -0
  326. tests/networks/layers/test_separable_filter.py +87 -0
  327. tests/networks/layers/test_skip_connection.py +48 -0
  328. tests/networks/layers/test_vector_quantizer.py +89 -0
  329. tests/networks/layers/test_weight_init.py +50 -0
  330. tests/networks/nets/__init__.py +10 -0
  331. tests/networks/nets/dints/__init__.py +10 -0
  332. tests/networks/nets/dints/test_dints_cell.py +110 -0
  333. tests/networks/nets/dints/test_dints_mixop.py +84 -0
  334. tests/networks/nets/regunet/__init__.py +10 -0
  335. tests/networks/nets/regunet/test_localnet.py +86 -0
  336. tests/networks/nets/regunet/test_regunet.py +88 -0
  337. tests/networks/nets/test_ahnet.py +224 -0
  338. tests/networks/nets/test_attentionunet.py +88 -0
  339. tests/networks/nets/test_autoencoder.py +95 -0
  340. tests/networks/nets/test_autoencoderkl.py +337 -0
  341. tests/networks/nets/test_basic_unet.py +102 -0
  342. tests/networks/nets/test_basic_unetplusplus.py +109 -0
  343. tests/networks/nets/test_bundle_init_bundle.py +55 -0
  344. tests/networks/nets/test_cell_sam_wrapper.py +58 -0
  345. tests/networks/nets/test_controlnet.py +215 -0
  346. tests/networks/nets/test_daf3d.py +62 -0
  347. tests/networks/nets/test_densenet.py +121 -0
  348. tests/networks/nets/test_diffusion_model_unet.py +585 -0
  349. tests/networks/nets/test_dints_network.py +168 -0
  350. tests/networks/nets/test_discriminator.py +59 -0
  351. tests/networks/nets/test_dynunet.py +181 -0
  352. tests/networks/nets/test_efficientnet.py +400 -0
  353. tests/networks/nets/test_flexible_unet.py +341 -0
  354. tests/networks/nets/test_fullyconnectednet.py +69 -0
  355. tests/networks/nets/test_generator.py +59 -0
  356. tests/networks/nets/test_globalnet.py +103 -0
  357. tests/networks/nets/test_highresnet.py +67 -0
  358. tests/networks/nets/test_hovernet.py +218 -0
  359. tests/networks/nets/test_mednext.py +122 -0
  360. tests/networks/nets/test_milmodel.py +92 -0
  361. tests/networks/nets/test_net_adapter.py +68 -0
  362. tests/networks/nets/test_network_consistency.py +86 -0
  363. tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
  364. tests/networks/nets/test_quicknat.py +57 -0
  365. tests/networks/nets/test_resnet.py +340 -0
  366. tests/networks/nets/test_segresnet.py +120 -0
  367. tests/networks/nets/test_segresnet_ds.py +156 -0
  368. tests/networks/nets/test_senet.py +151 -0
  369. tests/networks/nets/test_spade_autoencoderkl.py +295 -0
  370. tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
  371. tests/networks/nets/test_spade_vaegan.py +140 -0
  372. tests/networks/nets/test_swin_unetr.py +139 -0
  373. tests/networks/nets/test_torchvision_fc_model.py +201 -0
  374. tests/networks/nets/test_transchex.py +84 -0
  375. tests/networks/nets/test_transformer.py +108 -0
  376. tests/networks/nets/test_unet.py +208 -0
  377. tests/networks/nets/test_unetr.py +137 -0
  378. tests/networks/nets/test_varautoencoder.py +127 -0
  379. tests/networks/nets/test_vista3d.py +84 -0
  380. tests/networks/nets/test_vit.py +139 -0
  381. tests/networks/nets/test_vitautoenc.py +112 -0
  382. tests/networks/nets/test_vnet.py +81 -0
  383. tests/networks/nets/test_voxelmorph.py +280 -0
  384. tests/networks/nets/test_vqvae.py +274 -0
  385. tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
  386. tests/networks/schedulers/__init__.py +10 -0
  387. tests/networks/schedulers/test_scheduler_ddim.py +83 -0
  388. tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
  389. tests/networks/schedulers/test_scheduler_pndm.py +108 -0
  390. tests/networks/test_bundle_onnx_export.py +71 -0
  391. tests/networks/test_convert_to_onnx.py +106 -0
  392. tests/networks/test_convert_to_torchscript.py +46 -0
  393. tests/networks/test_convert_to_trt.py +79 -0
  394. tests/networks/test_save_state.py +73 -0
  395. tests/networks/test_to_onehot.py +63 -0
  396. tests/networks/test_varnet.py +63 -0
  397. tests/networks/utils/__init__.py +10 -0
  398. tests/networks/utils/test_copy_model_state.py +187 -0
  399. tests/networks/utils/test_eval_mode.py +34 -0
  400. tests/networks/utils/test_freeze_layers.py +61 -0
  401. tests/networks/utils/test_replace_module.py +98 -0
  402. tests/networks/utils/test_train_mode.py +34 -0
  403. tests/optimizers/__init__.py +10 -0
  404. tests/optimizers/test_generate_param_groups.py +105 -0
  405. tests/optimizers/test_lr_finder.py +108 -0
  406. tests/optimizers/test_lr_scheduler.py +71 -0
  407. tests/optimizers/test_optim_novograd.py +100 -0
  408. tests/profile_subclass/__init__.py +10 -0
  409. tests/profile_subclass/cprofile_profiling.py +29 -0
  410. tests/profile_subclass/min_classes.py +30 -0
  411. tests/profile_subclass/profiling.py +73 -0
  412. tests/profile_subclass/pyspy_profiling.py +41 -0
  413. tests/transforms/__init__.py +10 -0
  414. tests/transforms/compose/__init__.py +10 -0
  415. tests/transforms/compose/test_compose.py +758 -0
  416. tests/transforms/compose/test_some_of.py +258 -0
  417. tests/transforms/croppad/__init__.py +10 -0
  418. tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
  419. tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
  420. tests/transforms/functional/__init__.py +10 -0
  421. tests/transforms/functional/test_apply.py +75 -0
  422. tests/transforms/functional/test_resample.py +50 -0
  423. tests/transforms/intensity/__init__.py +10 -0
  424. tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
  425. tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
  426. tests/transforms/intensity/test_foreground_mask.py +98 -0
  427. tests/transforms/intensity/test_foreground_maskd.py +106 -0
  428. tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
  429. tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
  430. tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
  431. tests/transforms/inverse/__init__.py +10 -0
  432. tests/transforms/inverse/test_inverse_array.py +76 -0
  433. tests/transforms/inverse/test_traceable_transform.py +59 -0
  434. tests/transforms/post/__init__.py +10 -0
  435. tests/transforms/post/test_label_filterd.py +78 -0
  436. tests/transforms/post/test_probnms.py +72 -0
  437. tests/transforms/post/test_probnmsd.py +79 -0
  438. tests/transforms/post/test_remove_small_objects.py +102 -0
  439. tests/transforms/spatial/__init__.py +10 -0
  440. tests/transforms/spatial/test_convert_box_points.py +119 -0
  441. tests/transforms/spatial/test_grid_patch.py +134 -0
  442. tests/transforms/spatial/test_grid_patchd.py +102 -0
  443. tests/transforms/spatial/test_rand_grid_patch.py +150 -0
  444. tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
  445. tests/transforms/spatial/test_spatial_resampled.py +124 -0
  446. tests/transforms/test_activations.py +120 -0
  447. tests/transforms/test_activationsd.py +64 -0
  448. tests/transforms/test_adaptors.py +160 -0
  449. tests/transforms/test_add_coordinate_channels.py +53 -0
  450. tests/transforms/test_add_coordinate_channelsd.py +67 -0
  451. tests/transforms/test_add_extreme_points_channel.py +80 -0
  452. tests/transforms/test_add_extreme_points_channeld.py +77 -0
  453. tests/transforms/test_adjust_contrast.py +70 -0
  454. tests/transforms/test_adjust_contrastd.py +64 -0
  455. tests/transforms/test_affine.py +245 -0
  456. tests/transforms/test_affine_grid.py +152 -0
  457. tests/transforms/test_affined.py +190 -0
  458. tests/transforms/test_as_channel_last.py +38 -0
  459. tests/transforms/test_as_channel_lastd.py +44 -0
  460. tests/transforms/test_as_discrete.py +81 -0
  461. tests/transforms/test_as_discreted.py +82 -0
  462. tests/transforms/test_border_pad.py +49 -0
  463. tests/transforms/test_border_padd.py +45 -0
  464. tests/transforms/test_bounding_rect.py +54 -0
  465. tests/transforms/test_bounding_rectd.py +53 -0
  466. tests/transforms/test_cast_to_type.py +63 -0
  467. tests/transforms/test_cast_to_typed.py +74 -0
  468. tests/transforms/test_center_scale_crop.py +55 -0
  469. tests/transforms/test_center_scale_cropd.py +56 -0
  470. tests/transforms/test_center_spatial_crop.py +56 -0
  471. tests/transforms/test_center_spatial_cropd.py +63 -0
  472. tests/transforms/test_classes_to_indices.py +93 -0
  473. tests/transforms/test_classes_to_indicesd.py +110 -0
  474. tests/transforms/test_clip_intensity_percentiles.py +196 -0
  475. tests/transforms/test_clip_intensity_percentilesd.py +193 -0
  476. tests/transforms/test_compose_get_number_conversions.py +127 -0
  477. tests/transforms/test_concat_itemsd.py +82 -0
  478. tests/transforms/test_convert_to_multi_channel.py +59 -0
  479. tests/transforms/test_convert_to_multi_channeld.py +37 -0
  480. tests/transforms/test_copy_itemsd.py +86 -0
  481. tests/transforms/test_create_grid_and_affine.py +274 -0
  482. tests/transforms/test_crop_foreground.py +164 -0
  483. tests/transforms/test_crop_foregroundd.py +205 -0
  484. tests/transforms/test_cucim_dict_transform.py +142 -0
  485. tests/transforms/test_cucim_transform.py +141 -0
  486. tests/transforms/test_data_stats.py +221 -0
  487. tests/transforms/test_data_statsd.py +249 -0
  488. tests/transforms/test_delete_itemsd.py +58 -0
  489. tests/transforms/test_detect_envelope.py +159 -0
  490. tests/transforms/test_distance_transform_edt.py +202 -0
  491. tests/transforms/test_divisible_pad.py +49 -0
  492. tests/transforms/test_divisible_padd.py +42 -0
  493. tests/transforms/test_ensure_channel_first.py +113 -0
  494. tests/transforms/test_ensure_channel_firstd.py +85 -0
  495. tests/transforms/test_ensure_type.py +94 -0
  496. tests/transforms/test_ensure_typed.py +110 -0
  497. tests/transforms/test_fg_bg_to_indices.py +83 -0
  498. tests/transforms/test_fg_bg_to_indicesd.py +78 -0
  499. tests/transforms/test_fill_holes.py +207 -0
  500. tests/transforms/test_fill_holesd.py +209 -0
  501. tests/transforms/test_flatten_sub_keysd.py +64 -0
  502. tests/transforms/test_flip.py +83 -0
  503. tests/transforms/test_flipd.py +90 -0
  504. tests/transforms/test_fourier.py +70 -0
  505. tests/transforms/test_gaussian_sharpen.py +92 -0
  506. tests/transforms/test_gaussian_sharpend.py +92 -0
  507. tests/transforms/test_gaussian_smooth.py +96 -0
  508. tests/transforms/test_gaussian_smoothd.py +96 -0
  509. tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
  510. tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
  511. tests/transforms/test_generate_spatial_bounding_box.py +114 -0
  512. tests/transforms/test_get_extreme_points.py +57 -0
  513. tests/transforms/test_gibbs_noise.py +75 -0
  514. tests/transforms/test_gibbs_noised.py +88 -0
  515. tests/transforms/test_grid_distortion.py +113 -0
  516. tests/transforms/test_grid_distortiond.py +87 -0
  517. tests/transforms/test_grid_split.py +88 -0
  518. tests/transforms/test_grid_splitd.py +96 -0
  519. tests/transforms/test_histogram_normalize.py +59 -0
  520. tests/transforms/test_histogram_normalized.py +59 -0
  521. tests/transforms/test_image_filter.py +259 -0
  522. tests/transforms/test_intensity_stats.py +73 -0
  523. tests/transforms/test_intensity_statsd.py +90 -0
  524. tests/transforms/test_inverse.py +521 -0
  525. tests/transforms/test_inverse_collation.py +147 -0
  526. tests/transforms/test_invert.py +105 -0
  527. tests/transforms/test_invertd.py +142 -0
  528. tests/transforms/test_k_space_spike_noise.py +81 -0
  529. tests/transforms/test_k_space_spike_noised.py +98 -0
  530. tests/transforms/test_keep_largest_connected_component.py +419 -0
  531. tests/transforms/test_keep_largest_connected_componentd.py +348 -0
  532. tests/transforms/test_label_filter.py +78 -0
  533. tests/transforms/test_label_to_contour.py +179 -0
  534. tests/transforms/test_label_to_contourd.py +182 -0
  535. tests/transforms/test_label_to_mask.py +69 -0
  536. tests/transforms/test_label_to_maskd.py +70 -0
  537. tests/transforms/test_load_image.py +502 -0
  538. tests/transforms/test_load_imaged.py +198 -0
  539. tests/transforms/test_load_spacing_orientation.py +149 -0
  540. tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
  541. tests/transforms/test_map_binary_to_indices.py +75 -0
  542. tests/transforms/test_map_classes_to_indices.py +135 -0
  543. tests/transforms/test_map_label_value.py +89 -0
  544. tests/transforms/test_map_label_valued.py +85 -0
  545. tests/transforms/test_map_transform.py +45 -0
  546. tests/transforms/test_mask_intensity.py +74 -0
  547. tests/transforms/test_mask_intensityd.py +68 -0
  548. tests/transforms/test_mean_ensemble.py +77 -0
  549. tests/transforms/test_mean_ensembled.py +91 -0
  550. tests/transforms/test_median_smooth.py +41 -0
  551. tests/transforms/test_median_smoothd.py +65 -0
  552. tests/transforms/test_morphological_ops.py +101 -0
  553. tests/transforms/test_nifti_endianness.py +107 -0
  554. tests/transforms/test_normalize_intensity.py +143 -0
  555. tests/transforms/test_normalize_intensityd.py +81 -0
  556. tests/transforms/test_nvtx_decorator.py +289 -0
  557. tests/transforms/test_nvtx_transform.py +143 -0
  558. tests/transforms/test_orientation.py +247 -0
  559. tests/transforms/test_orientationd.py +112 -0
  560. tests/transforms/test_rand_adjust_contrast.py +45 -0
  561. tests/transforms/test_rand_adjust_contrastd.py +44 -0
  562. tests/transforms/test_rand_affine.py +201 -0
  563. tests/transforms/test_rand_affine_grid.py +212 -0
  564. tests/transforms/test_rand_affined.py +281 -0
  565. tests/transforms/test_rand_axis_flip.py +50 -0
  566. tests/transforms/test_rand_axis_flipd.py +50 -0
  567. tests/transforms/test_rand_bias_field.py +69 -0
  568. tests/transforms/test_rand_bias_fieldd.py +65 -0
  569. tests/transforms/test_rand_coarse_dropout.py +110 -0
  570. tests/transforms/test_rand_coarse_dropoutd.py +107 -0
  571. tests/transforms/test_rand_coarse_shuffle.py +65 -0
  572. tests/transforms/test_rand_coarse_shuffled.py +59 -0
  573. tests/transforms/test_rand_crop_by_label_classes.py +170 -0
  574. tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
  575. tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
  576. tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
  577. tests/transforms/test_rand_cucim_dict_transform.py +162 -0
  578. tests/transforms/test_rand_cucim_transform.py +162 -0
  579. tests/transforms/test_rand_deform_grid.py +138 -0
  580. tests/transforms/test_rand_elastic_2d.py +127 -0
  581. tests/transforms/test_rand_elastic_3d.py +104 -0
  582. tests/transforms/test_rand_elasticd_2d.py +177 -0
  583. tests/transforms/test_rand_elasticd_3d.py +156 -0
  584. tests/transforms/test_rand_flip.py +60 -0
  585. tests/transforms/test_rand_flipd.py +55 -0
  586. tests/transforms/test_rand_gaussian_noise.py +48 -0
  587. tests/transforms/test_rand_gaussian_noised.py +54 -0
  588. tests/transforms/test_rand_gaussian_sharpen.py +140 -0
  589. tests/transforms/test_rand_gaussian_sharpend.py +143 -0
  590. tests/transforms/test_rand_gaussian_smooth.py +98 -0
  591. tests/transforms/test_rand_gaussian_smoothd.py +98 -0
  592. tests/transforms/test_rand_gibbs_noise.py +103 -0
  593. tests/transforms/test_rand_gibbs_noised.py +117 -0
  594. tests/transforms/test_rand_grid_distortion.py +99 -0
  595. tests/transforms/test_rand_grid_distortiond.py +90 -0
  596. tests/transforms/test_rand_histogram_shift.py +92 -0
  597. tests/transforms/test_rand_k_space_spike_noise.py +92 -0
  598. tests/transforms/test_rand_k_space_spike_noised.py +76 -0
  599. tests/transforms/test_rand_rician_noise.py +52 -0
  600. tests/transforms/test_rand_rician_noised.py +52 -0
  601. tests/transforms/test_rand_rotate.py +166 -0
  602. tests/transforms/test_rand_rotate90.py +100 -0
  603. tests/transforms/test_rand_rotate90d.py +112 -0
  604. tests/transforms/test_rand_rotated.py +187 -0
  605. tests/transforms/test_rand_scale_crop.py +78 -0
  606. tests/transforms/test_rand_scale_cropd.py +98 -0
  607. tests/transforms/test_rand_scale_intensity.py +54 -0
  608. tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
  609. tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
  610. tests/transforms/test_rand_scale_intensityd.py +53 -0
  611. tests/transforms/test_rand_shift_intensity.py +52 -0
  612. tests/transforms/test_rand_shift_intensityd.py +67 -0
  613. tests/transforms/test_rand_simulate_low_resolution.py +83 -0
  614. tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
  615. tests/transforms/test_rand_spatial_crop.py +107 -0
  616. tests/transforms/test_rand_spatial_crop_samples.py +128 -0
  617. tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
  618. tests/transforms/test_rand_spatial_cropd.py +112 -0
  619. tests/transforms/test_rand_std_shift_intensity.py +43 -0
  620. tests/transforms/test_rand_std_shift_intensityd.py +38 -0
  621. tests/transforms/test_rand_zoom.py +105 -0
  622. tests/transforms/test_rand_zoomd.py +108 -0
  623. tests/transforms/test_randidentity.py +49 -0
  624. tests/transforms/test_random_order.py +144 -0
  625. tests/transforms/test_randtorchvisiond.py +65 -0
  626. tests/transforms/test_regularization.py +139 -0
  627. tests/transforms/test_remove_repeated_channel.py +34 -0
  628. tests/transforms/test_remove_repeated_channeld.py +44 -0
  629. tests/transforms/test_repeat_channel.py +34 -0
  630. tests/transforms/test_repeat_channeld.py +41 -0
  631. tests/transforms/test_resample_backends.py +65 -0
  632. tests/transforms/test_resample_to_match.py +110 -0
  633. tests/transforms/test_resample_to_matchd.py +93 -0
  634. tests/transforms/test_resampler.py +165 -0
  635. tests/transforms/test_resize.py +140 -0
  636. tests/transforms/test_resize_with_pad_or_crop.py +91 -0
  637. tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
  638. tests/transforms/test_resized.py +163 -0
  639. tests/transforms/test_rotate.py +160 -0
  640. tests/transforms/test_rotate90.py +212 -0
  641. tests/transforms/test_rotate90d.py +106 -0
  642. tests/transforms/test_rotated.py +179 -0
  643. tests/transforms/test_save_classificationd.py +109 -0
  644. tests/transforms/test_save_image.py +80 -0
  645. tests/transforms/test_save_imaged.py +130 -0
  646. tests/transforms/test_savitzky_golay_smooth.py +73 -0
  647. tests/transforms/test_savitzky_golay_smoothd.py +73 -0
  648. tests/transforms/test_scale_intensity.py +76 -0
  649. tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
  650. tests/transforms/test_scale_intensity_range.py +41 -0
  651. tests/transforms/test_scale_intensity_ranged.py +40 -0
  652. tests/transforms/test_scale_intensityd.py +57 -0
  653. tests/transforms/test_select_itemsd.py +41 -0
  654. tests/transforms/test_shift_intensity.py +31 -0
  655. tests/transforms/test_shift_intensityd.py +44 -0
  656. tests/transforms/test_signal_continuouswavelet.py +44 -0
  657. tests/transforms/test_signal_fillempty.py +52 -0
  658. tests/transforms/test_signal_fillemptyd.py +60 -0
  659. tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
  660. tests/transforms/test_signal_rand_add_sine.py +52 -0
  661. tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
  662. tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
  663. tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
  664. tests/transforms/test_signal_rand_drop.py +50 -0
  665. tests/transforms/test_signal_rand_scale.py +52 -0
  666. tests/transforms/test_signal_rand_shift.py +55 -0
  667. tests/transforms/test_signal_remove_frequency.py +71 -0
  668. tests/transforms/test_smooth_field.py +177 -0
  669. tests/transforms/test_sobel_gradient.py +189 -0
  670. tests/transforms/test_sobel_gradientd.py +212 -0
  671. tests/transforms/test_spacing.py +381 -0
  672. tests/transforms/test_spacingd.py +178 -0
  673. tests/transforms/test_spatial_crop.py +82 -0
  674. tests/transforms/test_spatial_cropd.py +74 -0
  675. tests/transforms/test_spatial_pad.py +57 -0
  676. tests/transforms/test_spatial_padd.py +43 -0
  677. tests/transforms/test_spatial_resample.py +235 -0
  678. tests/transforms/test_squeezedim.py +62 -0
  679. tests/transforms/test_squeezedimd.py +98 -0
  680. tests/transforms/test_std_shift_intensity.py +76 -0
  681. tests/transforms/test_std_shift_intensityd.py +74 -0
  682. tests/transforms/test_threshold_intensity.py +38 -0
  683. tests/transforms/test_threshold_intensityd.py +58 -0
  684. tests/transforms/test_to_contiguous.py +47 -0
  685. tests/transforms/test_to_cupy.py +112 -0
  686. tests/transforms/test_to_cupyd.py +76 -0
  687. tests/transforms/test_to_device.py +42 -0
  688. tests/transforms/test_to_deviced.py +37 -0
  689. tests/transforms/test_to_numpy.py +85 -0
  690. tests/transforms/test_to_numpyd.py +68 -0
  691. tests/transforms/test_to_pil.py +52 -0
  692. tests/transforms/test_to_pild.py +50 -0
  693. tests/transforms/test_to_tensor.py +60 -0
  694. tests/transforms/test_to_tensord.py +71 -0
  695. tests/transforms/test_torchvision.py +66 -0
  696. tests/transforms/test_torchvisiond.py +63 -0
  697. tests/transforms/test_transform.py +62 -0
  698. tests/transforms/test_transpose.py +41 -0
  699. tests/transforms/test_transposed.py +52 -0
  700. tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
  701. tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
  702. tests/transforms/test_vote_ensemble.py +84 -0
  703. tests/transforms/test_vote_ensembled.py +107 -0
  704. tests/transforms/test_with_allow_missing_keys.py +76 -0
  705. tests/transforms/test_zoom.py +120 -0
  706. tests/transforms/test_zoomd.py +94 -0
  707. tests/transforms/transform/__init__.py +10 -0
  708. tests/transforms/transform/test_randomizable.py +52 -0
  709. tests/transforms/transform/test_randomizable_transform_type.py +37 -0
  710. tests/transforms/utility/__init__.py +10 -0
  711. tests/transforms/utility/test_apply_transform_to_points.py +81 -0
  712. tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
  713. tests/transforms/utility/test_identity.py +29 -0
  714. tests/transforms/utility/test_identityd.py +30 -0
  715. tests/transforms/utility/test_lambda.py +71 -0
  716. tests/transforms/utility/test_lambdad.py +83 -0
  717. tests/transforms/utility/test_rand_lambda.py +87 -0
  718. tests/transforms/utility/test_rand_lambdad.py +77 -0
  719. tests/transforms/utility/test_simulatedelay.py +36 -0
  720. tests/transforms/utility/test_simulatedelayd.py +36 -0
  721. tests/transforms/utility/test_splitdim.py +52 -0
  722. tests/transforms/utility/test_splitdimd.py +96 -0
  723. tests/transforms/utils/__init__.py +10 -0
  724. tests/transforms/utils/test_correct_crop_centers.py +36 -0
  725. tests/transforms/utils/test_get_unique_labels.py +45 -0
  726. tests/transforms/utils/test_print_transform_backends.py +29 -0
  727. tests/transforms/utils/test_soft_clip.py +125 -0
  728. tests/utils/__init__.py +10 -0
  729. tests/utils/enums/__init__.py +10 -0
  730. tests/utils/enums/test_hovernet_loss.py +190 -0
  731. tests/utils/enums/test_ordering.py +289 -0
  732. tests/utils/enums/test_wsireader.py +663 -0
  733. tests/utils/misc/__init__.py +10 -0
  734. tests/utils/misc/test_ensure_tuple.py +53 -0
  735. tests/utils/misc/test_monai_env_vars.py +44 -0
  736. tests/utils/misc/test_monai_utils_misc.py +103 -0
  737. tests/utils/misc/test_str2bool.py +34 -0
  738. tests/utils/misc/test_str2list.py +33 -0
  739. tests/utils/test_alias.py +44 -0
  740. tests/utils/test_component_store.py +73 -0
  741. tests/utils/test_deprecated.py +455 -0
  742. tests/utils/test_enum_bound_interp.py +75 -0
  743. tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
  744. tests/utils/test_get_package_version.py +34 -0
  745. tests/utils/test_handler_logfile.py +84 -0
  746. tests/utils/test_handler_metric_logger.py +62 -0
  747. tests/utils/test_list_to_dict.py +43 -0
  748. tests/utils/test_look_up_option.py +87 -0
  749. tests/utils/test_optional_import.py +80 -0
  750. tests/utils/test_pad_mode.py +39 -0
  751. tests/utils/test_profiling.py +208 -0
  752. tests/utils/test_rankfilter_dist.py +77 -0
  753. tests/utils/test_require_pkg.py +83 -0
  754. tests/utils/test_sample_slices.py +43 -0
  755. tests/utils/test_set_determinism.py +74 -0
  756. tests/utils/test_squeeze_unsqueeze.py +71 -0
  757. tests/utils/test_state_cacher.py +67 -0
  758. tests/utils/test_torchscript_utils.py +113 -0
  759. tests/utils/test_version.py +91 -0
  760. tests/utils/test_version_after.py +65 -0
  761. tests/utils/type_conversion/__init__.py +10 -0
  762. tests/utils/type_conversion/test_convert_data_type.py +152 -0
  763. tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
  764. tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
  765. tests/visualize/__init__.py +10 -0
  766. tests/visualize/test_img2tensorboard.py +46 -0
  767. tests/visualize/test_occlusion_sensitivity.py +128 -0
  768. tests/visualize/test_plot_2d_or_3d_image.py +74 -0
  769. tests/visualize/test_vis_cam.py +98 -0
  770. tests/visualize/test_vis_gradcam.py +211 -0
  771. tests/visualize/utils/__init__.py +10 -0
  772. tests/visualize/utils/test_blend_images.py +63 -0
  773. tests/visualize/utils/test_matshow3d.py +133 -0
  774. monai_weekly-1.5.dev2506.dist-info/RECORD +0 -427
  775. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/LICENSE +0 -0
  776. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/WHEEL +0 -0
@@ -0,0 +1,304 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import os
15
+ import shutil
16
+ import tempfile
17
+ import unittest
18
+ from glob import glob
19
+
20
+ import nibabel as nib
21
+ import numpy as np
22
+ import torch
23
+
24
+ import monai
25
+ from monai.data import create_test_image_3d, decollate_batch
26
+ from monai.inferers import sliding_window_inference
27
+ from monai.metrics import DiceMetric
28
+ from monai.networks import eval_mode
29
+ from monai.networks.nets import UNet
30
+ from monai.transforms import (
31
+ Activations,
32
+ AsDiscrete,
33
+ Compose,
34
+ EnsureChannelFirstd,
35
+ LoadImaged,
36
+ RandCropByPosNegLabeld,
37
+ RandRotate90d,
38
+ SaveImage,
39
+ ScaleIntensityd,
40
+ Spacingd,
41
+ )
42
+ from monai.utils import optional_import, set_determinism
43
+ from monai.visualize import plot_2d_or_3d_image
44
+ from tests.test_utils import DistTestCase, TimedCall, skip_if_quick
45
+ from tests.testing_data.integration_answers import test_integration_value
46
+
47
+ SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter")
48
+
49
+ TASK = "integration_segmentation_3d"
50
+
51
+
52
+ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, None)):
53
+ monai.config.print_config()
54
+ images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
55
+ segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
56
+ train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])]
57
+ val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])]
58
+
59
+ # define transforms for image and segmentation
60
+ train_transforms = Compose(
61
+ [
62
+ LoadImaged(keys=["img", "seg"], reader=readers[0]),
63
+ EnsureChannelFirstd(keys=["img", "seg"]),
64
+ # resampling with align_corners=True or dtype=float64 will generate
65
+ # slight different results between PyTorch 1.5 an 1.6
66
+ Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
67
+ ScaleIntensityd(keys="img"),
68
+ RandCropByPosNegLabeld(
69
+ keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
70
+ ),
71
+ RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]),
72
+ ]
73
+ )
74
+ train_transforms.set_random_state(1234)
75
+ val_transforms = Compose(
76
+ [
77
+ LoadImaged(keys=["img", "seg"], reader=readers[1]),
78
+ EnsureChannelFirstd(keys=["img", "seg"]),
79
+ # resampling with align_corners=True or dtype=float64 will generate
80
+ # slight different results between PyTorch 1.5 an 1.6
81
+ Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
82
+ ScaleIntensityd(keys="img"),
83
+ ]
84
+ )
85
+
86
+ # create a training data loader
87
+ if cachedataset == 2:
88
+ train_ds = monai.data.CacheDataset(
89
+ data=train_files, transform=train_transforms, cache_rate=0.8, runtime_cache="process"
90
+ )
91
+ elif cachedataset == 3:
92
+ train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms, cache_dir=root_dir)
93
+ else:
94
+ train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
95
+ # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
96
+ train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
97
+ # create a validation data loader
98
+ val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
99
+ val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
100
+ val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
101
+ dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
102
+
103
+ # create UNet, DiceLoss and Adam optimizer
104
+ model = monai.networks.nets.UNet(
105
+ spatial_dims=3,
106
+ in_channels=1,
107
+ out_channels=1,
108
+ channels=(16, 32, 64, 128, 256),
109
+ strides=(2, 2, 2, 2),
110
+ num_res_units=2,
111
+ ).to(device)
112
+ loss_function = monai.losses.DiceLoss(sigmoid=True)
113
+ optimizer = torch.optim.Adam(model.parameters(), 5e-4)
114
+
115
+ # start a typical PyTorch training
116
+ val_interval = 2
117
+ best_metric, best_metric_epoch = -1, -1
118
+ epoch_loss_values = []
119
+ metric_values = []
120
+ writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs"))
121
+ model_filename = os.path.join(root_dir, "best_metric_model.pth")
122
+ for epoch in range(6):
123
+ print("-" * 10)
124
+ print(f"Epoch {epoch + 1}/{6}")
125
+ model.train()
126
+ epoch_loss = 0
127
+ step = 0
128
+ for batch_data in train_loader:
129
+ step += 1
130
+ inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
131
+ optimizer.zero_grad()
132
+ outputs = model(inputs)
133
+ loss = loss_function(outputs, labels)
134
+ loss.backward()
135
+ optimizer.step()
136
+ epoch_loss += loss.item()
137
+ epoch_len = len(train_ds) // train_loader.batch_size
138
+ print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}")
139
+ writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
140
+ epoch_loss /= step
141
+ epoch_loss_values.append(epoch_loss)
142
+ print(f"epoch {epoch + 1} average loss:{epoch_loss:0.4f}")
143
+
144
+ if (epoch + 1) % val_interval == 0:
145
+ with eval_mode(model):
146
+ val_images = None
147
+ val_labels = None
148
+ val_outputs = None
149
+ for val_data in val_loader:
150
+ val_images, val_labels = (val_data["img"].to(device), val_data["seg"].to(device))
151
+ sw_batch_size, roi_size = 4, (96, 96, 96)
152
+ val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
153
+ # decollate prediction into a list and execute post processing for every item
154
+ val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)]
155
+ # compute metrics
156
+ dice_metric(y_pred=val_outputs, y=val_labels)
157
+
158
+ metric = dice_metric.aggregate().item()
159
+ dice_metric.reset()
160
+ metric_values.append(metric)
161
+ if metric > best_metric:
162
+ best_metric = metric
163
+ best_metric_epoch = epoch + 1
164
+ torch.save(model.state_dict(), model_filename)
165
+ print("saved new best metric model")
166
+ print(
167
+ f"current epoch {epoch + 1} current mean dice: {metric:0.4f} "
168
+ f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}"
169
+ )
170
+ writer.add_scalar("val_mean_dice", metric, epoch + 1)
171
+ # plot the last model output as GIF image in TensorBoard with the corresponding image and label
172
+ plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
173
+ plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
174
+ plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
175
+ print(f"train completed, best_metric: {best_metric:0.4f} at epoch: {best_metric_epoch}")
176
+ writer.close()
177
+ return epoch_loss_values, best_metric
178
+
179
+
180
+ def run_inference_test(root_dir, device="cuda:0"):
181
+ images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
182
+ segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
183
+ val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]
184
+
185
+ saver = SaveImage(
186
+ output_dir=os.path.join(root_dir, "output"),
187
+ dtype=np.float32,
188
+ output_ext=".nii.gz",
189
+ output_postfix="seg",
190
+ mode="bilinear",
191
+ )
192
+ # define transforms for image and segmentation
193
+ val_transforms = Compose(
194
+ [
195
+ LoadImaged(keys=["img", "seg"]),
196
+ EnsureChannelFirstd(keys=["img", "seg"]),
197
+ # resampling with align_corners=True or dtype=float64 will generate
198
+ # slight different results between PyTorch 1.5 an 1.6
199
+ Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
200
+ ScaleIntensityd(keys="img"),
201
+ ]
202
+ )
203
+ val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
204
+ # sliding window inference need to input 1 image in every iteration
205
+ val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
206
+ val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5), saver])
207
+ dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
208
+
209
+ model = UNet(
210
+ spatial_dims=3,
211
+ in_channels=1,
212
+ out_channels=1,
213
+ channels=(16, 32, 64, 128, 256),
214
+ strides=(2, 2, 2, 2),
215
+ num_res_units=2,
216
+ ).to(device)
217
+
218
+ model_filename = os.path.join(root_dir, "best_metric_model.pth")
219
+ model.load_state_dict(torch.load(model_filename))
220
+ with eval_mode(model):
221
+ # resampling with align_corners=True or dtype=float64 will generate
222
+ # slight different results between PyTorch 1.5 an 1.6
223
+ for val_data in val_loader:
224
+ val_images, val_labels = (val_data["img"].to(device), val_data["seg"].to(device))
225
+ # define sliding window size and batch size for windows inference
226
+ sw_batch_size, roi_size = 4, (96, 96, 96)
227
+ val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
228
+ # decollate prediction into a list
229
+ val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)]
230
+ # compute metrics
231
+ dice_metric(y_pred=val_outputs, y=val_labels)
232
+
233
+ return dice_metric.aggregate().item()
234
+
235
+
236
+ @skip_if_quick
237
+ class IntegrationSegmentation3D(DistTestCase):
238
+ def setUp(self):
239
+ set_determinism(seed=0)
240
+
241
+ self.data_dir = tempfile.mkdtemp()
242
+ for i in range(40):
243
+ im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
244
+ n = nib.Nifti1Image(im, np.eye(4))
245
+ nib.save(n, os.path.join(self.data_dir, f"img{i:d}.nii.gz"))
246
+ n = nib.Nifti1Image(seg, np.eye(4))
247
+ nib.save(n, os.path.join(self.data_dir, f"seg{i:d}.nii.gz"))
248
+
249
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu:0"
250
+
251
+ def tearDown(self):
252
+ set_determinism(seed=None)
253
+ shutil.rmtree(self.data_dir)
254
+
255
+ def train_and_infer(self, idx=0):
256
+ results = []
257
+ set_determinism(0)
258
+ _readers = (None, None)
259
+ if idx == 1:
260
+ _readers = ("itkreader", "itkreader")
261
+ elif idx == 2:
262
+ _readers = ("itkreader", "nibabelreader")
263
+ losses, best_metric = run_training_test(self.data_dir, device=self.device, cachedataset=idx, readers=_readers)
264
+ infer_metric = run_inference_test(self.data_dir, device=self.device)
265
+
266
+ # check training properties
267
+ print("losses", losses)
268
+ print("best metric", best_metric)
269
+ print("infer metric", infer_metric)
270
+ self.assertTrue(len(glob(os.path.join(self.data_dir, "runs"))) > 0)
271
+ model_file = os.path.join(self.data_dir, "best_metric_model.pth")
272
+ self.assertTrue(os.path.exists(model_file))
273
+
274
+ # check inference properties
275
+ output_files = sorted(glob(os.path.join(self.data_dir, "output", "img*", "*.nii.gz")))
276
+ print([np.mean(nib.load(output).get_fdata()) for output in output_files])
277
+ results.extend(losses)
278
+ results.append(best_metric)
279
+ results.append(infer_metric)
280
+ for output in output_files:
281
+ ave = np.mean(nib.load(output).get_fdata())
282
+ results.append(ave)
283
+ self.assertTrue(test_integration_value(TASK, key="losses", data=results[:6], rtol=1e-3))
284
+ self.assertTrue(test_integration_value(TASK, key="best_metric", data=results[6], rtol=1e-2))
285
+ self.assertTrue(test_integration_value(TASK, key="infer_metric", data=results[7], rtol=1e-2))
286
+ self.assertTrue(test_integration_value(TASK, key="output_sums", data=results[8:], rtol=5e-2))
287
+ return results
288
+
289
+ def test_training(self):
290
+ repeated = []
291
+ for i in range(4):
292
+ results = self.train_and_infer(i)
293
+ repeated.append(results)
294
+ np.testing.assert_allclose(repeated[0], repeated[1])
295
+ np.testing.assert_allclose(repeated[0], repeated[2])
296
+ np.testing.assert_allclose(repeated[0], repeated[3])
297
+
298
+ @TimedCall(seconds=360, daemon=False)
299
+ def test_timing(self):
300
+ self.train_and_infer(idx=3)
301
+
302
+
303
+ if __name__ == "__main__":
304
+ unittest.main()
@@ -0,0 +1,100 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import os
15
+ import tempfile
16
+ import unittest
17
+
18
+ import nibabel as nib
19
+ import numpy as np
20
+ import torch
21
+ from ignite.engine import Engine, Events
22
+ from torch.utils.data import DataLoader
23
+
24
+ from monai.data import ImageDataset, create_test_image_3d
25
+ from monai.inferers import sliding_window_inference
26
+ from monai.networks import eval_mode, predict_segmentation
27
+ from monai.networks.nets import UNet
28
+ from monai.transforms import EnsureChannelFirst, SaveImage
29
+ from monai.utils import set_determinism
30
+ from tests.test_utils import DistTestCase, TimedCall, make_nifti_image, skip_if_quick
31
+
32
+
33
+ def run_test(batch_size, img_name, seg_name, output_dir, device="cuda:0"):
34
+ ds = ImageDataset(
35
+ [img_name],
36
+ [seg_name],
37
+ transform=EnsureChannelFirst(channel_dim="no_channel"),
38
+ seg_transform=EnsureChannelFirst(channel_dim="no_channel"),
39
+ image_only=True,
40
+ )
41
+ loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available())
42
+
43
+ net = UNet(
44
+ spatial_dims=3, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2
45
+ ).to(device)
46
+ roi_size = (16, 32, 48)
47
+ sw_batch_size = batch_size
48
+
49
+ saver = SaveImage(output_dir=output_dir, output_ext=".nii.gz", output_postfix="seg")
50
+
51
+ def _sliding_window_processor(_engine, batch):
52
+ img = batch[0] # first item from ImageDataset is the input image
53
+ with eval_mode(net):
54
+ seg_probs = sliding_window_inference(img.to(device), roi_size, sw_batch_size, net, device=device)
55
+ return predict_segmentation(seg_probs)
56
+
57
+ def save_func(engine):
58
+ for m in engine.state.output:
59
+ saver(m)
60
+
61
+ infer_engine = Engine(_sliding_window_processor)
62
+ infer_engine.add_event_handler(Events.ITERATION_COMPLETED, save_func)
63
+ infer_engine.run(loader)
64
+
65
+ basename = os.path.basename(img_name)[: -len(".nii.gz")]
66
+ saved_name = os.path.join(output_dir, basename, f"{basename}_seg.nii.gz")
67
+ return saved_name
68
+
69
+
70
+ @skip_if_quick
71
+ class TestIntegrationSlidingWindow(DistTestCase):
72
+ def setUp(self):
73
+ set_determinism(seed=0)
74
+
75
+ im, seg = create_test_image_3d(28, 25, 63, rad_max=10, noise_max=1, num_objs=4, num_seg_classes=1)
76
+ self.img_name = make_nifti_image(im)
77
+ self.seg_name = make_nifti_image(seg)
78
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")
79
+
80
+ def tearDown(self):
81
+ set_determinism(seed=None)
82
+ if os.path.exists(self.img_name):
83
+ os.remove(self.img_name)
84
+ if os.path.exists(self.seg_name):
85
+ os.remove(self.seg_name)
86
+
87
+ @TimedCall(seconds=20)
88
+ def test_training(self):
89
+ set_determinism(seed=0)
90
+ with tempfile.TemporaryDirectory() as tempdir:
91
+ output_file = run_test(
92
+ batch_size=2, img_name=self.img_name, seg_name=self.seg_name, output_dir=tempdir, device=self.device
93
+ )
94
+ output_image = nib.load(output_file).get_fdata()
95
+ np.testing.assert_allclose(np.sum(output_image), 33621)
96
+ np.testing.assert_allclose(output_image.shape, (28, 25, 63))
97
+
98
+
99
+ if __name__ == "__main__":
100
+ unittest.main()
@@ -0,0 +1,133 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import torch.optim as optim
21
+
22
+ from monai.data import create_test_image_2d
23
+ from monai.networks.layers import AffineTransform
24
+ from monai.utils import set_determinism
25
+ from tests.test_utils import DistTestCase, TimedCall
26
+
27
+
28
+ class STNBenchmark(nn.Module):
29
+ """
30
+ adapted from https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html
31
+ """
32
+
33
+ def __init__(self, is_ref=True, reverse_indexing=False):
34
+ super().__init__()
35
+ self.is_ref = is_ref
36
+ self.localization = nn.Sequential(
37
+ nn.Conv2d(1, 8, kernel_size=7),
38
+ nn.MaxPool2d(2, stride=2),
39
+ nn.ReLU(True),
40
+ nn.Conv2d(8, 10, kernel_size=5),
41
+ nn.MaxPool2d(2, stride=2),
42
+ nn.ReLU(True),
43
+ )
44
+ # Regressor for the 3 * 2 affine matrix
45
+ self.fc_loc = nn.Sequential(nn.Linear(10 * 3 * 3, 32), nn.ReLU(True), nn.Linear(32, 3 * 2))
46
+ # Initialize the weights/bias with identity transformation
47
+ self.fc_loc[2].weight.data.zero_()
48
+ self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
49
+ if not self.is_ref:
50
+ self.xform = AffineTransform(align_corners=False, normalized=True, reverse_indexing=reverse_indexing)
51
+
52
+ # Spatial transformer network forward function
53
+ def stn_ref(self, x):
54
+ xs = self.localization(x)
55
+ xs = xs.view(-1, 10 * 3 * 3)
56
+ theta = self.fc_loc(xs)
57
+ theta = theta.view(-1, 2, 3)
58
+
59
+ grid = F.affine_grid(theta, x.size(), align_corners=False)
60
+ x = F.grid_sample(x, grid, align_corners=False)
61
+ return x
62
+
63
+ def stn(self, x):
64
+ xs = self.localization(x)
65
+ xs = xs.view(-1, 10 * 3 * 3)
66
+ theta = self.fc_loc(xs)
67
+ theta = theta.view(-1, 2, 3)
68
+ x = self.xform(x, theta, spatial_size=x.size()[2:])
69
+ return x
70
+
71
+ def forward(self, x):
72
+ if self.is_ref:
73
+ return self.stn_ref(x)
74
+ return self.stn(x)
75
+
76
+
77
+ def compare_2d(is_ref=True, device=None, reverse_indexing=False):
78
+ batch_size = 32
79
+ img_a = [create_test_image_2d(28, 28, 5, rad_max=6, noise_max=1)[0][None] for _ in range(batch_size)]
80
+ img_b = [create_test_image_2d(28, 28, 5, rad_max=6, noise_max=1)[0][None] for _ in range(batch_size)]
81
+ img_a = np.stack(img_a, axis=0)
82
+ img_b = np.stack(img_b, axis=0)
83
+ img_a = torch.as_tensor(img_a, device=device)
84
+ img_b = torch.as_tensor(img_b, device=device)
85
+ model = STNBenchmark(is_ref=is_ref, reverse_indexing=reverse_indexing).to(device)
86
+ optimizer = optim.SGD(model.parameters(), lr=0.001)
87
+ model.train()
88
+ init_loss = None
89
+ for _ in range(20):
90
+ optimizer.zero_grad()
91
+ output_a = model(img_a)
92
+ loss = torch.mean((output_a - img_b) ** 2)
93
+ if init_loss is None:
94
+ init_loss = loss.item()
95
+ loss.backward()
96
+ optimizer.step()
97
+ return model(img_a).detach().cpu().numpy(), loss.item(), init_loss
98
+
99
+
100
+ class TestSpatialTransformerCore(DistTestCase):
101
+ def setUp(self):
102
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")
103
+
104
+ def tearDown(self):
105
+ set_determinism(seed=None)
106
+
107
+ @TimedCall(seconds=100, skip_timing=not torch.cuda.is_available())
108
+ def test_training(self):
109
+ """
110
+ check that the quality AffineTransform backpropagation
111
+ """
112
+ atol = 1e-5
113
+ set_determinism(seed=0)
114
+ out_ref, loss_ref, init_loss_ref = compare_2d(True, self.device)
115
+ print(out_ref.shape, loss_ref, init_loss_ref)
116
+
117
+ set_determinism(seed=0)
118
+ out, loss, init_loss = compare_2d(False, self.device)
119
+ print(out.shape, loss, init_loss)
120
+ np.testing.assert_allclose(out_ref, out, atol=atol)
121
+ np.testing.assert_allclose(init_loss_ref, init_loss, atol=atol)
122
+ np.testing.assert_allclose(loss_ref, loss, atol=atol)
123
+
124
+ set_determinism(seed=0)
125
+ out, loss, init_loss = compare_2d(False, self.device, True)
126
+ print(out.shape, loss, init_loss)
127
+ np.testing.assert_allclose(out_ref, out, atol=atol)
128
+ np.testing.assert_allclose(init_loss_ref, init_loss, atol=atol)
129
+ np.testing.assert_allclose(loss_ref, loss, atol=atol)
130
+
131
+
132
+ if __name__ == "__main__":
133
+ unittest.main()
@@ -0,0 +1,67 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import numpy as np
17
+ import torch
18
+ from ignite.engine import create_supervised_trainer
19
+ from torch.utils.data import DataLoader, Dataset
20
+
21
+ from monai.data import create_test_image_2d
22
+ from monai.losses import DiceLoss
23
+ from monai.networks.nets import BasicUNet, UNet
24
+ from tests.test_utils import DistTestCase, TimedCall, skip_if_quick
25
+
26
+
27
+ def run_test(net_name="basicunet", batch_size=64, train_steps=100, device="cuda:0"):
28
+ class _TestBatch(Dataset):
29
+ def __getitem__(self, _unused_id):
30
+ im, seg = create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1)
31
+ return im[None], seg[None].astype(np.float32)
32
+
33
+ def __len__(self):
34
+ return train_steps
35
+
36
+ net = None
37
+ if net_name == "basicunet":
38
+ net = BasicUNet(spatial_dims=2, in_channels=1, out_channels=1, features=(4, 8, 8, 16, 16, 32))
39
+ elif net_name == "unet":
40
+ net = UNet(
41
+ spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2
42
+ )
43
+ net.to(device)
44
+
45
+ loss = DiceLoss(sigmoid=True)
46
+ opt = torch.optim.Adam(net.parameters(), 1e-4)
47
+ src = DataLoader(_TestBatch(), batch_size=batch_size)
48
+
49
+ trainer = create_supervised_trainer(net, opt, loss, device, False)
50
+
51
+ trainer.run(src, 1)
52
+ loss = trainer.state.output
53
+ return loss
54
+
55
+
56
+ @skip_if_quick
57
+ class TestIntegrationUnet2D(DistTestCase):
58
+ @TimedCall(seconds=20, daemon=False)
59
+ def test_unet_training(self):
60
+ for n in ["basicunet", "unet"]:
61
+ loss = run_test(net_name=n, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0"))
62
+ print(loss)
63
+ self.assertGreaterEqual(0.85, loss)
64
+
65
+
66
+ if __name__ == "__main__":
67
+ unittest.main()
@@ -0,0 +1,61 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import numpy as np
17
+ import torch
18
+
19
+ from monai.data import DataLoader
20
+ from monai.utils import set_determinism
21
+ from tests.test_utils import DistTestCase, SkipIfBeforePyTorchVersion, TimedCall, skip_if_no_cuda, skip_if_quick
22
+
23
+
24
+ def run_loading_test(num_workers=50, device=None, pw=False):
25
+ """multi workers stress tests"""
26
+ set_determinism(seed=0)
27
+ if device is None:
28
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
29
+ train_ds = list(range(10000))
30
+ train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=num_workers, persistent_workers=pw)
31
+ answer = []
32
+ for _ in range(2):
33
+ np.testing.assert_equal(torch.cuda.memory_allocated(), 0)
34
+ for batch_data in train_loader:
35
+ x = batch_data.to(device)
36
+ mem = torch.cuda.memory_allocated()
37
+ np.testing.assert_equal(mem > 0 and mem < 5000, True)
38
+ answer.append(x[-1].item())
39
+ del x
40
+ return answer
41
+
42
+
43
+ @skip_if_quick
44
+ @skip_if_no_cuda
45
+ @SkipIfBeforePyTorchVersion((1, 9))
46
+ class IntegrationLoading(DistTestCase):
47
+ def tearDown(self):
48
+ set_determinism(seed=None)
49
+
50
+ @TimedCall(seconds=5000, skip_timing=not torch.cuda.is_available(), daemon=False)
51
+ def test_timing(self):
52
+ expected = None
53
+ for pw in (False, True):
54
+ result = run_loading_test(pw=pw)
55
+ if expected is None:
56
+ expected = result[0]
57
+ np.testing.assert_allclose(result[0], expected) # test for deterministic first epoch in two settings
58
+
59
+
60
+ if __name__ == "__main__":
61
+ unittest.main()