monai-weekly 1.5.dev2506__py3-none-any.whl → 1.5.dev2508__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (787) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/transforms.py +1 -4
  4. monai/data/utils.py +6 -13
  5. monai/handlers/__init__.py +1 -0
  6. monai/handlers/average_precision.py +53 -0
  7. monai/inferers/inferer.py +10 -7
  8. monai/inferers/utils.py +1 -2
  9. monai/losses/dice.py +2 -14
  10. monai/losses/ds_loss.py +1 -3
  11. monai/metrics/__init__.py +1 -0
  12. monai/metrics/average_precision.py +187 -0
  13. monai/networks/layers/simplelayers.py +2 -14
  14. monai/networks/utils.py +4 -16
  15. monai/transforms/compose.py +28 -11
  16. monai/transforms/croppad/array.py +1 -6
  17. monai/transforms/io/array.py +0 -1
  18. monai/transforms/transform.py +15 -6
  19. monai/transforms/utility/array.py +2 -12
  20. monai/transforms/utils.py +1 -2
  21. monai/transforms/utils_pytorch_numpy_unification.py +2 -4
  22. monai/utils/enums.py +3 -2
  23. monai/utils/module.py +6 -6
  24. monai/utils/tf32.py +0 -10
  25. monai/visualize/class_activation_maps.py +5 -8
  26. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/METADATA +21 -17
  27. monai_weekly-1.5.dev2508.dist-info/RECORD +1185 -0
  28. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/top_level.txt +1 -0
  29. tests/apps/__init__.py +10 -0
  30. tests/apps/deepedit/__init__.py +10 -0
  31. tests/apps/deepedit/test_deepedit_transforms.py +314 -0
  32. tests/apps/deepgrow/__init__.py +10 -0
  33. tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
  34. tests/apps/deepgrow/transforms/__init__.py +10 -0
  35. tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
  36. tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
  37. tests/apps/detection/__init__.py +10 -0
  38. tests/apps/detection/metrics/__init__.py +10 -0
  39. tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
  40. tests/apps/detection/networks/__init__.py +10 -0
  41. tests/apps/detection/networks/test_retinanet.py +210 -0
  42. tests/apps/detection/networks/test_retinanet_detector.py +203 -0
  43. tests/apps/detection/test_box_transform.py +370 -0
  44. tests/apps/detection/utils/__init__.py +10 -0
  45. tests/apps/detection/utils/test_anchor_box.py +88 -0
  46. tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
  47. tests/apps/detection/utils/test_box_coder.py +43 -0
  48. tests/apps/detection/utils/test_detector_boxselector.py +67 -0
  49. tests/apps/detection/utils/test_detector_utils.py +96 -0
  50. tests/apps/detection/utils/test_hardnegsampler.py +54 -0
  51. tests/apps/nuclick/__init__.py +10 -0
  52. tests/apps/nuclick/test_nuclick_transforms.py +259 -0
  53. tests/apps/pathology/__init__.py +10 -0
  54. tests/apps/pathology/handlers/__init__.py +10 -0
  55. tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
  56. tests/apps/pathology/test_lesion_froc.py +333 -0
  57. tests/apps/pathology/test_pathology_prob_nms.py +55 -0
  58. tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
  59. tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
  60. tests/apps/pathology/transforms/__init__.py +10 -0
  61. tests/apps/pathology/transforms/post/__init__.py +10 -0
  62. tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
  63. tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
  64. tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
  65. tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
  66. tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
  67. tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
  68. tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
  69. tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
  70. tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
  71. tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
  72. tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
  73. tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
  74. tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
  75. tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
  76. tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
  77. tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
  78. tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
  79. tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
  80. tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
  81. tests/apps/pathology/transforms/post/test_watershed.py +60 -0
  82. tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
  83. tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
  84. tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
  85. tests/apps/reconstruction/__init__.py +10 -0
  86. tests/apps/reconstruction/nets/__init__.py +10 -0
  87. tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
  88. tests/apps/reconstruction/test_complex_utils.py +77 -0
  89. tests/apps/reconstruction/test_fastmri_reader.py +82 -0
  90. tests/apps/reconstruction/test_mri_utils.py +37 -0
  91. tests/apps/reconstruction/transforms/__init__.py +10 -0
  92. tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
  93. tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
  94. tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
  95. tests/apps/test_auto3dseg_bundlegen.py +156 -0
  96. tests/apps/test_check_hash.py +53 -0
  97. tests/apps/test_cross_validation.py +74 -0
  98. tests/apps/test_decathlondataset.py +93 -0
  99. tests/apps/test_download_and_extract.py +70 -0
  100. tests/apps/test_download_url_yandex.py +45 -0
  101. tests/apps/test_mednistdataset.py +72 -0
  102. tests/apps/test_mmar_download.py +154 -0
  103. tests/apps/test_tciadataset.py +123 -0
  104. tests/apps/vista3d/__init__.py +10 -0
  105. tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
  106. tests/apps/vista3d/test_vista3d_sampler.py +100 -0
  107. tests/apps/vista3d/test_vista3d_transforms.py +94 -0
  108. tests/bundle/__init__.py +10 -0
  109. tests/bundle/test_bundle_ckpt_export.py +107 -0
  110. tests/bundle/test_bundle_download.py +435 -0
  111. tests/bundle/test_bundle_get_data.py +94 -0
  112. tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
  113. tests/bundle/test_bundle_trt_export.py +147 -0
  114. tests/bundle/test_bundle_utils.py +149 -0
  115. tests/bundle/test_bundle_verify_metadata.py +66 -0
  116. tests/bundle/test_bundle_verify_net.py +76 -0
  117. tests/bundle/test_bundle_workflow.py +272 -0
  118. tests/bundle/test_component_locator.py +38 -0
  119. tests/bundle/test_config_item.py +138 -0
  120. tests/bundle/test_config_parser.py +392 -0
  121. tests/bundle/test_reference_resolver.py +114 -0
  122. tests/config/__init__.py +10 -0
  123. tests/config/test_cv2_dist.py +53 -0
  124. tests/engines/__init__.py +10 -0
  125. tests/engines/test_ensemble_evaluator.py +94 -0
  126. tests/engines/test_prepare_batch_default.py +76 -0
  127. tests/engines/test_prepare_batch_default_dist.py +76 -0
  128. tests/engines/test_prepare_batch_diffusion.py +104 -0
  129. tests/engines/test_prepare_batch_extra_input.py +80 -0
  130. tests/fl/__init__.py +10 -0
  131. tests/fl/monai_algo/__init__.py +10 -0
  132. tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
  133. tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
  134. tests/fl/test_fl_monai_algo_stats.py +81 -0
  135. tests/fl/utils/__init__.py +10 -0
  136. tests/fl/utils/test_fl_exchange_object.py +63 -0
  137. tests/handlers/__init__.py +10 -0
  138. tests/handlers/test_handler_average_precision.py +79 -0
  139. tests/handlers/test_handler_checkpoint_loader.py +182 -0
  140. tests/handlers/test_handler_checkpoint_saver.py +233 -0
  141. tests/handlers/test_handler_classification_saver.py +64 -0
  142. tests/handlers/test_handler_classification_saver_dist.py +77 -0
  143. tests/handlers/test_handler_clearml_image.py +65 -0
  144. tests/handlers/test_handler_clearml_stats.py +65 -0
  145. tests/handlers/test_handler_confusion_matrix.py +104 -0
  146. tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
  147. tests/handlers/test_handler_decollate_batch.py +66 -0
  148. tests/handlers/test_handler_early_stop.py +68 -0
  149. tests/handlers/test_handler_garbage_collector.py +73 -0
  150. tests/handlers/test_handler_hausdorff_distance.py +111 -0
  151. tests/handlers/test_handler_ignite_metric.py +191 -0
  152. tests/handlers/test_handler_lr_scheduler.py +94 -0
  153. tests/handlers/test_handler_mean_dice.py +98 -0
  154. tests/handlers/test_handler_mean_iou.py +76 -0
  155. tests/handlers/test_handler_metrics_reloaded.py +149 -0
  156. tests/handlers/test_handler_metrics_saver.py +89 -0
  157. tests/handlers/test_handler_metrics_saver_dist.py +120 -0
  158. tests/handlers/test_handler_mlflow.py +296 -0
  159. tests/handlers/test_handler_nvtx.py +93 -0
  160. tests/handlers/test_handler_panoptic_quality.py +89 -0
  161. tests/handlers/test_handler_parameter_scheduler.py +136 -0
  162. tests/handlers/test_handler_post_processing.py +74 -0
  163. tests/handlers/test_handler_prob_map_producer.py +111 -0
  164. tests/handlers/test_handler_regression_metrics.py +160 -0
  165. tests/handlers/test_handler_regression_metrics_dist.py +245 -0
  166. tests/handlers/test_handler_rocauc.py +48 -0
  167. tests/handlers/test_handler_rocauc_dist.py +54 -0
  168. tests/handlers/test_handler_stats.py +281 -0
  169. tests/handlers/test_handler_surface_distance.py +113 -0
  170. tests/handlers/test_handler_tb_image.py +61 -0
  171. tests/handlers/test_handler_tb_stats.py +166 -0
  172. tests/handlers/test_handler_validation.py +59 -0
  173. tests/handlers/test_trt_compile.py +145 -0
  174. tests/handlers/test_write_metrics_reports.py +68 -0
  175. tests/inferers/__init__.py +10 -0
  176. tests/inferers/test_avg_merger.py +179 -0
  177. tests/inferers/test_controlnet_inferers.py +1388 -0
  178. tests/inferers/test_diffusion_inferer.py +236 -0
  179. tests/inferers/test_latent_diffusion_inferer.py +884 -0
  180. tests/inferers/test_patch_inferer.py +309 -0
  181. tests/inferers/test_saliency_inferer.py +55 -0
  182. tests/inferers/test_slice_inferer.py +57 -0
  183. tests/inferers/test_sliding_window_inference.py +377 -0
  184. tests/inferers/test_sliding_window_splitter.py +284 -0
  185. tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
  186. tests/inferers/test_zarr_avg_merger.py +326 -0
  187. tests/integration/__init__.py +10 -0
  188. tests/integration/test_auto3dseg_ensemble.py +211 -0
  189. tests/integration/test_auto3dseg_hpo.py +189 -0
  190. tests/integration/test_deepedit_interaction.py +122 -0
  191. tests/integration/test_downsample_block.py +50 -0
  192. tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
  193. tests/integration/test_integration_autorunner.py +201 -0
  194. tests/integration/test_integration_bundle_run.py +240 -0
  195. tests/integration/test_integration_classification_2d.py +282 -0
  196. tests/integration/test_integration_determinism.py +95 -0
  197. tests/integration/test_integration_fast_train.py +231 -0
  198. tests/integration/test_integration_gpu_customization.py +159 -0
  199. tests/integration/test_integration_lazy_samples.py +219 -0
  200. tests/integration/test_integration_nnunetv2_runner.py +96 -0
  201. tests/integration/test_integration_segmentation_3d.py +304 -0
  202. tests/integration/test_integration_sliding_window.py +100 -0
  203. tests/integration/test_integration_stn.py +133 -0
  204. tests/integration/test_integration_unet_2d.py +67 -0
  205. tests/integration/test_integration_workers.py +61 -0
  206. tests/integration/test_integration_workflows.py +365 -0
  207. tests/integration/test_integration_workflows_adversarial.py +173 -0
  208. tests/integration/test_integration_workflows_gan.py +158 -0
  209. tests/integration/test_loader_semaphore.py +48 -0
  210. tests/integration/test_mapping_filed.py +122 -0
  211. tests/integration/test_meta_affine.py +183 -0
  212. tests/integration/test_metatensor_integration.py +114 -0
  213. tests/integration/test_module_list.py +76 -0
  214. tests/integration/test_one_of.py +283 -0
  215. tests/integration/test_pad_collation.py +124 -0
  216. tests/integration/test_reg_loss_integration.py +107 -0
  217. tests/integration/test_retinanet_predict_utils.py +154 -0
  218. tests/integration/test_seg_loss_integration.py +159 -0
  219. tests/integration/test_spatial_combine_transforms.py +185 -0
  220. tests/integration/test_testtimeaugmentation.py +186 -0
  221. tests/integration/test_vis_gradbased.py +69 -0
  222. tests/integration/test_vista3d_utils.py +159 -0
  223. tests/losses/__init__.py +10 -0
  224. tests/losses/deform/__init__.py +10 -0
  225. tests/losses/deform/test_bending_energy.py +88 -0
  226. tests/losses/deform/test_diffusion_loss.py +117 -0
  227. tests/losses/image_dissimilarity/__init__.py +10 -0
  228. tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
  229. tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
  230. tests/losses/test_adversarial_loss.py +94 -0
  231. tests/losses/test_barlow_twins_loss.py +109 -0
  232. tests/losses/test_cldice_loss.py +51 -0
  233. tests/losses/test_contrastive_loss.py +86 -0
  234. tests/losses/test_dice_ce_loss.py +123 -0
  235. tests/losses/test_dice_focal_loss.py +124 -0
  236. tests/losses/test_dice_loss.py +227 -0
  237. tests/losses/test_ds_loss.py +189 -0
  238. tests/losses/test_focal_loss.py +379 -0
  239. tests/losses/test_generalized_dice_focal_loss.py +85 -0
  240. tests/losses/test_generalized_dice_loss.py +221 -0
  241. tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
  242. tests/losses/test_giou_loss.py +62 -0
  243. tests/losses/test_hausdorff_loss.py +264 -0
  244. tests/losses/test_masked_dice_loss.py +152 -0
  245. tests/losses/test_masked_loss.py +87 -0
  246. tests/losses/test_multi_scale.py +86 -0
  247. tests/losses/test_nacl_loss.py +167 -0
  248. tests/losses/test_perceptual_loss.py +122 -0
  249. tests/losses/test_spectral_loss.py +86 -0
  250. tests/losses/test_ssim_loss.py +59 -0
  251. tests/losses/test_sure_loss.py +72 -0
  252. tests/losses/test_tversky_loss.py +198 -0
  253. tests/losses/test_unified_focal_loss.py +66 -0
  254. tests/metrics/__init__.py +10 -0
  255. tests/metrics/test_compute_average_precision.py +162 -0
  256. tests/metrics/test_compute_confusion_matrix.py +294 -0
  257. tests/metrics/test_compute_f_beta.py +80 -0
  258. tests/metrics/test_compute_fid_metric.py +40 -0
  259. tests/metrics/test_compute_froc.py +143 -0
  260. tests/metrics/test_compute_generalized_dice.py +240 -0
  261. tests/metrics/test_compute_meandice.py +306 -0
  262. tests/metrics/test_compute_meaniou.py +223 -0
  263. tests/metrics/test_compute_mmd_metric.py +56 -0
  264. tests/metrics/test_compute_multiscalessim_metric.py +83 -0
  265. tests/metrics/test_compute_panoptic_quality.py +113 -0
  266. tests/metrics/test_compute_regression_metrics.py +196 -0
  267. tests/metrics/test_compute_roc_auc.py +155 -0
  268. tests/metrics/test_compute_variance.py +147 -0
  269. tests/metrics/test_cumulative.py +63 -0
  270. tests/metrics/test_cumulative_average.py +74 -0
  271. tests/metrics/test_cumulative_average_dist.py +48 -0
  272. tests/metrics/test_hausdorff_distance.py +209 -0
  273. tests/metrics/test_label_quality_score.py +134 -0
  274. tests/metrics/test_loss_metric.py +57 -0
  275. tests/metrics/test_metrics_reloaded.py +96 -0
  276. tests/metrics/test_ssim_metric.py +78 -0
  277. tests/metrics/test_surface_dice.py +416 -0
  278. tests/metrics/test_surface_distance.py +186 -0
  279. tests/networks/__init__.py +10 -0
  280. tests/networks/blocks/__init__.py +10 -0
  281. tests/networks/blocks/dints_block/__init__.py +10 -0
  282. tests/networks/blocks/dints_block/test_acn_block.py +41 -0
  283. tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
  284. tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
  285. tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
  286. tests/networks/blocks/test_adn.py +86 -0
  287. tests/networks/blocks/test_convolutions.py +156 -0
  288. tests/networks/blocks/test_crf_cpu.py +513 -0
  289. tests/networks/blocks/test_crf_cuda.py +528 -0
  290. tests/networks/blocks/test_crossattention.py +185 -0
  291. tests/networks/blocks/test_denseblock.py +105 -0
  292. tests/networks/blocks/test_dynunet_block.py +116 -0
  293. tests/networks/blocks/test_fpn_block.py +88 -0
  294. tests/networks/blocks/test_localnet_block.py +121 -0
  295. tests/networks/blocks/test_mlp.py +78 -0
  296. tests/networks/blocks/test_patchembedding.py +212 -0
  297. tests/networks/blocks/test_regunet_block.py +103 -0
  298. tests/networks/blocks/test_se_block.py +85 -0
  299. tests/networks/blocks/test_se_blocks.py +78 -0
  300. tests/networks/blocks/test_segresnet_block.py +57 -0
  301. tests/networks/blocks/test_selfattention.py +232 -0
  302. tests/networks/blocks/test_simple_aspp.py +87 -0
  303. tests/networks/blocks/test_spatialattention.py +55 -0
  304. tests/networks/blocks/test_subpixel_upsample.py +87 -0
  305. tests/networks/blocks/test_text_encoding.py +49 -0
  306. tests/networks/blocks/test_transformerblock.py +90 -0
  307. tests/networks/blocks/test_unetr_block.py +158 -0
  308. tests/networks/blocks/test_upsample_block.py +134 -0
  309. tests/networks/blocks/warp/__init__.py +10 -0
  310. tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
  311. tests/networks/blocks/warp/test_warp.py +250 -0
  312. tests/networks/layers/__init__.py +10 -0
  313. tests/networks/layers/filtering/__init__.py +10 -0
  314. tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
  315. tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
  316. tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
  317. tests/networks/layers/filtering/test_phl_cpu.py +259 -0
  318. tests/networks/layers/filtering/test_phl_cuda.py +167 -0
  319. tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
  320. tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
  321. tests/networks/layers/test_affine_transform.py +385 -0
  322. tests/networks/layers/test_apply_filter.py +89 -0
  323. tests/networks/layers/test_channel_pad.py +51 -0
  324. tests/networks/layers/test_conjugate_gradient.py +56 -0
  325. tests/networks/layers/test_drop_path.py +46 -0
  326. tests/networks/layers/test_gaussian.py +317 -0
  327. tests/networks/layers/test_gaussian_filter.py +206 -0
  328. tests/networks/layers/test_get_layers.py +65 -0
  329. tests/networks/layers/test_gmm.py +314 -0
  330. tests/networks/layers/test_grid_pull.py +93 -0
  331. tests/networks/layers/test_hilbert_transform.py +131 -0
  332. tests/networks/layers/test_lltm.py +62 -0
  333. tests/networks/layers/test_median_filter.py +52 -0
  334. tests/networks/layers/test_polyval.py +55 -0
  335. tests/networks/layers/test_preset_filters.py +136 -0
  336. tests/networks/layers/test_savitzky_golay_filter.py +141 -0
  337. tests/networks/layers/test_separable_filter.py +87 -0
  338. tests/networks/layers/test_skip_connection.py +48 -0
  339. tests/networks/layers/test_vector_quantizer.py +89 -0
  340. tests/networks/layers/test_weight_init.py +50 -0
  341. tests/networks/nets/__init__.py +10 -0
  342. tests/networks/nets/dints/__init__.py +10 -0
  343. tests/networks/nets/dints/test_dints_cell.py +110 -0
  344. tests/networks/nets/dints/test_dints_mixop.py +84 -0
  345. tests/networks/nets/regunet/__init__.py +10 -0
  346. tests/networks/nets/regunet/test_localnet.py +86 -0
  347. tests/networks/nets/regunet/test_regunet.py +88 -0
  348. tests/networks/nets/test_ahnet.py +224 -0
  349. tests/networks/nets/test_attentionunet.py +88 -0
  350. tests/networks/nets/test_autoencoder.py +95 -0
  351. tests/networks/nets/test_autoencoderkl.py +337 -0
  352. tests/networks/nets/test_basic_unet.py +102 -0
  353. tests/networks/nets/test_basic_unetplusplus.py +109 -0
  354. tests/networks/nets/test_bundle_init_bundle.py +55 -0
  355. tests/networks/nets/test_cell_sam_wrapper.py +58 -0
  356. tests/networks/nets/test_controlnet.py +215 -0
  357. tests/networks/nets/test_daf3d.py +62 -0
  358. tests/networks/nets/test_densenet.py +121 -0
  359. tests/networks/nets/test_diffusion_model_unet.py +585 -0
  360. tests/networks/nets/test_dints_network.py +168 -0
  361. tests/networks/nets/test_discriminator.py +59 -0
  362. tests/networks/nets/test_dynunet.py +181 -0
  363. tests/networks/nets/test_efficientnet.py +400 -0
  364. tests/networks/nets/test_flexible_unet.py +341 -0
  365. tests/networks/nets/test_fullyconnectednet.py +69 -0
  366. tests/networks/nets/test_generator.py +59 -0
  367. tests/networks/nets/test_globalnet.py +103 -0
  368. tests/networks/nets/test_highresnet.py +67 -0
  369. tests/networks/nets/test_hovernet.py +218 -0
  370. tests/networks/nets/test_mednext.py +122 -0
  371. tests/networks/nets/test_milmodel.py +92 -0
  372. tests/networks/nets/test_net_adapter.py +68 -0
  373. tests/networks/nets/test_network_consistency.py +86 -0
  374. tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
  375. tests/networks/nets/test_quicknat.py +57 -0
  376. tests/networks/nets/test_resnet.py +340 -0
  377. tests/networks/nets/test_segresnet.py +120 -0
  378. tests/networks/nets/test_segresnet_ds.py +156 -0
  379. tests/networks/nets/test_senet.py +151 -0
  380. tests/networks/nets/test_spade_autoencoderkl.py +295 -0
  381. tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
  382. tests/networks/nets/test_spade_vaegan.py +140 -0
  383. tests/networks/nets/test_swin_unetr.py +139 -0
  384. tests/networks/nets/test_torchvision_fc_model.py +201 -0
  385. tests/networks/nets/test_transchex.py +84 -0
  386. tests/networks/nets/test_transformer.py +108 -0
  387. tests/networks/nets/test_unet.py +208 -0
  388. tests/networks/nets/test_unetr.py +137 -0
  389. tests/networks/nets/test_varautoencoder.py +127 -0
  390. tests/networks/nets/test_vista3d.py +84 -0
  391. tests/networks/nets/test_vit.py +139 -0
  392. tests/networks/nets/test_vitautoenc.py +112 -0
  393. tests/networks/nets/test_vnet.py +81 -0
  394. tests/networks/nets/test_voxelmorph.py +280 -0
  395. tests/networks/nets/test_vqvae.py +274 -0
  396. tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
  397. tests/networks/schedulers/__init__.py +10 -0
  398. tests/networks/schedulers/test_scheduler_ddim.py +83 -0
  399. tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
  400. tests/networks/schedulers/test_scheduler_pndm.py +108 -0
  401. tests/networks/test_bundle_onnx_export.py +71 -0
  402. tests/networks/test_convert_to_onnx.py +106 -0
  403. tests/networks/test_convert_to_torchscript.py +46 -0
  404. tests/networks/test_convert_to_trt.py +79 -0
  405. tests/networks/test_save_state.py +73 -0
  406. tests/networks/test_to_onehot.py +63 -0
  407. tests/networks/test_varnet.py +63 -0
  408. tests/networks/utils/__init__.py +10 -0
  409. tests/networks/utils/test_copy_model_state.py +187 -0
  410. tests/networks/utils/test_eval_mode.py +34 -0
  411. tests/networks/utils/test_freeze_layers.py +61 -0
  412. tests/networks/utils/test_replace_module.py +98 -0
  413. tests/networks/utils/test_train_mode.py +34 -0
  414. tests/optimizers/__init__.py +10 -0
  415. tests/optimizers/test_generate_param_groups.py +105 -0
  416. tests/optimizers/test_lr_finder.py +108 -0
  417. tests/optimizers/test_lr_scheduler.py +71 -0
  418. tests/optimizers/test_optim_novograd.py +100 -0
  419. tests/profile_subclass/__init__.py +10 -0
  420. tests/profile_subclass/cprofile_profiling.py +29 -0
  421. tests/profile_subclass/min_classes.py +30 -0
  422. tests/profile_subclass/profiling.py +73 -0
  423. tests/profile_subclass/pyspy_profiling.py +41 -0
  424. tests/transforms/__init__.py +10 -0
  425. tests/transforms/compose/__init__.py +10 -0
  426. tests/transforms/compose/test_compose.py +758 -0
  427. tests/transforms/compose/test_some_of.py +258 -0
  428. tests/transforms/croppad/__init__.py +10 -0
  429. tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
  430. tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
  431. tests/transforms/functional/__init__.py +10 -0
  432. tests/transforms/functional/test_apply.py +75 -0
  433. tests/transforms/functional/test_resample.py +50 -0
  434. tests/transforms/intensity/__init__.py +10 -0
  435. tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
  436. tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
  437. tests/transforms/intensity/test_foreground_mask.py +98 -0
  438. tests/transforms/intensity/test_foreground_maskd.py +106 -0
  439. tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
  440. tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
  441. tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
  442. tests/transforms/inverse/__init__.py +10 -0
  443. tests/transforms/inverse/test_inverse_array.py +76 -0
  444. tests/transforms/inverse/test_traceable_transform.py +59 -0
  445. tests/transforms/post/__init__.py +10 -0
  446. tests/transforms/post/test_label_filterd.py +78 -0
  447. tests/transforms/post/test_probnms.py +72 -0
  448. tests/transforms/post/test_probnmsd.py +79 -0
  449. tests/transforms/post/test_remove_small_objects.py +102 -0
  450. tests/transforms/spatial/__init__.py +10 -0
  451. tests/transforms/spatial/test_convert_box_points.py +119 -0
  452. tests/transforms/spatial/test_grid_patch.py +134 -0
  453. tests/transforms/spatial/test_grid_patchd.py +102 -0
  454. tests/transforms/spatial/test_rand_grid_patch.py +150 -0
  455. tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
  456. tests/transforms/spatial/test_spatial_resampled.py +124 -0
  457. tests/transforms/test_activations.py +120 -0
  458. tests/transforms/test_activationsd.py +64 -0
  459. tests/transforms/test_adaptors.py +160 -0
  460. tests/transforms/test_add_coordinate_channels.py +53 -0
  461. tests/transforms/test_add_coordinate_channelsd.py +67 -0
  462. tests/transforms/test_add_extreme_points_channel.py +80 -0
  463. tests/transforms/test_add_extreme_points_channeld.py +77 -0
  464. tests/transforms/test_adjust_contrast.py +70 -0
  465. tests/transforms/test_adjust_contrastd.py +64 -0
  466. tests/transforms/test_affine.py +245 -0
  467. tests/transforms/test_affine_grid.py +152 -0
  468. tests/transforms/test_affined.py +190 -0
  469. tests/transforms/test_as_channel_last.py +38 -0
  470. tests/transforms/test_as_channel_lastd.py +44 -0
  471. tests/transforms/test_as_discrete.py +81 -0
  472. tests/transforms/test_as_discreted.py +82 -0
  473. tests/transforms/test_border_pad.py +49 -0
  474. tests/transforms/test_border_padd.py +45 -0
  475. tests/transforms/test_bounding_rect.py +54 -0
  476. tests/transforms/test_bounding_rectd.py +53 -0
  477. tests/transforms/test_cast_to_type.py +63 -0
  478. tests/transforms/test_cast_to_typed.py +74 -0
  479. tests/transforms/test_center_scale_crop.py +55 -0
  480. tests/transforms/test_center_scale_cropd.py +56 -0
  481. tests/transforms/test_center_spatial_crop.py +56 -0
  482. tests/transforms/test_center_spatial_cropd.py +63 -0
  483. tests/transforms/test_classes_to_indices.py +93 -0
  484. tests/transforms/test_classes_to_indicesd.py +110 -0
  485. tests/transforms/test_clip_intensity_percentiles.py +196 -0
  486. tests/transforms/test_clip_intensity_percentilesd.py +193 -0
  487. tests/transforms/test_compose_get_number_conversions.py +127 -0
  488. tests/transforms/test_concat_itemsd.py +82 -0
  489. tests/transforms/test_convert_to_multi_channel.py +59 -0
  490. tests/transforms/test_convert_to_multi_channeld.py +37 -0
  491. tests/transforms/test_copy_itemsd.py +86 -0
  492. tests/transforms/test_create_grid_and_affine.py +274 -0
  493. tests/transforms/test_crop_foreground.py +164 -0
  494. tests/transforms/test_crop_foregroundd.py +205 -0
  495. tests/transforms/test_cucim_dict_transform.py +142 -0
  496. tests/transforms/test_cucim_transform.py +141 -0
  497. tests/transforms/test_data_stats.py +221 -0
  498. tests/transforms/test_data_statsd.py +249 -0
  499. tests/transforms/test_delete_itemsd.py +58 -0
  500. tests/transforms/test_detect_envelope.py +159 -0
  501. tests/transforms/test_distance_transform_edt.py +202 -0
  502. tests/transforms/test_divisible_pad.py +49 -0
  503. tests/transforms/test_divisible_padd.py +42 -0
  504. tests/transforms/test_ensure_channel_first.py +113 -0
  505. tests/transforms/test_ensure_channel_firstd.py +85 -0
  506. tests/transforms/test_ensure_type.py +94 -0
  507. tests/transforms/test_ensure_typed.py +110 -0
  508. tests/transforms/test_fg_bg_to_indices.py +83 -0
  509. tests/transforms/test_fg_bg_to_indicesd.py +78 -0
  510. tests/transforms/test_fill_holes.py +207 -0
  511. tests/transforms/test_fill_holesd.py +209 -0
  512. tests/transforms/test_flatten_sub_keysd.py +64 -0
  513. tests/transforms/test_flip.py +83 -0
  514. tests/transforms/test_flipd.py +90 -0
  515. tests/transforms/test_fourier.py +70 -0
  516. tests/transforms/test_gaussian_sharpen.py +92 -0
  517. tests/transforms/test_gaussian_sharpend.py +92 -0
  518. tests/transforms/test_gaussian_smooth.py +96 -0
  519. tests/transforms/test_gaussian_smoothd.py +96 -0
  520. tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
  521. tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
  522. tests/transforms/test_generate_spatial_bounding_box.py +114 -0
  523. tests/transforms/test_get_extreme_points.py +57 -0
  524. tests/transforms/test_gibbs_noise.py +73 -0
  525. tests/transforms/test_gibbs_noised.py +88 -0
  526. tests/transforms/test_grid_distortion.py +113 -0
  527. tests/transforms/test_grid_distortiond.py +87 -0
  528. tests/transforms/test_grid_split.py +88 -0
  529. tests/transforms/test_grid_splitd.py +96 -0
  530. tests/transforms/test_histogram_normalize.py +59 -0
  531. tests/transforms/test_histogram_normalized.py +59 -0
  532. tests/transforms/test_image_filter.py +259 -0
  533. tests/transforms/test_intensity_stats.py +73 -0
  534. tests/transforms/test_intensity_statsd.py +90 -0
  535. tests/transforms/test_inverse.py +521 -0
  536. tests/transforms/test_inverse_collation.py +147 -0
  537. tests/transforms/test_invert.py +105 -0
  538. tests/transforms/test_invertd.py +142 -0
  539. tests/transforms/test_k_space_spike_noise.py +81 -0
  540. tests/transforms/test_k_space_spike_noised.py +98 -0
  541. tests/transforms/test_keep_largest_connected_component.py +419 -0
  542. tests/transforms/test_keep_largest_connected_componentd.py +348 -0
  543. tests/transforms/test_label_filter.py +78 -0
  544. tests/transforms/test_label_to_contour.py +179 -0
  545. tests/transforms/test_label_to_contourd.py +182 -0
  546. tests/transforms/test_label_to_mask.py +69 -0
  547. tests/transforms/test_label_to_maskd.py +70 -0
  548. tests/transforms/test_load_image.py +502 -0
  549. tests/transforms/test_load_imaged.py +198 -0
  550. tests/transforms/test_load_spacing_orientation.py +149 -0
  551. tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
  552. tests/transforms/test_map_binary_to_indices.py +75 -0
  553. tests/transforms/test_map_classes_to_indices.py +135 -0
  554. tests/transforms/test_map_label_value.py +89 -0
  555. tests/transforms/test_map_label_valued.py +85 -0
  556. tests/transforms/test_map_transform.py +45 -0
  557. tests/transforms/test_mask_intensity.py +74 -0
  558. tests/transforms/test_mask_intensityd.py +68 -0
  559. tests/transforms/test_mean_ensemble.py +77 -0
  560. tests/transforms/test_mean_ensembled.py +91 -0
  561. tests/transforms/test_median_smooth.py +41 -0
  562. tests/transforms/test_median_smoothd.py +65 -0
  563. tests/transforms/test_morphological_ops.py +101 -0
  564. tests/transforms/test_nifti_endianness.py +107 -0
  565. tests/transforms/test_normalize_intensity.py +143 -0
  566. tests/transforms/test_normalize_intensityd.py +81 -0
  567. tests/transforms/test_nvtx_decorator.py +289 -0
  568. tests/transforms/test_nvtx_transform.py +143 -0
  569. tests/transforms/test_orientation.py +247 -0
  570. tests/transforms/test_orientationd.py +112 -0
  571. tests/transforms/test_rand_adjust_contrast.py +45 -0
  572. tests/transforms/test_rand_adjust_contrastd.py +44 -0
  573. tests/transforms/test_rand_affine.py +201 -0
  574. tests/transforms/test_rand_affine_grid.py +212 -0
  575. tests/transforms/test_rand_affined.py +281 -0
  576. tests/transforms/test_rand_axis_flip.py +50 -0
  577. tests/transforms/test_rand_axis_flipd.py +50 -0
  578. tests/transforms/test_rand_bias_field.py +69 -0
  579. tests/transforms/test_rand_bias_fieldd.py +65 -0
  580. tests/transforms/test_rand_coarse_dropout.py +110 -0
  581. tests/transforms/test_rand_coarse_dropoutd.py +107 -0
  582. tests/transforms/test_rand_coarse_shuffle.py +65 -0
  583. tests/transforms/test_rand_coarse_shuffled.py +59 -0
  584. tests/transforms/test_rand_crop_by_label_classes.py +170 -0
  585. tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
  586. tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
  587. tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
  588. tests/transforms/test_rand_cucim_dict_transform.py +162 -0
  589. tests/transforms/test_rand_cucim_transform.py +162 -0
  590. tests/transforms/test_rand_deform_grid.py +138 -0
  591. tests/transforms/test_rand_elastic_2d.py +127 -0
  592. tests/transforms/test_rand_elastic_3d.py +104 -0
  593. tests/transforms/test_rand_elasticd_2d.py +177 -0
  594. tests/transforms/test_rand_elasticd_3d.py +156 -0
  595. tests/transforms/test_rand_flip.py +60 -0
  596. tests/transforms/test_rand_flipd.py +55 -0
  597. tests/transforms/test_rand_gaussian_noise.py +48 -0
  598. tests/transforms/test_rand_gaussian_noised.py +54 -0
  599. tests/transforms/test_rand_gaussian_sharpen.py +140 -0
  600. tests/transforms/test_rand_gaussian_sharpend.py +143 -0
  601. tests/transforms/test_rand_gaussian_smooth.py +98 -0
  602. tests/transforms/test_rand_gaussian_smoothd.py +98 -0
  603. tests/transforms/test_rand_gibbs_noise.py +103 -0
  604. tests/transforms/test_rand_gibbs_noised.py +117 -0
  605. tests/transforms/test_rand_grid_distortion.py +99 -0
  606. tests/transforms/test_rand_grid_distortiond.py +90 -0
  607. tests/transforms/test_rand_histogram_shift.py +92 -0
  608. tests/transforms/test_rand_k_space_spike_noise.py +92 -0
  609. tests/transforms/test_rand_k_space_spike_noised.py +76 -0
  610. tests/transforms/test_rand_rician_noise.py +52 -0
  611. tests/transforms/test_rand_rician_noised.py +52 -0
  612. tests/transforms/test_rand_rotate.py +166 -0
  613. tests/transforms/test_rand_rotate90.py +100 -0
  614. tests/transforms/test_rand_rotate90d.py +112 -0
  615. tests/transforms/test_rand_rotated.py +187 -0
  616. tests/transforms/test_rand_scale_crop.py +78 -0
  617. tests/transforms/test_rand_scale_cropd.py +98 -0
  618. tests/transforms/test_rand_scale_intensity.py +54 -0
  619. tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
  620. tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
  621. tests/transforms/test_rand_scale_intensityd.py +53 -0
  622. tests/transforms/test_rand_shift_intensity.py +52 -0
  623. tests/transforms/test_rand_shift_intensityd.py +67 -0
  624. tests/transforms/test_rand_simulate_low_resolution.py +83 -0
  625. tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
  626. tests/transforms/test_rand_spatial_crop.py +107 -0
  627. tests/transforms/test_rand_spatial_crop_samples.py +128 -0
  628. tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
  629. tests/transforms/test_rand_spatial_cropd.py +112 -0
  630. tests/transforms/test_rand_std_shift_intensity.py +43 -0
  631. tests/transforms/test_rand_std_shift_intensityd.py +38 -0
  632. tests/transforms/test_rand_zoom.py +105 -0
  633. tests/transforms/test_rand_zoomd.py +108 -0
  634. tests/transforms/test_randidentity.py +49 -0
  635. tests/transforms/test_random_order.py +144 -0
  636. tests/transforms/test_randtorchvisiond.py +65 -0
  637. tests/transforms/test_regularization.py +139 -0
  638. tests/transforms/test_remove_repeated_channel.py +34 -0
  639. tests/transforms/test_remove_repeated_channeld.py +44 -0
  640. tests/transforms/test_repeat_channel.py +34 -0
  641. tests/transforms/test_repeat_channeld.py +41 -0
  642. tests/transforms/test_resample_backends.py +65 -0
  643. tests/transforms/test_resample_to_match.py +110 -0
  644. tests/transforms/test_resample_to_matchd.py +93 -0
  645. tests/transforms/test_resampler.py +165 -0
  646. tests/transforms/test_resize.py +140 -0
  647. tests/transforms/test_resize_with_pad_or_crop.py +91 -0
  648. tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
  649. tests/transforms/test_resized.py +163 -0
  650. tests/transforms/test_rotate.py +160 -0
  651. tests/transforms/test_rotate90.py +212 -0
  652. tests/transforms/test_rotate90d.py +106 -0
  653. tests/transforms/test_rotated.py +179 -0
  654. tests/transforms/test_save_classificationd.py +109 -0
  655. tests/transforms/test_save_image.py +80 -0
  656. tests/transforms/test_save_imaged.py +130 -0
  657. tests/transforms/test_savitzky_golay_smooth.py +73 -0
  658. tests/transforms/test_savitzky_golay_smoothd.py +73 -0
  659. tests/transforms/test_scale_intensity.py +76 -0
  660. tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
  661. tests/transforms/test_scale_intensity_range.py +41 -0
  662. tests/transforms/test_scale_intensity_ranged.py +40 -0
  663. tests/transforms/test_scale_intensityd.py +57 -0
  664. tests/transforms/test_select_itemsd.py +41 -0
  665. tests/transforms/test_shift_intensity.py +31 -0
  666. tests/transforms/test_shift_intensityd.py +44 -0
  667. tests/transforms/test_signal_continuouswavelet.py +44 -0
  668. tests/transforms/test_signal_fillempty.py +52 -0
  669. tests/transforms/test_signal_fillemptyd.py +60 -0
  670. tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
  671. tests/transforms/test_signal_rand_add_sine.py +52 -0
  672. tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
  673. tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
  674. tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
  675. tests/transforms/test_signal_rand_drop.py +50 -0
  676. tests/transforms/test_signal_rand_scale.py +52 -0
  677. tests/transforms/test_signal_rand_shift.py +55 -0
  678. tests/transforms/test_signal_remove_frequency.py +71 -0
  679. tests/transforms/test_smooth_field.py +177 -0
  680. tests/transforms/test_sobel_gradient.py +189 -0
  681. tests/transforms/test_sobel_gradientd.py +212 -0
  682. tests/transforms/test_spacing.py +381 -0
  683. tests/transforms/test_spacingd.py +178 -0
  684. tests/transforms/test_spatial_crop.py +82 -0
  685. tests/transforms/test_spatial_cropd.py +74 -0
  686. tests/transforms/test_spatial_pad.py +57 -0
  687. tests/transforms/test_spatial_padd.py +43 -0
  688. tests/transforms/test_spatial_resample.py +235 -0
  689. tests/transforms/test_squeezedim.py +62 -0
  690. tests/transforms/test_squeezedimd.py +98 -0
  691. tests/transforms/test_std_shift_intensity.py +76 -0
  692. tests/transforms/test_std_shift_intensityd.py +74 -0
  693. tests/transforms/test_threshold_intensity.py +38 -0
  694. tests/transforms/test_threshold_intensityd.py +58 -0
  695. tests/transforms/test_to_contiguous.py +47 -0
  696. tests/transforms/test_to_cupy.py +112 -0
  697. tests/transforms/test_to_cupyd.py +76 -0
  698. tests/transforms/test_to_device.py +42 -0
  699. tests/transforms/test_to_deviced.py +37 -0
  700. tests/transforms/test_to_numpy.py +85 -0
  701. tests/transforms/test_to_numpyd.py +68 -0
  702. tests/transforms/test_to_pil.py +52 -0
  703. tests/transforms/test_to_pild.py +50 -0
  704. tests/transforms/test_to_tensor.py +60 -0
  705. tests/transforms/test_to_tensord.py +71 -0
  706. tests/transforms/test_torchvision.py +66 -0
  707. tests/transforms/test_torchvisiond.py +63 -0
  708. tests/transforms/test_transform.py +62 -0
  709. tests/transforms/test_transpose.py +41 -0
  710. tests/transforms/test_transposed.py +52 -0
  711. tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
  712. tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
  713. tests/transforms/test_vote_ensemble.py +84 -0
  714. tests/transforms/test_vote_ensembled.py +107 -0
  715. tests/transforms/test_with_allow_missing_keys.py +76 -0
  716. tests/transforms/test_zoom.py +120 -0
  717. tests/transforms/test_zoomd.py +94 -0
  718. tests/transforms/transform/__init__.py +10 -0
  719. tests/transforms/transform/test_randomizable.py +52 -0
  720. tests/transforms/transform/test_randomizable_transform_type.py +37 -0
  721. tests/transforms/utility/__init__.py +10 -0
  722. tests/transforms/utility/test_apply_transform_to_points.py +81 -0
  723. tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
  724. tests/transforms/utility/test_identity.py +29 -0
  725. tests/transforms/utility/test_identityd.py +30 -0
  726. tests/transforms/utility/test_lambda.py +71 -0
  727. tests/transforms/utility/test_lambdad.py +83 -0
  728. tests/transforms/utility/test_rand_lambda.py +87 -0
  729. tests/transforms/utility/test_rand_lambdad.py +77 -0
  730. tests/transforms/utility/test_simulatedelay.py +36 -0
  731. tests/transforms/utility/test_simulatedelayd.py +36 -0
  732. tests/transforms/utility/test_splitdim.py +52 -0
  733. tests/transforms/utility/test_splitdimd.py +96 -0
  734. tests/transforms/utils/__init__.py +10 -0
  735. tests/transforms/utils/test_correct_crop_centers.py +36 -0
  736. tests/transforms/utils/test_get_unique_labels.py +45 -0
  737. tests/transforms/utils/test_print_transform_backends.py +29 -0
  738. tests/transforms/utils/test_soft_clip.py +125 -0
  739. tests/utils/__init__.py +10 -0
  740. tests/utils/enums/__init__.py +10 -0
  741. tests/utils/enums/test_hovernet_loss.py +190 -0
  742. tests/utils/enums/test_ordering.py +289 -0
  743. tests/utils/enums/test_wsireader.py +663 -0
  744. tests/utils/misc/__init__.py +10 -0
  745. tests/utils/misc/test_ensure_tuple.py +53 -0
  746. tests/utils/misc/test_monai_env_vars.py +44 -0
  747. tests/utils/misc/test_monai_utils_misc.py +103 -0
  748. tests/utils/misc/test_str2bool.py +34 -0
  749. tests/utils/misc/test_str2list.py +33 -0
  750. tests/utils/test_alias.py +44 -0
  751. tests/utils/test_component_store.py +73 -0
  752. tests/utils/test_deprecated.py +455 -0
  753. tests/utils/test_enum_bound_interp.py +75 -0
  754. tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
  755. tests/utils/test_get_package_version.py +34 -0
  756. tests/utils/test_handler_logfile.py +84 -0
  757. tests/utils/test_handler_metric_logger.py +62 -0
  758. tests/utils/test_list_to_dict.py +43 -0
  759. tests/utils/test_look_up_option.py +87 -0
  760. tests/utils/test_optional_import.py +80 -0
  761. tests/utils/test_pad_mode.py +39 -0
  762. tests/utils/test_profiling.py +208 -0
  763. tests/utils/test_rankfilter_dist.py +77 -0
  764. tests/utils/test_require_pkg.py +83 -0
  765. tests/utils/test_sample_slices.py +43 -0
  766. tests/utils/test_set_determinism.py +74 -0
  767. tests/utils/test_squeeze_unsqueeze.py +71 -0
  768. tests/utils/test_state_cacher.py +67 -0
  769. tests/utils/test_torchscript_utils.py +113 -0
  770. tests/utils/test_version.py +91 -0
  771. tests/utils/test_version_after.py +65 -0
  772. tests/utils/type_conversion/__init__.py +10 -0
  773. tests/utils/type_conversion/test_convert_data_type.py +152 -0
  774. tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
  775. tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
  776. tests/visualize/__init__.py +10 -0
  777. tests/visualize/test_img2tensorboard.py +46 -0
  778. tests/visualize/test_occlusion_sensitivity.py +128 -0
  779. tests/visualize/test_plot_2d_or_3d_image.py +74 -0
  780. tests/visualize/test_vis_cam.py +98 -0
  781. tests/visualize/test_vis_gradcam.py +211 -0
  782. tests/visualize/utils/__init__.py +10 -0
  783. tests/visualize/utils/test_blend_images.py +63 -0
  784. tests/visualize/utils/test_matshow3d.py +133 -0
  785. monai_weekly-1.5.dev2506.dist-info/RECORD +0 -427
  786. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/LICENSE +0 -0
  787. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/WHEEL +0 -0
@@ -0,0 +1,232 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+ from unittest import skipUnless
16
+
17
+ import numpy as np
18
+ import torch
19
+ from parameterized import parameterized
20
+
21
+ from monai.networks import eval_mode
22
+ from monai.networks.blocks.selfattention import SABlock
23
+ from monai.networks.layers.factories import RelPosEmbedding
24
+ from monai.utils import optional_import
25
+ from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, test_script_save
26
+
27
+ einops, has_einops = optional_import("einops")
28
+
29
+ TEST_CASE_SABLOCK = []
30
+ for dropout_rate in np.linspace(0, 1, 4):
31
+ for hidden_size in [360, 480, 600, 768]:
32
+ for num_heads in [4, 6, 8, 12]:
33
+ for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:
34
+ for input_size in [(16, 32), (8, 8, 8)]:
35
+ for include_fc in [True, False]:
36
+ for use_combined_linear in [True, False]:
37
+ test_case = [
38
+ {
39
+ "hidden_size": hidden_size,
40
+ "num_heads": num_heads,
41
+ "dropout_rate": dropout_rate,
42
+ "rel_pos_embedding": rel_pos_embedding,
43
+ "input_size": input_size,
44
+ "include_fc": include_fc,
45
+ "use_combined_linear": use_combined_linear,
46
+ "use_flash_attention": True if rel_pos_embedding is None else False,
47
+ },
48
+ (2, 512, hidden_size),
49
+ (2, 512, hidden_size),
50
+ ]
51
+ TEST_CASE_SABLOCK.append(test_case)
52
+
53
+
54
+ class TestResBlock(unittest.TestCase):
55
+ @parameterized.expand(TEST_CASE_SABLOCK)
56
+ @skipUnless(has_einops, "Requires einops")
57
+ @SkipIfBeforePyTorchVersion((2, 0))
58
+ def test_shape(self, input_param, input_shape, expected_shape):
59
+ net = SABlock(**input_param)
60
+ with eval_mode(net):
61
+ result = net(torch.randn(input_shape))
62
+ self.assertEqual(result.shape, expected_shape)
63
+
64
+ def test_ill_arg(self):
65
+ with self.assertRaises(ValueError):
66
+ SABlock(hidden_size=128, num_heads=12, dropout_rate=6.0)
67
+
68
+ with self.assertRaises(ValueError):
69
+ SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4)
70
+
71
+ @SkipIfBeforePyTorchVersion((2, 0))
72
+ def test_rel_pos_embedding_with_flash_attention(self):
73
+ with self.assertRaises(ValueError):
74
+ SABlock(
75
+ hidden_size=128,
76
+ num_heads=3,
77
+ dropout_rate=0.1,
78
+ use_flash_attention=True,
79
+ save_attn=False,
80
+ rel_pos_embedding=RelPosEmbedding.DECOMPOSED,
81
+ )
82
+
83
+ @SkipIfBeforePyTorchVersion((1, 13))
84
+ def test_save_attn_with_flash_attention(self):
85
+ with self.assertRaises(ValueError):
86
+ SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True)
87
+
88
+ def test_attention_dim_not_multiple_of_heads(self):
89
+ with self.assertRaises(ValueError):
90
+ SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1)
91
+
92
+ @skipUnless(has_einops, "Requires einops")
93
+ def test_inner_dim_different(self):
94
+ SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30)
95
+
96
+ def test_causal_no_sequence_length(self):
97
+ with self.assertRaises(ValueError):
98
+ SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True)
99
+
100
+ @skipUnless(has_einops, "Requires einops")
101
+ @SkipIfBeforePyTorchVersion((2, 0))
102
+ def test_causal_flash_attention(self):
103
+ block = SABlock(
104
+ hidden_size=128,
105
+ num_heads=1,
106
+ dropout_rate=0.1,
107
+ causal=True,
108
+ sequence_length=16,
109
+ save_attn=False,
110
+ use_flash_attention=True,
111
+ )
112
+ input_shape = (1, 16, 128)
113
+ # Check it runs correctly
114
+ block(torch.randn(input_shape))
115
+
116
+ @skipUnless(has_einops, "Requires einops")
117
+ def test_causal(self):
118
+ block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True)
119
+ input_shape = (1, 16, 128)
120
+ block(torch.randn(input_shape))
121
+ # check upper triangular part of the attention matrix is zero
122
+ assert torch.triu(block.att_mat, diagonal=1).sum() == 0
123
+
124
+ def test_masked_selfattention(self):
125
+ n = 64
126
+ block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, sequence_length=16, save_attn=True)
127
+ input_shape = (1, n, 128)
128
+ # generate a mask randomly with zeros and ones of shape (1, n)
129
+ mask = torch.randint(0, 2, (1, n)).bool()
130
+ block(torch.randn(input_shape), attn_mask=mask)
131
+ att_mat = block.att_mat.squeeze()
132
+ # ensure all masked columns are zeros
133
+ assert torch.allclose(att_mat[:, ~mask.squeeze(0)], torch.zeros_like(att_mat[:, ~mask.squeeze(0)]))
134
+
135
+ def test_causal_and_mask(self):
136
+ with self.assertRaises(ValueError):
137
+ block = SABlock(hidden_size=128, num_heads=1, causal=True, sequence_length=64)
138
+ inputs = torch.randn(2, 64, 128)
139
+ mask = torch.randint(0, 2, (2, 64)).bool()
140
+ block(inputs, attn_mask=mask)
141
+
142
+ @skipUnless(has_einops, "Requires einops")
143
+ def test_access_attn_matrix(self):
144
+ # input format
145
+ hidden_size = 128
146
+ num_heads = 2
147
+ dropout_rate = 0
148
+ input_shape = (2, 256, hidden_size)
149
+
150
+ # be not able to access the matrix
151
+ no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate)
152
+ no_matrix_acess_blk(torch.randn(input_shape))
153
+ assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor)
154
+ # no of elements is zero
155
+ assert no_matrix_acess_blk.att_mat.nelement() == 0
156
+
157
+ # be able to acess the attention matrix
158
+ matrix_acess_blk = SABlock(
159
+ hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True
160
+ )
161
+ matrix_acess_blk(torch.randn(input_shape))
162
+ assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1])
163
+
164
+ def test_number_of_parameters(self):
165
+ def count_sablock_params(*args, **kwargs):
166
+ """Count the number of parameters in a SABlock."""
167
+ sablock = SABlock(*args, **kwargs)
168
+ return sum([x.numel() for x in sablock.parameters() if x.requires_grad])
169
+
170
+ hidden_size = 128
171
+ num_heads = 8
172
+ default_dim_head = hidden_size // num_heads
173
+
174
+ # Default dim_head is hidden_size // num_heads
175
+ nparams_default = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads)
176
+ nparams_like_default = count_sablock_params(
177
+ hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head
178
+ )
179
+ self.assertEqual(nparams_default, nparams_like_default)
180
+
181
+ # Increasing dim_head should increase the number of parameters
182
+ nparams_custom_large = count_sablock_params(
183
+ hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head * 2
184
+ )
185
+ self.assertGreater(nparams_custom_large, nparams_default)
186
+
187
+ # Decreasing dim_head should decrease the number of parameters
188
+ nparams_custom_small = count_sablock_params(
189
+ hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head // 2
190
+ )
191
+ self.assertGreater(nparams_default, nparams_custom_small)
192
+
193
+ # Increasing the number of heads with the default behaviour should not change the number of params.
194
+ nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2)
195
+ self.assertEqual(nparams_default, nparams_default_more_heads)
196
+
197
+ @parameterized.expand([[True, False], [True, True], [False, True], [False, False]])
198
+ @skipUnless(has_einops, "Requires einops")
199
+ @SkipIfBeforePyTorchVersion((2, 0))
200
+ def test_script(self, include_fc, use_combined_linear):
201
+ input_param = {
202
+ "hidden_size": 360,
203
+ "num_heads": 4,
204
+ "dropout_rate": 0.0,
205
+ "rel_pos_embedding": None,
206
+ "input_size": (16, 32),
207
+ "include_fc": include_fc,
208
+ "use_combined_linear": use_combined_linear,
209
+ }
210
+ net = SABlock(**input_param)
211
+ input_shape = (2, 512, 360)
212
+ test_data = torch.randn(input_shape)
213
+ test_script_save(net, test_data)
214
+
215
+ @skipUnless(has_einops, "Requires einops")
216
+ @SkipIfBeforePyTorchVersion((2, 0))
217
+ def test_flash_attention(self):
218
+ for causal in [True, False]:
219
+ input_param = {"hidden_size": 360, "num_heads": 4, "input_size": (16, 32), "causal": causal}
220
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
221
+ block_w_flash_attention = SABlock(**input_param, use_flash_attention=True).to(device)
222
+ block_wo_flash_attention = SABlock(**input_param, use_flash_attention=False).to(device)
223
+ block_wo_flash_attention.load_state_dict(block_w_flash_attention.state_dict())
224
+ test_data = torch.randn(2, 512, 360).to(device)
225
+
226
+ out_1 = block_w_flash_attention(test_data)
227
+ out_2 = block_wo_flash_attention(test_data)
228
+ assert_allclose(out_1, out_2, atol=1e-4)
229
+
230
+
231
+ if __name__ == "__main__":
232
+ unittest.main()
@@ -0,0 +1,87 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import torch
17
+ from parameterized import parameterized
18
+
19
+ from monai.networks import eval_mode
20
+ from monai.networks.blocks import SimpleASPP
21
+
22
+ TEST_CASES = [
23
+ [ # 32-channel 2D, batch 7
24
+ {"spatial_dims": 2, "in_channels": 32, "conv_out_channels": 3, "norm_type": ("batch", {"affine": False})},
25
+ (7, 32, 18, 20),
26
+ (7, 12, 18, 20),
27
+ ],
28
+ [ # 4-channel 1D, batch 16
29
+ {"spatial_dims": 1, "in_channels": 4, "conv_out_channels": 8, "acti_type": ("PRELU", {"num_parameters": 32})},
30
+ (16, 4, 17),
31
+ (16, 32, 17),
32
+ ],
33
+ [ # 3-channel 3D, batch 16
34
+ {"spatial_dims": 3, "in_channels": 3, "conv_out_channels": 2},
35
+ (16, 3, 17, 18, 19),
36
+ (16, 8, 17, 18, 19),
37
+ ],
38
+ [ # 3-channel 3D, batch 16
39
+ {
40
+ "spatial_dims": 3,
41
+ "in_channels": 3,
42
+ "conv_out_channels": 2,
43
+ "kernel_sizes": (1, 3, 3),
44
+ "dilations": (1, 2, 4),
45
+ },
46
+ (16, 3, 17, 18, 19),
47
+ (16, 6, 17, 18, 19),
48
+ ],
49
+ ]
50
+
51
+ TEST_ILL_CASES = [
52
+ [ # 3-channel 3D, batch 16, wrong k and d sizes.
53
+ {"spatial_dims": 3, "in_channels": 3, "conv_out_channels": 2, "kernel_sizes": (1, 3, 3), "dilations": (1, 2)},
54
+ (16, 3, 17, 18, 19),
55
+ ValueError,
56
+ ],
57
+ [ # 3-channel 3D, batch 16, wrong k and d sizes.
58
+ {
59
+ "spatial_dims": 3,
60
+ "in_channels": 3,
61
+ "conv_out_channels": 2,
62
+ "kernel_sizes": (1, 3, 4),
63
+ "dilations": (1, 2, 3),
64
+ },
65
+ (16, 3, 17, 18, 19),
66
+ NotImplementedError, # unknown padding k=4, d=3
67
+ ],
68
+ ]
69
+
70
+
71
+ class TestChannelSELayer(unittest.TestCase):
72
+
73
+ @parameterized.expand(TEST_CASES)
74
+ def test_shape(self, input_param, input_shape, expected_shape):
75
+ net = SimpleASPP(**input_param)
76
+ with eval_mode(net):
77
+ result = net(torch.randn(input_shape))
78
+ self.assertEqual(result.shape, expected_shape)
79
+
80
+ @parameterized.expand(TEST_ILL_CASES)
81
+ def test_ill_args(self, input_param, input_shape, error_type):
82
+ with self.assertRaises(error_type):
83
+ SimpleASPP(**input_param)
84
+
85
+
86
+ if __name__ == "__main__":
87
+ unittest.main()
@@ -0,0 +1,55 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+ from unittest import skipUnless
16
+
17
+ import torch
18
+ from parameterized import parameterized
19
+
20
+ from monai.networks import eval_mode
21
+ from monai.networks.blocks.spatialattention import SpatialAttentionBlock
22
+ from monai.utils import optional_import
23
+
24
+ einops, has_einops = optional_import("einops")
25
+
26
+ TEST_CASES = [
27
+ [
28
+ {"spatial_dims": 2, "num_channels": 128, "num_head_channels": 32, "norm_num_groups": 32, "norm_eps": 1e-6},
29
+ (1, 128, 32, 32),
30
+ (1, 128, 32, 32),
31
+ ],
32
+ [
33
+ {"spatial_dims": 3, "num_channels": 16, "num_head_channels": 8, "norm_num_groups": 8, "norm_eps": 1e-6},
34
+ (1, 16, 8, 8, 8),
35
+ (1, 16, 8, 8, 8),
36
+ ],
37
+ ]
38
+
39
+
40
+ class TestBlock(unittest.TestCase):
41
+ @parameterized.expand(TEST_CASES)
42
+ @skipUnless(has_einops, "Requires einops")
43
+ def test_shape(self, input_param, input_shape, expected_shape):
44
+ net = SpatialAttentionBlock(**input_param)
45
+ with eval_mode(net):
46
+ result = net(torch.randn(input_shape))
47
+ self.assertEqual(result.shape, expected_shape)
48
+
49
+ def test_attention_dim_not_multiple_of_heads(self):
50
+ with self.assertRaises(ValueError):
51
+ SpatialAttentionBlock(spatial_dims=2, num_channels=128, num_head_channels=33)
52
+
53
+
54
+ if __name__ == "__main__":
55
+ unittest.main()
@@ -0,0 +1,87 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from parameterized import parameterized
19
+
20
+ from monai.networks import eval_mode
21
+ from monai.networks.blocks import SubpixelUpsample
22
+ from monai.networks.layers.factories import Conv
23
+ from tests.test_utils import SkipIfBeforePyTorchVersion, test_script_save
24
+
25
+ TEST_CASE_SUBPIXEL = []
26
+ for inch in range(1, 5):
27
+ for dim in range(1, 4):
28
+ for factor in range(1, 3):
29
+ test_case = [
30
+ {"spatial_dims": dim, "in_channels": inch, "scale_factor": factor},
31
+ (2, inch, *([8] * dim)),
32
+ (2, inch, *([8 * factor] * dim)),
33
+ ]
34
+ TEST_CASE_SUBPIXEL.append(test_case)
35
+
36
+ TEST_CASE_SUBPIXEL_2D_EXTRA = [
37
+ {"spatial_dims": 2, "in_channels": 2, "scale_factor": 3},
38
+ (2, 2, 8, 4), # different size for H and W
39
+ (2, 2, 24, 12),
40
+ ]
41
+
42
+ TEST_CASE_SUBPIXEL_3D_EXTRA = [
43
+ {"spatial_dims": 3, "in_channels": 1, "scale_factor": 2},
44
+ (2, 1, 16, 8, 4), # different size for H, W and D
45
+ (2, 1, 32, 16, 8),
46
+ ]
47
+
48
+ conv_block = nn.Sequential(
49
+ Conv[Conv.CONV, 3](1, 4, kernel_size=1), Conv[Conv.CONV, 3](4, 8, kernel_size=3, stride=1, padding=1)
50
+ )
51
+
52
+ TEST_CASE_SUBPIXEL_CONV_BLOCK_EXTRA = [
53
+ {"spatial_dims": 3, "in_channels": 1, "scale_factor": 2, "conv_block": conv_block},
54
+ (2, 1, 16, 8, 4), # different size for H, W and D
55
+ (2, 1, 32, 16, 8),
56
+ ]
57
+
58
+ TEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_2D_EXTRA) # type: ignore
59
+ TEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_3D_EXTRA) # type: ignore
60
+ TEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_CONV_BLOCK_EXTRA) # type: ignore
61
+
62
+ # add every test back with the pad/pool sequential component omitted
63
+ for tests in list(TEST_CASE_SUBPIXEL):
64
+ args: dict = tests[0] # type: ignore
65
+ args = dict(args)
66
+ args["apply_pad_pool"] = False
67
+ TEST_CASE_SUBPIXEL.append([args, tests[1], tests[2]])
68
+
69
+
70
+ class TestSUBPIXEL(unittest.TestCase):
71
+ @parameterized.expand(TEST_CASE_SUBPIXEL)
72
+ def test_subpixel_shape(self, input_param, input_shape, expected_shape):
73
+ net = SubpixelUpsample(**input_param)
74
+ with eval_mode(net):
75
+ result = net.forward(torch.randn(input_shape))
76
+ self.assertEqual(result.shape, expected_shape)
77
+
78
+ @SkipIfBeforePyTorchVersion((1, 8, 1))
79
+ def test_script(self):
80
+ input_param, input_shape, _ = TEST_CASE_SUBPIXEL[0]
81
+ net = SubpixelUpsample(**input_param)
82
+ test_data = torch.randn(input_shape)
83
+ test_script_save(net, test_data)
84
+
85
+
86
+ if __name__ == "__main__":
87
+ unittest.main()
@@ -0,0 +1,49 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ from monai.networks.blocks.text_embedding import TextEncoder
17
+ from tests.test_utils import skip_if_downloading_fails
18
+
19
+
20
+ class TestTextEncoder(unittest.TestCase):
21
+ def test_test_encoding_shape(self):
22
+ with skip_if_downloading_fails():
23
+ # test 2D encoder
24
+ text_encoder = TextEncoder(
25
+ spatial_dims=2, out_channels=32, encoding="clip_encoding_universal_model_32", pretrained=True
26
+ )
27
+ text_encoding = text_encoder()
28
+ self.assertEqual(text_encoding.shape, (32, 256, 1, 1))
29
+
30
+ # test 3D encoder
31
+ text_encoder = TextEncoder(
32
+ spatial_dims=3, out_channels=32, encoding="clip_encoding_universal_model_32", pretrained=True
33
+ )
34
+ text_encoding = text_encoder()
35
+ self.assertEqual(text_encoding.shape, (32, 256, 1, 1, 1))
36
+
37
+ # test random enbedding 3D
38
+ text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="rand_embedding", pretrained=True)
39
+ text_encoding = text_encoder()
40
+ self.assertEqual(text_encoding.shape, (32, 256, 1, 1, 1))
41
+
42
+ # test random enbedding 2D
43
+ text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="rand_embedding", pretrained=True)
44
+ text_encoding = text_encoder()
45
+ self.assertEqual(text_encoding.shape, (32, 256, 1, 1))
46
+
47
+
48
+ if __name__ == "__main__":
49
+ unittest.main()
@@ -0,0 +1,90 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+ from unittest import skipUnless
16
+
17
+ import numpy as np
18
+ import torch
19
+ from parameterized import parameterized
20
+
21
+ from monai.networks import eval_mode
22
+ from monai.networks.blocks.transformerblock import TransformerBlock
23
+ from monai.utils import optional_import
24
+
25
+ einops, has_einops = optional_import("einops")
26
+ TEST_CASE_TRANSFORMERBLOCK = []
27
+ for dropout_rate in np.linspace(0, 1, 4):
28
+ for hidden_size in [360, 480, 600, 768]:
29
+ for num_heads in [4, 8, 12]:
30
+ for mlp_dim in [1024, 3072]:
31
+ for cross_attention in [False, True]:
32
+ test_case = [
33
+ {
34
+ "hidden_size": hidden_size,
35
+ "num_heads": num_heads,
36
+ "mlp_dim": mlp_dim,
37
+ "dropout_rate": dropout_rate,
38
+ "with_cross_attention": cross_attention,
39
+ },
40
+ (2, 512, hidden_size),
41
+ (2, 512, hidden_size),
42
+ ]
43
+ TEST_CASE_TRANSFORMERBLOCK.append(test_case)
44
+
45
+
46
+ class TestTransformerBlock(unittest.TestCase):
47
+
48
+ @parameterized.expand(TEST_CASE_TRANSFORMERBLOCK)
49
+ @skipUnless(has_einops, "Requires einops")
50
+ def test_shape(self, input_param, input_shape, expected_shape):
51
+ net = TransformerBlock(**input_param)
52
+ with eval_mode(net):
53
+ result = net(torch.randn(input_shape))
54
+ self.assertEqual(result.shape, expected_shape)
55
+
56
+ def test_ill_arg(self):
57
+ with self.assertRaises(ValueError):
58
+ TransformerBlock(hidden_size=128, num_heads=12, mlp_dim=2048, dropout_rate=4.0)
59
+
60
+ with self.assertRaises(ValueError):
61
+ TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4)
62
+
63
+ @skipUnless(has_einops, "Requires einops")
64
+ def test_access_attn_matrix(self):
65
+ # input format
66
+ hidden_size = 128
67
+ mlp_dim = 12
68
+ num_heads = 2
69
+ dropout_rate = 0
70
+ input_shape = (2, 256, hidden_size)
71
+
72
+ # returns an empty attention matrix
73
+ no_matrix_acess_blk = TransformerBlock(
74
+ hidden_size=hidden_size, mlp_dim=mlp_dim, num_heads=num_heads, dropout_rate=dropout_rate
75
+ )
76
+ no_matrix_acess_blk(torch.randn(input_shape))
77
+ assert isinstance(no_matrix_acess_blk.attn.att_mat, torch.Tensor)
78
+ # no of elements is zero
79
+ assert no_matrix_acess_blk.attn.att_mat.nelement() == 0
80
+
81
+ # be able to acess the attention matrix
82
+ matrix_acess_blk = TransformerBlock(
83
+ hidden_size=hidden_size, mlp_dim=mlp_dim, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True
84
+ )
85
+ matrix_acess_blk(torch.randn(input_shape))
86
+ assert matrix_acess_blk.attn.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1])
87
+
88
+
89
+ if __name__ == "__main__":
90
+ unittest.main()