monai-weekly 1.5.dev2506__py3-none-any.whl → 1.5.dev2508__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (787) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/transforms.py +1 -4
  4. monai/data/utils.py +6 -13
  5. monai/handlers/__init__.py +1 -0
  6. monai/handlers/average_precision.py +53 -0
  7. monai/inferers/inferer.py +10 -7
  8. monai/inferers/utils.py +1 -2
  9. monai/losses/dice.py +2 -14
  10. monai/losses/ds_loss.py +1 -3
  11. monai/metrics/__init__.py +1 -0
  12. monai/metrics/average_precision.py +187 -0
  13. monai/networks/layers/simplelayers.py +2 -14
  14. monai/networks/utils.py +4 -16
  15. monai/transforms/compose.py +28 -11
  16. monai/transforms/croppad/array.py +1 -6
  17. monai/transforms/io/array.py +0 -1
  18. monai/transforms/transform.py +15 -6
  19. monai/transforms/utility/array.py +2 -12
  20. monai/transforms/utils.py +1 -2
  21. monai/transforms/utils_pytorch_numpy_unification.py +2 -4
  22. monai/utils/enums.py +3 -2
  23. monai/utils/module.py +6 -6
  24. monai/utils/tf32.py +0 -10
  25. monai/visualize/class_activation_maps.py +5 -8
  26. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/METADATA +21 -17
  27. monai_weekly-1.5.dev2508.dist-info/RECORD +1185 -0
  28. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/top_level.txt +1 -0
  29. tests/apps/__init__.py +10 -0
  30. tests/apps/deepedit/__init__.py +10 -0
  31. tests/apps/deepedit/test_deepedit_transforms.py +314 -0
  32. tests/apps/deepgrow/__init__.py +10 -0
  33. tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
  34. tests/apps/deepgrow/transforms/__init__.py +10 -0
  35. tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
  36. tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
  37. tests/apps/detection/__init__.py +10 -0
  38. tests/apps/detection/metrics/__init__.py +10 -0
  39. tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
  40. tests/apps/detection/networks/__init__.py +10 -0
  41. tests/apps/detection/networks/test_retinanet.py +210 -0
  42. tests/apps/detection/networks/test_retinanet_detector.py +203 -0
  43. tests/apps/detection/test_box_transform.py +370 -0
  44. tests/apps/detection/utils/__init__.py +10 -0
  45. tests/apps/detection/utils/test_anchor_box.py +88 -0
  46. tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
  47. tests/apps/detection/utils/test_box_coder.py +43 -0
  48. tests/apps/detection/utils/test_detector_boxselector.py +67 -0
  49. tests/apps/detection/utils/test_detector_utils.py +96 -0
  50. tests/apps/detection/utils/test_hardnegsampler.py +54 -0
  51. tests/apps/nuclick/__init__.py +10 -0
  52. tests/apps/nuclick/test_nuclick_transforms.py +259 -0
  53. tests/apps/pathology/__init__.py +10 -0
  54. tests/apps/pathology/handlers/__init__.py +10 -0
  55. tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
  56. tests/apps/pathology/test_lesion_froc.py +333 -0
  57. tests/apps/pathology/test_pathology_prob_nms.py +55 -0
  58. tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
  59. tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
  60. tests/apps/pathology/transforms/__init__.py +10 -0
  61. tests/apps/pathology/transforms/post/__init__.py +10 -0
  62. tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
  63. tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
  64. tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
  65. tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
  66. tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
  67. tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
  68. tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
  69. tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
  70. tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
  71. tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
  72. tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
  73. tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
  74. tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
  75. tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
  76. tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
  77. tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
  78. tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
  79. tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
  80. tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
  81. tests/apps/pathology/transforms/post/test_watershed.py +60 -0
  82. tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
  83. tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
  84. tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
  85. tests/apps/reconstruction/__init__.py +10 -0
  86. tests/apps/reconstruction/nets/__init__.py +10 -0
  87. tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
  88. tests/apps/reconstruction/test_complex_utils.py +77 -0
  89. tests/apps/reconstruction/test_fastmri_reader.py +82 -0
  90. tests/apps/reconstruction/test_mri_utils.py +37 -0
  91. tests/apps/reconstruction/transforms/__init__.py +10 -0
  92. tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
  93. tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
  94. tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
  95. tests/apps/test_auto3dseg_bundlegen.py +156 -0
  96. tests/apps/test_check_hash.py +53 -0
  97. tests/apps/test_cross_validation.py +74 -0
  98. tests/apps/test_decathlondataset.py +93 -0
  99. tests/apps/test_download_and_extract.py +70 -0
  100. tests/apps/test_download_url_yandex.py +45 -0
  101. tests/apps/test_mednistdataset.py +72 -0
  102. tests/apps/test_mmar_download.py +154 -0
  103. tests/apps/test_tciadataset.py +123 -0
  104. tests/apps/vista3d/__init__.py +10 -0
  105. tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
  106. tests/apps/vista3d/test_vista3d_sampler.py +100 -0
  107. tests/apps/vista3d/test_vista3d_transforms.py +94 -0
  108. tests/bundle/__init__.py +10 -0
  109. tests/bundle/test_bundle_ckpt_export.py +107 -0
  110. tests/bundle/test_bundle_download.py +435 -0
  111. tests/bundle/test_bundle_get_data.py +94 -0
  112. tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
  113. tests/bundle/test_bundle_trt_export.py +147 -0
  114. tests/bundle/test_bundle_utils.py +149 -0
  115. tests/bundle/test_bundle_verify_metadata.py +66 -0
  116. tests/bundle/test_bundle_verify_net.py +76 -0
  117. tests/bundle/test_bundle_workflow.py +272 -0
  118. tests/bundle/test_component_locator.py +38 -0
  119. tests/bundle/test_config_item.py +138 -0
  120. tests/bundle/test_config_parser.py +392 -0
  121. tests/bundle/test_reference_resolver.py +114 -0
  122. tests/config/__init__.py +10 -0
  123. tests/config/test_cv2_dist.py +53 -0
  124. tests/engines/__init__.py +10 -0
  125. tests/engines/test_ensemble_evaluator.py +94 -0
  126. tests/engines/test_prepare_batch_default.py +76 -0
  127. tests/engines/test_prepare_batch_default_dist.py +76 -0
  128. tests/engines/test_prepare_batch_diffusion.py +104 -0
  129. tests/engines/test_prepare_batch_extra_input.py +80 -0
  130. tests/fl/__init__.py +10 -0
  131. tests/fl/monai_algo/__init__.py +10 -0
  132. tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
  133. tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
  134. tests/fl/test_fl_monai_algo_stats.py +81 -0
  135. tests/fl/utils/__init__.py +10 -0
  136. tests/fl/utils/test_fl_exchange_object.py +63 -0
  137. tests/handlers/__init__.py +10 -0
  138. tests/handlers/test_handler_average_precision.py +79 -0
  139. tests/handlers/test_handler_checkpoint_loader.py +182 -0
  140. tests/handlers/test_handler_checkpoint_saver.py +233 -0
  141. tests/handlers/test_handler_classification_saver.py +64 -0
  142. tests/handlers/test_handler_classification_saver_dist.py +77 -0
  143. tests/handlers/test_handler_clearml_image.py +65 -0
  144. tests/handlers/test_handler_clearml_stats.py +65 -0
  145. tests/handlers/test_handler_confusion_matrix.py +104 -0
  146. tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
  147. tests/handlers/test_handler_decollate_batch.py +66 -0
  148. tests/handlers/test_handler_early_stop.py +68 -0
  149. tests/handlers/test_handler_garbage_collector.py +73 -0
  150. tests/handlers/test_handler_hausdorff_distance.py +111 -0
  151. tests/handlers/test_handler_ignite_metric.py +191 -0
  152. tests/handlers/test_handler_lr_scheduler.py +94 -0
  153. tests/handlers/test_handler_mean_dice.py +98 -0
  154. tests/handlers/test_handler_mean_iou.py +76 -0
  155. tests/handlers/test_handler_metrics_reloaded.py +149 -0
  156. tests/handlers/test_handler_metrics_saver.py +89 -0
  157. tests/handlers/test_handler_metrics_saver_dist.py +120 -0
  158. tests/handlers/test_handler_mlflow.py +296 -0
  159. tests/handlers/test_handler_nvtx.py +93 -0
  160. tests/handlers/test_handler_panoptic_quality.py +89 -0
  161. tests/handlers/test_handler_parameter_scheduler.py +136 -0
  162. tests/handlers/test_handler_post_processing.py +74 -0
  163. tests/handlers/test_handler_prob_map_producer.py +111 -0
  164. tests/handlers/test_handler_regression_metrics.py +160 -0
  165. tests/handlers/test_handler_regression_metrics_dist.py +245 -0
  166. tests/handlers/test_handler_rocauc.py +48 -0
  167. tests/handlers/test_handler_rocauc_dist.py +54 -0
  168. tests/handlers/test_handler_stats.py +281 -0
  169. tests/handlers/test_handler_surface_distance.py +113 -0
  170. tests/handlers/test_handler_tb_image.py +61 -0
  171. tests/handlers/test_handler_tb_stats.py +166 -0
  172. tests/handlers/test_handler_validation.py +59 -0
  173. tests/handlers/test_trt_compile.py +145 -0
  174. tests/handlers/test_write_metrics_reports.py +68 -0
  175. tests/inferers/__init__.py +10 -0
  176. tests/inferers/test_avg_merger.py +179 -0
  177. tests/inferers/test_controlnet_inferers.py +1388 -0
  178. tests/inferers/test_diffusion_inferer.py +236 -0
  179. tests/inferers/test_latent_diffusion_inferer.py +884 -0
  180. tests/inferers/test_patch_inferer.py +309 -0
  181. tests/inferers/test_saliency_inferer.py +55 -0
  182. tests/inferers/test_slice_inferer.py +57 -0
  183. tests/inferers/test_sliding_window_inference.py +377 -0
  184. tests/inferers/test_sliding_window_splitter.py +284 -0
  185. tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
  186. tests/inferers/test_zarr_avg_merger.py +326 -0
  187. tests/integration/__init__.py +10 -0
  188. tests/integration/test_auto3dseg_ensemble.py +211 -0
  189. tests/integration/test_auto3dseg_hpo.py +189 -0
  190. tests/integration/test_deepedit_interaction.py +122 -0
  191. tests/integration/test_downsample_block.py +50 -0
  192. tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
  193. tests/integration/test_integration_autorunner.py +201 -0
  194. tests/integration/test_integration_bundle_run.py +240 -0
  195. tests/integration/test_integration_classification_2d.py +282 -0
  196. tests/integration/test_integration_determinism.py +95 -0
  197. tests/integration/test_integration_fast_train.py +231 -0
  198. tests/integration/test_integration_gpu_customization.py +159 -0
  199. tests/integration/test_integration_lazy_samples.py +219 -0
  200. tests/integration/test_integration_nnunetv2_runner.py +96 -0
  201. tests/integration/test_integration_segmentation_3d.py +304 -0
  202. tests/integration/test_integration_sliding_window.py +100 -0
  203. tests/integration/test_integration_stn.py +133 -0
  204. tests/integration/test_integration_unet_2d.py +67 -0
  205. tests/integration/test_integration_workers.py +61 -0
  206. tests/integration/test_integration_workflows.py +365 -0
  207. tests/integration/test_integration_workflows_adversarial.py +173 -0
  208. tests/integration/test_integration_workflows_gan.py +158 -0
  209. tests/integration/test_loader_semaphore.py +48 -0
  210. tests/integration/test_mapping_filed.py +122 -0
  211. tests/integration/test_meta_affine.py +183 -0
  212. tests/integration/test_metatensor_integration.py +114 -0
  213. tests/integration/test_module_list.py +76 -0
  214. tests/integration/test_one_of.py +283 -0
  215. tests/integration/test_pad_collation.py +124 -0
  216. tests/integration/test_reg_loss_integration.py +107 -0
  217. tests/integration/test_retinanet_predict_utils.py +154 -0
  218. tests/integration/test_seg_loss_integration.py +159 -0
  219. tests/integration/test_spatial_combine_transforms.py +185 -0
  220. tests/integration/test_testtimeaugmentation.py +186 -0
  221. tests/integration/test_vis_gradbased.py +69 -0
  222. tests/integration/test_vista3d_utils.py +159 -0
  223. tests/losses/__init__.py +10 -0
  224. tests/losses/deform/__init__.py +10 -0
  225. tests/losses/deform/test_bending_energy.py +88 -0
  226. tests/losses/deform/test_diffusion_loss.py +117 -0
  227. tests/losses/image_dissimilarity/__init__.py +10 -0
  228. tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
  229. tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
  230. tests/losses/test_adversarial_loss.py +94 -0
  231. tests/losses/test_barlow_twins_loss.py +109 -0
  232. tests/losses/test_cldice_loss.py +51 -0
  233. tests/losses/test_contrastive_loss.py +86 -0
  234. tests/losses/test_dice_ce_loss.py +123 -0
  235. tests/losses/test_dice_focal_loss.py +124 -0
  236. tests/losses/test_dice_loss.py +227 -0
  237. tests/losses/test_ds_loss.py +189 -0
  238. tests/losses/test_focal_loss.py +379 -0
  239. tests/losses/test_generalized_dice_focal_loss.py +85 -0
  240. tests/losses/test_generalized_dice_loss.py +221 -0
  241. tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
  242. tests/losses/test_giou_loss.py +62 -0
  243. tests/losses/test_hausdorff_loss.py +264 -0
  244. tests/losses/test_masked_dice_loss.py +152 -0
  245. tests/losses/test_masked_loss.py +87 -0
  246. tests/losses/test_multi_scale.py +86 -0
  247. tests/losses/test_nacl_loss.py +167 -0
  248. tests/losses/test_perceptual_loss.py +122 -0
  249. tests/losses/test_spectral_loss.py +86 -0
  250. tests/losses/test_ssim_loss.py +59 -0
  251. tests/losses/test_sure_loss.py +72 -0
  252. tests/losses/test_tversky_loss.py +198 -0
  253. tests/losses/test_unified_focal_loss.py +66 -0
  254. tests/metrics/__init__.py +10 -0
  255. tests/metrics/test_compute_average_precision.py +162 -0
  256. tests/metrics/test_compute_confusion_matrix.py +294 -0
  257. tests/metrics/test_compute_f_beta.py +80 -0
  258. tests/metrics/test_compute_fid_metric.py +40 -0
  259. tests/metrics/test_compute_froc.py +143 -0
  260. tests/metrics/test_compute_generalized_dice.py +240 -0
  261. tests/metrics/test_compute_meandice.py +306 -0
  262. tests/metrics/test_compute_meaniou.py +223 -0
  263. tests/metrics/test_compute_mmd_metric.py +56 -0
  264. tests/metrics/test_compute_multiscalessim_metric.py +83 -0
  265. tests/metrics/test_compute_panoptic_quality.py +113 -0
  266. tests/metrics/test_compute_regression_metrics.py +196 -0
  267. tests/metrics/test_compute_roc_auc.py +155 -0
  268. tests/metrics/test_compute_variance.py +147 -0
  269. tests/metrics/test_cumulative.py +63 -0
  270. tests/metrics/test_cumulative_average.py +74 -0
  271. tests/metrics/test_cumulative_average_dist.py +48 -0
  272. tests/metrics/test_hausdorff_distance.py +209 -0
  273. tests/metrics/test_label_quality_score.py +134 -0
  274. tests/metrics/test_loss_metric.py +57 -0
  275. tests/metrics/test_metrics_reloaded.py +96 -0
  276. tests/metrics/test_ssim_metric.py +78 -0
  277. tests/metrics/test_surface_dice.py +416 -0
  278. tests/metrics/test_surface_distance.py +186 -0
  279. tests/networks/__init__.py +10 -0
  280. tests/networks/blocks/__init__.py +10 -0
  281. tests/networks/blocks/dints_block/__init__.py +10 -0
  282. tests/networks/blocks/dints_block/test_acn_block.py +41 -0
  283. tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
  284. tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
  285. tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
  286. tests/networks/blocks/test_adn.py +86 -0
  287. tests/networks/blocks/test_convolutions.py +156 -0
  288. tests/networks/blocks/test_crf_cpu.py +513 -0
  289. tests/networks/blocks/test_crf_cuda.py +528 -0
  290. tests/networks/blocks/test_crossattention.py +185 -0
  291. tests/networks/blocks/test_denseblock.py +105 -0
  292. tests/networks/blocks/test_dynunet_block.py +116 -0
  293. tests/networks/blocks/test_fpn_block.py +88 -0
  294. tests/networks/blocks/test_localnet_block.py +121 -0
  295. tests/networks/blocks/test_mlp.py +78 -0
  296. tests/networks/blocks/test_patchembedding.py +212 -0
  297. tests/networks/blocks/test_regunet_block.py +103 -0
  298. tests/networks/blocks/test_se_block.py +85 -0
  299. tests/networks/blocks/test_se_blocks.py +78 -0
  300. tests/networks/blocks/test_segresnet_block.py +57 -0
  301. tests/networks/blocks/test_selfattention.py +232 -0
  302. tests/networks/blocks/test_simple_aspp.py +87 -0
  303. tests/networks/blocks/test_spatialattention.py +55 -0
  304. tests/networks/blocks/test_subpixel_upsample.py +87 -0
  305. tests/networks/blocks/test_text_encoding.py +49 -0
  306. tests/networks/blocks/test_transformerblock.py +90 -0
  307. tests/networks/blocks/test_unetr_block.py +158 -0
  308. tests/networks/blocks/test_upsample_block.py +134 -0
  309. tests/networks/blocks/warp/__init__.py +10 -0
  310. tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
  311. tests/networks/blocks/warp/test_warp.py +250 -0
  312. tests/networks/layers/__init__.py +10 -0
  313. tests/networks/layers/filtering/__init__.py +10 -0
  314. tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
  315. tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
  316. tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
  317. tests/networks/layers/filtering/test_phl_cpu.py +259 -0
  318. tests/networks/layers/filtering/test_phl_cuda.py +167 -0
  319. tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
  320. tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
  321. tests/networks/layers/test_affine_transform.py +385 -0
  322. tests/networks/layers/test_apply_filter.py +89 -0
  323. tests/networks/layers/test_channel_pad.py +51 -0
  324. tests/networks/layers/test_conjugate_gradient.py +56 -0
  325. tests/networks/layers/test_drop_path.py +46 -0
  326. tests/networks/layers/test_gaussian.py +317 -0
  327. tests/networks/layers/test_gaussian_filter.py +206 -0
  328. tests/networks/layers/test_get_layers.py +65 -0
  329. tests/networks/layers/test_gmm.py +314 -0
  330. tests/networks/layers/test_grid_pull.py +93 -0
  331. tests/networks/layers/test_hilbert_transform.py +131 -0
  332. tests/networks/layers/test_lltm.py +62 -0
  333. tests/networks/layers/test_median_filter.py +52 -0
  334. tests/networks/layers/test_polyval.py +55 -0
  335. tests/networks/layers/test_preset_filters.py +136 -0
  336. tests/networks/layers/test_savitzky_golay_filter.py +141 -0
  337. tests/networks/layers/test_separable_filter.py +87 -0
  338. tests/networks/layers/test_skip_connection.py +48 -0
  339. tests/networks/layers/test_vector_quantizer.py +89 -0
  340. tests/networks/layers/test_weight_init.py +50 -0
  341. tests/networks/nets/__init__.py +10 -0
  342. tests/networks/nets/dints/__init__.py +10 -0
  343. tests/networks/nets/dints/test_dints_cell.py +110 -0
  344. tests/networks/nets/dints/test_dints_mixop.py +84 -0
  345. tests/networks/nets/regunet/__init__.py +10 -0
  346. tests/networks/nets/regunet/test_localnet.py +86 -0
  347. tests/networks/nets/regunet/test_regunet.py +88 -0
  348. tests/networks/nets/test_ahnet.py +224 -0
  349. tests/networks/nets/test_attentionunet.py +88 -0
  350. tests/networks/nets/test_autoencoder.py +95 -0
  351. tests/networks/nets/test_autoencoderkl.py +337 -0
  352. tests/networks/nets/test_basic_unet.py +102 -0
  353. tests/networks/nets/test_basic_unetplusplus.py +109 -0
  354. tests/networks/nets/test_bundle_init_bundle.py +55 -0
  355. tests/networks/nets/test_cell_sam_wrapper.py +58 -0
  356. tests/networks/nets/test_controlnet.py +215 -0
  357. tests/networks/nets/test_daf3d.py +62 -0
  358. tests/networks/nets/test_densenet.py +121 -0
  359. tests/networks/nets/test_diffusion_model_unet.py +585 -0
  360. tests/networks/nets/test_dints_network.py +168 -0
  361. tests/networks/nets/test_discriminator.py +59 -0
  362. tests/networks/nets/test_dynunet.py +181 -0
  363. tests/networks/nets/test_efficientnet.py +400 -0
  364. tests/networks/nets/test_flexible_unet.py +341 -0
  365. tests/networks/nets/test_fullyconnectednet.py +69 -0
  366. tests/networks/nets/test_generator.py +59 -0
  367. tests/networks/nets/test_globalnet.py +103 -0
  368. tests/networks/nets/test_highresnet.py +67 -0
  369. tests/networks/nets/test_hovernet.py +218 -0
  370. tests/networks/nets/test_mednext.py +122 -0
  371. tests/networks/nets/test_milmodel.py +92 -0
  372. tests/networks/nets/test_net_adapter.py +68 -0
  373. tests/networks/nets/test_network_consistency.py +86 -0
  374. tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
  375. tests/networks/nets/test_quicknat.py +57 -0
  376. tests/networks/nets/test_resnet.py +340 -0
  377. tests/networks/nets/test_segresnet.py +120 -0
  378. tests/networks/nets/test_segresnet_ds.py +156 -0
  379. tests/networks/nets/test_senet.py +151 -0
  380. tests/networks/nets/test_spade_autoencoderkl.py +295 -0
  381. tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
  382. tests/networks/nets/test_spade_vaegan.py +140 -0
  383. tests/networks/nets/test_swin_unetr.py +139 -0
  384. tests/networks/nets/test_torchvision_fc_model.py +201 -0
  385. tests/networks/nets/test_transchex.py +84 -0
  386. tests/networks/nets/test_transformer.py +108 -0
  387. tests/networks/nets/test_unet.py +208 -0
  388. tests/networks/nets/test_unetr.py +137 -0
  389. tests/networks/nets/test_varautoencoder.py +127 -0
  390. tests/networks/nets/test_vista3d.py +84 -0
  391. tests/networks/nets/test_vit.py +139 -0
  392. tests/networks/nets/test_vitautoenc.py +112 -0
  393. tests/networks/nets/test_vnet.py +81 -0
  394. tests/networks/nets/test_voxelmorph.py +280 -0
  395. tests/networks/nets/test_vqvae.py +274 -0
  396. tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
  397. tests/networks/schedulers/__init__.py +10 -0
  398. tests/networks/schedulers/test_scheduler_ddim.py +83 -0
  399. tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
  400. tests/networks/schedulers/test_scheduler_pndm.py +108 -0
  401. tests/networks/test_bundle_onnx_export.py +71 -0
  402. tests/networks/test_convert_to_onnx.py +106 -0
  403. tests/networks/test_convert_to_torchscript.py +46 -0
  404. tests/networks/test_convert_to_trt.py +79 -0
  405. tests/networks/test_save_state.py +73 -0
  406. tests/networks/test_to_onehot.py +63 -0
  407. tests/networks/test_varnet.py +63 -0
  408. tests/networks/utils/__init__.py +10 -0
  409. tests/networks/utils/test_copy_model_state.py +187 -0
  410. tests/networks/utils/test_eval_mode.py +34 -0
  411. tests/networks/utils/test_freeze_layers.py +61 -0
  412. tests/networks/utils/test_replace_module.py +98 -0
  413. tests/networks/utils/test_train_mode.py +34 -0
  414. tests/optimizers/__init__.py +10 -0
  415. tests/optimizers/test_generate_param_groups.py +105 -0
  416. tests/optimizers/test_lr_finder.py +108 -0
  417. tests/optimizers/test_lr_scheduler.py +71 -0
  418. tests/optimizers/test_optim_novograd.py +100 -0
  419. tests/profile_subclass/__init__.py +10 -0
  420. tests/profile_subclass/cprofile_profiling.py +29 -0
  421. tests/profile_subclass/min_classes.py +30 -0
  422. tests/profile_subclass/profiling.py +73 -0
  423. tests/profile_subclass/pyspy_profiling.py +41 -0
  424. tests/transforms/__init__.py +10 -0
  425. tests/transforms/compose/__init__.py +10 -0
  426. tests/transforms/compose/test_compose.py +758 -0
  427. tests/transforms/compose/test_some_of.py +258 -0
  428. tests/transforms/croppad/__init__.py +10 -0
  429. tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
  430. tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
  431. tests/transforms/functional/__init__.py +10 -0
  432. tests/transforms/functional/test_apply.py +75 -0
  433. tests/transforms/functional/test_resample.py +50 -0
  434. tests/transforms/intensity/__init__.py +10 -0
  435. tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
  436. tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
  437. tests/transforms/intensity/test_foreground_mask.py +98 -0
  438. tests/transforms/intensity/test_foreground_maskd.py +106 -0
  439. tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
  440. tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
  441. tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
  442. tests/transforms/inverse/__init__.py +10 -0
  443. tests/transforms/inverse/test_inverse_array.py +76 -0
  444. tests/transforms/inverse/test_traceable_transform.py +59 -0
  445. tests/transforms/post/__init__.py +10 -0
  446. tests/transforms/post/test_label_filterd.py +78 -0
  447. tests/transforms/post/test_probnms.py +72 -0
  448. tests/transforms/post/test_probnmsd.py +79 -0
  449. tests/transforms/post/test_remove_small_objects.py +102 -0
  450. tests/transforms/spatial/__init__.py +10 -0
  451. tests/transforms/spatial/test_convert_box_points.py +119 -0
  452. tests/transforms/spatial/test_grid_patch.py +134 -0
  453. tests/transforms/spatial/test_grid_patchd.py +102 -0
  454. tests/transforms/spatial/test_rand_grid_patch.py +150 -0
  455. tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
  456. tests/transforms/spatial/test_spatial_resampled.py +124 -0
  457. tests/transforms/test_activations.py +120 -0
  458. tests/transforms/test_activationsd.py +64 -0
  459. tests/transforms/test_adaptors.py +160 -0
  460. tests/transforms/test_add_coordinate_channels.py +53 -0
  461. tests/transforms/test_add_coordinate_channelsd.py +67 -0
  462. tests/transforms/test_add_extreme_points_channel.py +80 -0
  463. tests/transforms/test_add_extreme_points_channeld.py +77 -0
  464. tests/transforms/test_adjust_contrast.py +70 -0
  465. tests/transforms/test_adjust_contrastd.py +64 -0
  466. tests/transforms/test_affine.py +245 -0
  467. tests/transforms/test_affine_grid.py +152 -0
  468. tests/transforms/test_affined.py +190 -0
  469. tests/transforms/test_as_channel_last.py +38 -0
  470. tests/transforms/test_as_channel_lastd.py +44 -0
  471. tests/transforms/test_as_discrete.py +81 -0
  472. tests/transforms/test_as_discreted.py +82 -0
  473. tests/transforms/test_border_pad.py +49 -0
  474. tests/transforms/test_border_padd.py +45 -0
  475. tests/transforms/test_bounding_rect.py +54 -0
  476. tests/transforms/test_bounding_rectd.py +53 -0
  477. tests/transforms/test_cast_to_type.py +63 -0
  478. tests/transforms/test_cast_to_typed.py +74 -0
  479. tests/transforms/test_center_scale_crop.py +55 -0
  480. tests/transforms/test_center_scale_cropd.py +56 -0
  481. tests/transforms/test_center_spatial_crop.py +56 -0
  482. tests/transforms/test_center_spatial_cropd.py +63 -0
  483. tests/transforms/test_classes_to_indices.py +93 -0
  484. tests/transforms/test_classes_to_indicesd.py +110 -0
  485. tests/transforms/test_clip_intensity_percentiles.py +196 -0
  486. tests/transforms/test_clip_intensity_percentilesd.py +193 -0
  487. tests/transforms/test_compose_get_number_conversions.py +127 -0
  488. tests/transforms/test_concat_itemsd.py +82 -0
  489. tests/transforms/test_convert_to_multi_channel.py +59 -0
  490. tests/transforms/test_convert_to_multi_channeld.py +37 -0
  491. tests/transforms/test_copy_itemsd.py +86 -0
  492. tests/transforms/test_create_grid_and_affine.py +274 -0
  493. tests/transforms/test_crop_foreground.py +164 -0
  494. tests/transforms/test_crop_foregroundd.py +205 -0
  495. tests/transforms/test_cucim_dict_transform.py +142 -0
  496. tests/transforms/test_cucim_transform.py +141 -0
  497. tests/transforms/test_data_stats.py +221 -0
  498. tests/transforms/test_data_statsd.py +249 -0
  499. tests/transforms/test_delete_itemsd.py +58 -0
  500. tests/transforms/test_detect_envelope.py +159 -0
  501. tests/transforms/test_distance_transform_edt.py +202 -0
  502. tests/transforms/test_divisible_pad.py +49 -0
  503. tests/transforms/test_divisible_padd.py +42 -0
  504. tests/transforms/test_ensure_channel_first.py +113 -0
  505. tests/transforms/test_ensure_channel_firstd.py +85 -0
  506. tests/transforms/test_ensure_type.py +94 -0
  507. tests/transforms/test_ensure_typed.py +110 -0
  508. tests/transforms/test_fg_bg_to_indices.py +83 -0
  509. tests/transforms/test_fg_bg_to_indicesd.py +78 -0
  510. tests/transforms/test_fill_holes.py +207 -0
  511. tests/transforms/test_fill_holesd.py +209 -0
  512. tests/transforms/test_flatten_sub_keysd.py +64 -0
  513. tests/transforms/test_flip.py +83 -0
  514. tests/transforms/test_flipd.py +90 -0
  515. tests/transforms/test_fourier.py +70 -0
  516. tests/transforms/test_gaussian_sharpen.py +92 -0
  517. tests/transforms/test_gaussian_sharpend.py +92 -0
  518. tests/transforms/test_gaussian_smooth.py +96 -0
  519. tests/transforms/test_gaussian_smoothd.py +96 -0
  520. tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
  521. tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
  522. tests/transforms/test_generate_spatial_bounding_box.py +114 -0
  523. tests/transforms/test_get_extreme_points.py +57 -0
  524. tests/transforms/test_gibbs_noise.py +73 -0
  525. tests/transforms/test_gibbs_noised.py +88 -0
  526. tests/transforms/test_grid_distortion.py +113 -0
  527. tests/transforms/test_grid_distortiond.py +87 -0
  528. tests/transforms/test_grid_split.py +88 -0
  529. tests/transforms/test_grid_splitd.py +96 -0
  530. tests/transforms/test_histogram_normalize.py +59 -0
  531. tests/transforms/test_histogram_normalized.py +59 -0
  532. tests/transforms/test_image_filter.py +259 -0
  533. tests/transforms/test_intensity_stats.py +73 -0
  534. tests/transforms/test_intensity_statsd.py +90 -0
  535. tests/transforms/test_inverse.py +521 -0
  536. tests/transforms/test_inverse_collation.py +147 -0
  537. tests/transforms/test_invert.py +105 -0
  538. tests/transforms/test_invertd.py +142 -0
  539. tests/transforms/test_k_space_spike_noise.py +81 -0
  540. tests/transforms/test_k_space_spike_noised.py +98 -0
  541. tests/transforms/test_keep_largest_connected_component.py +419 -0
  542. tests/transforms/test_keep_largest_connected_componentd.py +348 -0
  543. tests/transforms/test_label_filter.py +78 -0
  544. tests/transforms/test_label_to_contour.py +179 -0
  545. tests/transforms/test_label_to_contourd.py +182 -0
  546. tests/transforms/test_label_to_mask.py +69 -0
  547. tests/transforms/test_label_to_maskd.py +70 -0
  548. tests/transforms/test_load_image.py +502 -0
  549. tests/transforms/test_load_imaged.py +198 -0
  550. tests/transforms/test_load_spacing_orientation.py +149 -0
  551. tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
  552. tests/transforms/test_map_binary_to_indices.py +75 -0
  553. tests/transforms/test_map_classes_to_indices.py +135 -0
  554. tests/transforms/test_map_label_value.py +89 -0
  555. tests/transforms/test_map_label_valued.py +85 -0
  556. tests/transforms/test_map_transform.py +45 -0
  557. tests/transforms/test_mask_intensity.py +74 -0
  558. tests/transforms/test_mask_intensityd.py +68 -0
  559. tests/transforms/test_mean_ensemble.py +77 -0
  560. tests/transforms/test_mean_ensembled.py +91 -0
  561. tests/transforms/test_median_smooth.py +41 -0
  562. tests/transforms/test_median_smoothd.py +65 -0
  563. tests/transforms/test_morphological_ops.py +101 -0
  564. tests/transforms/test_nifti_endianness.py +107 -0
  565. tests/transforms/test_normalize_intensity.py +143 -0
  566. tests/transforms/test_normalize_intensityd.py +81 -0
  567. tests/transforms/test_nvtx_decorator.py +289 -0
  568. tests/transforms/test_nvtx_transform.py +143 -0
  569. tests/transforms/test_orientation.py +247 -0
  570. tests/transforms/test_orientationd.py +112 -0
  571. tests/transforms/test_rand_adjust_contrast.py +45 -0
  572. tests/transforms/test_rand_adjust_contrastd.py +44 -0
  573. tests/transforms/test_rand_affine.py +201 -0
  574. tests/transforms/test_rand_affine_grid.py +212 -0
  575. tests/transforms/test_rand_affined.py +281 -0
  576. tests/transforms/test_rand_axis_flip.py +50 -0
  577. tests/transforms/test_rand_axis_flipd.py +50 -0
  578. tests/transforms/test_rand_bias_field.py +69 -0
  579. tests/transforms/test_rand_bias_fieldd.py +65 -0
  580. tests/transforms/test_rand_coarse_dropout.py +110 -0
  581. tests/transforms/test_rand_coarse_dropoutd.py +107 -0
  582. tests/transforms/test_rand_coarse_shuffle.py +65 -0
  583. tests/transforms/test_rand_coarse_shuffled.py +59 -0
  584. tests/transforms/test_rand_crop_by_label_classes.py +170 -0
  585. tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
  586. tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
  587. tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
  588. tests/transforms/test_rand_cucim_dict_transform.py +162 -0
  589. tests/transforms/test_rand_cucim_transform.py +162 -0
  590. tests/transforms/test_rand_deform_grid.py +138 -0
  591. tests/transforms/test_rand_elastic_2d.py +127 -0
  592. tests/transforms/test_rand_elastic_3d.py +104 -0
  593. tests/transforms/test_rand_elasticd_2d.py +177 -0
  594. tests/transforms/test_rand_elasticd_3d.py +156 -0
  595. tests/transforms/test_rand_flip.py +60 -0
  596. tests/transforms/test_rand_flipd.py +55 -0
  597. tests/transforms/test_rand_gaussian_noise.py +48 -0
  598. tests/transforms/test_rand_gaussian_noised.py +54 -0
  599. tests/transforms/test_rand_gaussian_sharpen.py +140 -0
  600. tests/transforms/test_rand_gaussian_sharpend.py +143 -0
  601. tests/transforms/test_rand_gaussian_smooth.py +98 -0
  602. tests/transforms/test_rand_gaussian_smoothd.py +98 -0
  603. tests/transforms/test_rand_gibbs_noise.py +103 -0
  604. tests/transforms/test_rand_gibbs_noised.py +117 -0
  605. tests/transforms/test_rand_grid_distortion.py +99 -0
  606. tests/transforms/test_rand_grid_distortiond.py +90 -0
  607. tests/transforms/test_rand_histogram_shift.py +92 -0
  608. tests/transforms/test_rand_k_space_spike_noise.py +92 -0
  609. tests/transforms/test_rand_k_space_spike_noised.py +76 -0
  610. tests/transforms/test_rand_rician_noise.py +52 -0
  611. tests/transforms/test_rand_rician_noised.py +52 -0
  612. tests/transforms/test_rand_rotate.py +166 -0
  613. tests/transforms/test_rand_rotate90.py +100 -0
  614. tests/transforms/test_rand_rotate90d.py +112 -0
  615. tests/transforms/test_rand_rotated.py +187 -0
  616. tests/transforms/test_rand_scale_crop.py +78 -0
  617. tests/transforms/test_rand_scale_cropd.py +98 -0
  618. tests/transforms/test_rand_scale_intensity.py +54 -0
  619. tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
  620. tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
  621. tests/transforms/test_rand_scale_intensityd.py +53 -0
  622. tests/transforms/test_rand_shift_intensity.py +52 -0
  623. tests/transforms/test_rand_shift_intensityd.py +67 -0
  624. tests/transforms/test_rand_simulate_low_resolution.py +83 -0
  625. tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
  626. tests/transforms/test_rand_spatial_crop.py +107 -0
  627. tests/transforms/test_rand_spatial_crop_samples.py +128 -0
  628. tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
  629. tests/transforms/test_rand_spatial_cropd.py +112 -0
  630. tests/transforms/test_rand_std_shift_intensity.py +43 -0
  631. tests/transforms/test_rand_std_shift_intensityd.py +38 -0
  632. tests/transforms/test_rand_zoom.py +105 -0
  633. tests/transforms/test_rand_zoomd.py +108 -0
  634. tests/transforms/test_randidentity.py +49 -0
  635. tests/transforms/test_random_order.py +144 -0
  636. tests/transforms/test_randtorchvisiond.py +65 -0
  637. tests/transforms/test_regularization.py +139 -0
  638. tests/transforms/test_remove_repeated_channel.py +34 -0
  639. tests/transforms/test_remove_repeated_channeld.py +44 -0
  640. tests/transforms/test_repeat_channel.py +34 -0
  641. tests/transforms/test_repeat_channeld.py +41 -0
  642. tests/transforms/test_resample_backends.py +65 -0
  643. tests/transforms/test_resample_to_match.py +110 -0
  644. tests/transforms/test_resample_to_matchd.py +93 -0
  645. tests/transforms/test_resampler.py +165 -0
  646. tests/transforms/test_resize.py +140 -0
  647. tests/transforms/test_resize_with_pad_or_crop.py +91 -0
  648. tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
  649. tests/transforms/test_resized.py +163 -0
  650. tests/transforms/test_rotate.py +160 -0
  651. tests/transforms/test_rotate90.py +212 -0
  652. tests/transforms/test_rotate90d.py +106 -0
  653. tests/transforms/test_rotated.py +179 -0
  654. tests/transforms/test_save_classificationd.py +109 -0
  655. tests/transforms/test_save_image.py +80 -0
  656. tests/transforms/test_save_imaged.py +130 -0
  657. tests/transforms/test_savitzky_golay_smooth.py +73 -0
  658. tests/transforms/test_savitzky_golay_smoothd.py +73 -0
  659. tests/transforms/test_scale_intensity.py +76 -0
  660. tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
  661. tests/transforms/test_scale_intensity_range.py +41 -0
  662. tests/transforms/test_scale_intensity_ranged.py +40 -0
  663. tests/transforms/test_scale_intensityd.py +57 -0
  664. tests/transforms/test_select_itemsd.py +41 -0
  665. tests/transforms/test_shift_intensity.py +31 -0
  666. tests/transforms/test_shift_intensityd.py +44 -0
  667. tests/transforms/test_signal_continuouswavelet.py +44 -0
  668. tests/transforms/test_signal_fillempty.py +52 -0
  669. tests/transforms/test_signal_fillemptyd.py +60 -0
  670. tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
  671. tests/transforms/test_signal_rand_add_sine.py +52 -0
  672. tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
  673. tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
  674. tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
  675. tests/transforms/test_signal_rand_drop.py +50 -0
  676. tests/transforms/test_signal_rand_scale.py +52 -0
  677. tests/transforms/test_signal_rand_shift.py +55 -0
  678. tests/transforms/test_signal_remove_frequency.py +71 -0
  679. tests/transforms/test_smooth_field.py +177 -0
  680. tests/transforms/test_sobel_gradient.py +189 -0
  681. tests/transforms/test_sobel_gradientd.py +212 -0
  682. tests/transforms/test_spacing.py +381 -0
  683. tests/transforms/test_spacingd.py +178 -0
  684. tests/transforms/test_spatial_crop.py +82 -0
  685. tests/transforms/test_spatial_cropd.py +74 -0
  686. tests/transforms/test_spatial_pad.py +57 -0
  687. tests/transforms/test_spatial_padd.py +43 -0
  688. tests/transforms/test_spatial_resample.py +235 -0
  689. tests/transforms/test_squeezedim.py +62 -0
  690. tests/transforms/test_squeezedimd.py +98 -0
  691. tests/transforms/test_std_shift_intensity.py +76 -0
  692. tests/transforms/test_std_shift_intensityd.py +74 -0
  693. tests/transforms/test_threshold_intensity.py +38 -0
  694. tests/transforms/test_threshold_intensityd.py +58 -0
  695. tests/transforms/test_to_contiguous.py +47 -0
  696. tests/transforms/test_to_cupy.py +112 -0
  697. tests/transforms/test_to_cupyd.py +76 -0
  698. tests/transforms/test_to_device.py +42 -0
  699. tests/transforms/test_to_deviced.py +37 -0
  700. tests/transforms/test_to_numpy.py +85 -0
  701. tests/transforms/test_to_numpyd.py +68 -0
  702. tests/transforms/test_to_pil.py +52 -0
  703. tests/transforms/test_to_pild.py +50 -0
  704. tests/transforms/test_to_tensor.py +60 -0
  705. tests/transforms/test_to_tensord.py +71 -0
  706. tests/transforms/test_torchvision.py +66 -0
  707. tests/transforms/test_torchvisiond.py +63 -0
  708. tests/transforms/test_transform.py +62 -0
  709. tests/transforms/test_transpose.py +41 -0
  710. tests/transforms/test_transposed.py +52 -0
  711. tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
  712. tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
  713. tests/transforms/test_vote_ensemble.py +84 -0
  714. tests/transforms/test_vote_ensembled.py +107 -0
  715. tests/transforms/test_with_allow_missing_keys.py +76 -0
  716. tests/transforms/test_zoom.py +120 -0
  717. tests/transforms/test_zoomd.py +94 -0
  718. tests/transforms/transform/__init__.py +10 -0
  719. tests/transforms/transform/test_randomizable.py +52 -0
  720. tests/transforms/transform/test_randomizable_transform_type.py +37 -0
  721. tests/transforms/utility/__init__.py +10 -0
  722. tests/transforms/utility/test_apply_transform_to_points.py +81 -0
  723. tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
  724. tests/transforms/utility/test_identity.py +29 -0
  725. tests/transforms/utility/test_identityd.py +30 -0
  726. tests/transforms/utility/test_lambda.py +71 -0
  727. tests/transforms/utility/test_lambdad.py +83 -0
  728. tests/transforms/utility/test_rand_lambda.py +87 -0
  729. tests/transforms/utility/test_rand_lambdad.py +77 -0
  730. tests/transforms/utility/test_simulatedelay.py +36 -0
  731. tests/transforms/utility/test_simulatedelayd.py +36 -0
  732. tests/transforms/utility/test_splitdim.py +52 -0
  733. tests/transforms/utility/test_splitdimd.py +96 -0
  734. tests/transforms/utils/__init__.py +10 -0
  735. tests/transforms/utils/test_correct_crop_centers.py +36 -0
  736. tests/transforms/utils/test_get_unique_labels.py +45 -0
  737. tests/transforms/utils/test_print_transform_backends.py +29 -0
  738. tests/transforms/utils/test_soft_clip.py +125 -0
  739. tests/utils/__init__.py +10 -0
  740. tests/utils/enums/__init__.py +10 -0
  741. tests/utils/enums/test_hovernet_loss.py +190 -0
  742. tests/utils/enums/test_ordering.py +289 -0
  743. tests/utils/enums/test_wsireader.py +663 -0
  744. tests/utils/misc/__init__.py +10 -0
  745. tests/utils/misc/test_ensure_tuple.py +53 -0
  746. tests/utils/misc/test_monai_env_vars.py +44 -0
  747. tests/utils/misc/test_monai_utils_misc.py +103 -0
  748. tests/utils/misc/test_str2bool.py +34 -0
  749. tests/utils/misc/test_str2list.py +33 -0
  750. tests/utils/test_alias.py +44 -0
  751. tests/utils/test_component_store.py +73 -0
  752. tests/utils/test_deprecated.py +455 -0
  753. tests/utils/test_enum_bound_interp.py +75 -0
  754. tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
  755. tests/utils/test_get_package_version.py +34 -0
  756. tests/utils/test_handler_logfile.py +84 -0
  757. tests/utils/test_handler_metric_logger.py +62 -0
  758. tests/utils/test_list_to_dict.py +43 -0
  759. tests/utils/test_look_up_option.py +87 -0
  760. tests/utils/test_optional_import.py +80 -0
  761. tests/utils/test_pad_mode.py +39 -0
  762. tests/utils/test_profiling.py +208 -0
  763. tests/utils/test_rankfilter_dist.py +77 -0
  764. tests/utils/test_require_pkg.py +83 -0
  765. tests/utils/test_sample_slices.py +43 -0
  766. tests/utils/test_set_determinism.py +74 -0
  767. tests/utils/test_squeeze_unsqueeze.py +71 -0
  768. tests/utils/test_state_cacher.py +67 -0
  769. tests/utils/test_torchscript_utils.py +113 -0
  770. tests/utils/test_version.py +91 -0
  771. tests/utils/test_version_after.py +65 -0
  772. tests/utils/type_conversion/__init__.py +10 -0
  773. tests/utils/type_conversion/test_convert_data_type.py +152 -0
  774. tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
  775. tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
  776. tests/visualize/__init__.py +10 -0
  777. tests/visualize/test_img2tensorboard.py +46 -0
  778. tests/visualize/test_occlusion_sensitivity.py +128 -0
  779. tests/visualize/test_plot_2d_or_3d_image.py +74 -0
  780. tests/visualize/test_vis_cam.py +98 -0
  781. tests/visualize/test_vis_gradcam.py +211 -0
  782. tests/visualize/utils/__init__.py +10 -0
  783. tests/visualize/utils/test_blend_images.py +63 -0
  784. tests/visualize/utils/test_matshow3d.py +133 -0
  785. monai_weekly-1.5.dev2506.dist-info/RECORD +0 -427
  786. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/LICENSE +0 -0
  787. {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/WHEEL +0 -0
@@ -0,0 +1,154 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import os
15
+ import tempfile
16
+ import unittest
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import torch
21
+ from parameterized import parameterized
22
+
23
+ from monai import __version__
24
+ from monai.apps import download_mmar, load_from_mmar
25
+ from monai.apps.mmars import MODEL_DESC
26
+ from monai.apps.mmars.mmars import _get_val
27
+ from monai.utils import version_leq
28
+ from tests.test_utils import skip_if_downloading_fails, skip_if_quick
29
+
30
+ TEST_CASES = [["clara_pt_prostate_mri_segmentation"], ["clara_pt_covid19_ct_lesion_segmentation"]]
31
+ TEST_EXTRACT_CASES = [
32
+ (
33
+ {"item": "clara_pt_prostate_mri_segmentation", "map_location": "cuda" if torch.cuda.is_available() else "cpu"},
34
+ "UNet",
35
+ np.array(
36
+ [
37
+ [[-0.0838, 0.0116, -0.0861], [-0.0792, 0.2216, -0.0301], [-0.0379, 0.0006, -0.0399]],
38
+ [[-0.0347, 0.0979, 0.0754], [0.1689, 0.3759, 0.2584], [-0.0698, 0.2740, 0.1414]],
39
+ [[-0.0772, 0.1046, -0.0103], [0.0917, 0.1942, 0.0284], [-0.0165, -0.0181, 0.0247]],
40
+ ]
41
+ ),
42
+ ),
43
+ (
44
+ {
45
+ "item": "clara_pt_covid19_ct_lesion_segmentation",
46
+ "map_location": "cuda" if torch.cuda.is_available() else "cpu",
47
+ },
48
+ "SegResNet",
49
+ np.array(
50
+ [
51
+ [
52
+ [0.01671106, 0.08502351, -0.1766469],
53
+ [-0.13039736, -0.06137804, 0.03924942],
54
+ [0.02268324, 0.159056, -0.03485069],
55
+ ],
56
+ [
57
+ [0.04788467, -0.09365353, -0.05802464],
58
+ [-0.19500689, -0.13514304, -0.08191573],
59
+ [0.0238207, 0.08029253, 0.10818923],
60
+ ],
61
+ [
62
+ [-0.11541673, -0.10622888, 0.039689],
63
+ [0.18462701, -0.0499289, 0.14309818],
64
+ [0.00528282, 0.02152331, 0.1698219],
65
+ ],
66
+ ]
67
+ ),
68
+ ),
69
+ (
70
+ {
71
+ "item": "clara_pt_fed_learning_brain_tumor_mri_segmentation",
72
+ "map_location": "cuda" if torch.cuda.is_available() else "cpu",
73
+ "model_file": os.path.join("models", "server", "best_FL_global_model.pt"),
74
+ },
75
+ "SegResNet",
76
+ np.array(
77
+ [
78
+ [
79
+ [0.01874463, 0.12237817, 0.09269974],
80
+ [0.07691482, 0.00621202, -0.06682577],
81
+ [-0.07718472, 0.08637864, -0.03222707],
82
+ ],
83
+ [
84
+ [0.05117761, 0.07428649, -0.03053505],
85
+ [0.11045473, 0.07083791, 0.06547518],
86
+ [0.09555705, -0.03950734, -0.00819483],
87
+ ],
88
+ [
89
+ [0.03704128, 0.062543, 0.0380853],
90
+ [-0.02814676, -0.03078287, -0.01383446],
91
+ [-0.08137762, 0.01385882, 0.01229484],
92
+ ],
93
+ ]
94
+ ),
95
+ ),
96
+ (
97
+ {
98
+ "item": "clara_pt_pathology_metastasis_detection",
99
+ "map_location": "cuda" if torch.cuda.is_available() else "cpu",
100
+ },
101
+ "TorchVisionFCModel",
102
+ np.array(
103
+ [
104
+ [-0.00540746, -0.00274996, -0.00837622, 0.05415914, 0.03555066, -0.00071636, -0.02325751],
105
+ [0.00564625, 0.00674562, -0.1098334, -0.2936509, -0.28384757, -0.13580588, -0.00737865],
106
+ [-0.02159783, 0.04615543, 0.29717407, 0.6001161, 0.53496915, 0.2528417, 0.04530451],
107
+ [0.0225903, -0.07556137, -0.3070122, -0.43984795, -0.26286602, -0.00172576, 0.05003437],
108
+ [-0.0320133, 0.00855468, 0.06824744, -0.04786247, -0.30358723, -0.3960023, -0.24895012],
109
+ [0.02412516, 0.03411723, 0.06513759, 0.24332047, 0.41664436, 0.38999054, 0.15957521],
110
+ [-0.01303542, -0.00166874, -0.01965466, -0.06620175, -0.15635538, -0.10023144, -0.01698002],
111
+ ]
112
+ ),
113
+ ),
114
+ ]
115
+
116
+
117
+ @unittest.skip("deprecating mmar tests")
118
+ class TestMMMARDownload(unittest.TestCase):
119
+ @parameterized.expand(TEST_CASES)
120
+ @skip_if_quick
121
+ def test_download(self, idx):
122
+ with skip_if_downloading_fails():
123
+ with self.assertLogs(level="INFO", logger="monai.apps"):
124
+ download_mmar(idx)
125
+ download_mmar(idx, progress=False) # repeated to check caching
126
+ with tempfile.TemporaryDirectory() as tmp_dir:
127
+ download_mmar(idx, mmar_dir=tmp_dir, progress=False)
128
+ download_mmar(idx, mmar_dir=Path(tmp_dir), progress=False, version=1) # repeated to check caching
129
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, idx)))
130
+
131
+ @parameterized.expand(TEST_EXTRACT_CASES)
132
+ @skip_if_quick
133
+ @unittest.skipIf(version_leq(__version__, "0.6"), "requires newer monai")
134
+ def test_load_ckpt(self, input_args, expected_name, expected_val):
135
+ with skip_if_downloading_fails():
136
+ output = load_from_mmar(**input_args)
137
+ self.assertEqual(output.__class__.__name__, expected_name)
138
+ x = next(output.parameters()) # verify the first element
139
+ np.testing.assert_allclose(x[0][0].detach().cpu().numpy(), expected_val, rtol=1e-3, atol=1e-3)
140
+
141
+ def test_unique(self):
142
+ # model ids are unique
143
+ keys = sorted(m["id"] for m in MODEL_DESC)
144
+ self.assertEqual(keys, sorted(set(keys)))
145
+
146
+ def test_search(self):
147
+ self.assertEqual(_get_val({"a": 1, "b": 2}, key="b"), 2)
148
+ self.assertEqual(_get_val({"a": {"c": {"c": 4}}, "b": {"c": 2}}, key="b"), {"c": 2})
149
+ self.assertEqual(_get_val({"a": {"c": 4}, "b": {"c": 2}}, key="c"), 4)
150
+ self.assertEqual(_get_val({"a": {"c": None}, "b": {"c": 2}}, key="c"), 2)
151
+
152
+
153
+ if __name__ == "__main__":
154
+ unittest.main()
@@ -0,0 +1,123 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import os
15
+ import shutil
16
+ import unittest
17
+ from pathlib import Path
18
+
19
+ from monai.apps import TciaDataset
20
+ from monai.apps.tcia import DCM_FILENAME_REGEX, TCIA_LABEL_DICT
21
+ from monai.data import MetaTensor
22
+ from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ScaleIntensityd
23
+ from tests.test_utils import skip_if_downloading_fails, skip_if_quick
24
+
25
+
26
+ class TestTciaDataset(unittest.TestCase):
27
+ @skip_if_quick
28
+ def test_values(self):
29
+ testing_dir = Path(__file__).parents[1] / "testing_data"
30
+ download_len = 1
31
+ val_frac = 1.0
32
+ collection = "QIN-PROSTATE-Repeatability"
33
+
34
+ transform = Compose(
35
+ [
36
+ LoadImaged(
37
+ keys=["image", "seg"],
38
+ reader="PydicomReader",
39
+ fname_regex=DCM_FILENAME_REGEX,
40
+ label_dict=TCIA_LABEL_DICT[collection],
41
+ ),
42
+ EnsureChannelFirstd(keys="image", channel_dim="no_channel"),
43
+ ScaleIntensityd(keys="image"),
44
+ ]
45
+ )
46
+
47
+ def _test_dataset(dataset):
48
+ self.assertEqual(len(dataset), int(download_len * val_frac))
49
+ self.assertTrue("image" in dataset[0])
50
+ self.assertTrue("seg" in dataset[0])
51
+ self.assertTrue(isinstance(dataset[0]["image"], MetaTensor))
52
+ self.assertTupleEqual(dataset[0]["image"].shape, (1, 256, 256, 24))
53
+ self.assertTupleEqual(dataset[0]["seg"].shape, (256, 256, 24, 4))
54
+
55
+ with skip_if_downloading_fails():
56
+ data = TciaDataset(
57
+ root_dir=testing_dir,
58
+ collection=collection,
59
+ transform=transform,
60
+ section="validation",
61
+ download=True,
62
+ download_len=download_len,
63
+ copy_cache=False,
64
+ val_frac=val_frac,
65
+ )
66
+
67
+ _test_dataset(data)
68
+ data = TciaDataset(
69
+ root_dir=testing_dir,
70
+ collection=collection,
71
+ transform=transform,
72
+ section="validation",
73
+ download=False,
74
+ val_frac=val_frac,
75
+ runtime_cache=True,
76
+ )
77
+ _test_dataset(data)
78
+ self.assertTrue(
79
+ data[0]["image"].meta["filename_or_obj"].endswith("QIN-PROSTATE-Repeatability/PCAMPMRI-00015/1901/image")
80
+ )
81
+ self.assertTrue(
82
+ data[0]["seg"].meta["filename_or_obj"].endswith("QIN-PROSTATE-Repeatability/PCAMPMRI-00015/1901/seg")
83
+ )
84
+ # test validation without transforms
85
+ data = TciaDataset(
86
+ root_dir=testing_dir, collection=collection, section="validation", download=False, val_frac=val_frac
87
+ )
88
+ self.assertTupleEqual(data[0]["image"].shape, (256, 256, 24))
89
+ self.assertEqual(len(data), int(download_len * val_frac))
90
+ data = TciaDataset(
91
+ root_dir=testing_dir,
92
+ collection=collection,
93
+ section="validation",
94
+ download=False,
95
+ fname_regex=DCM_FILENAME_REGEX,
96
+ val_frac=val_frac,
97
+ )
98
+ self.assertTupleEqual(data[0]["image"].shape, (256, 256, 24))
99
+ self.assertEqual(len(data), download_len)
100
+ with self.assertWarns(UserWarning):
101
+ data = TciaDataset(
102
+ root_dir=testing_dir,
103
+ collection=collection,
104
+ section="validation",
105
+ fname_regex=".*", # all files including 'LICENSE' is not a valid input
106
+ download=False,
107
+ val_frac=val_frac,
108
+ )[0]
109
+
110
+ shutil.rmtree(os.path.join(testing_dir, collection))
111
+ with self.assertRaisesRegex(RuntimeError, "^Cannot find dataset directory"):
112
+ TciaDataset(
113
+ root_dir=testing_dir,
114
+ collection=collection,
115
+ transform=transform,
116
+ section="validation",
117
+ download=False,
118
+ val_frac=val_frac,
119
+ )
120
+
121
+
122
+ if __name__ == "__main__":
123
+ unittest.main()
@@ -0,0 +1,10 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
@@ -0,0 +1,77 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import torch
17
+ from parameterized import parameterized
18
+
19
+ from monai.apps.vista3d.inferer import point_based_window_inferer
20
+ from monai.networks import eval_mode
21
+ from monai.networks.nets.vista3d import vista3d132
22
+ from monai.utils import optional_import
23
+ from tests.test_utils import SkipIfBeforePyTorchVersion, skip_if_quick
24
+
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ _, has_tqdm = optional_import("tqdm")
28
+
29
+ TEST_CASES = [
30
+ [
31
+ {"encoder_embed_dim": 48, "in_channels": 1},
32
+ (1, 1, 64, 64, 64),
33
+ {
34
+ "roi_size": [32, 32, 32],
35
+ "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device),
36
+ "point_labels": torch.tensor([[1, 0]], device=device),
37
+ },
38
+ ],
39
+ [
40
+ {"encoder_embed_dim": 48, "in_channels": 1},
41
+ (1, 1, 64, 64, 64),
42
+ {
43
+ "roi_size": [32, 32, 32],
44
+ "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device),
45
+ "point_labels": torch.tensor([[1, 0]], device=device),
46
+ "class_vector": torch.tensor([1], device=device),
47
+ },
48
+ ],
49
+ [
50
+ {"encoder_embed_dim": 48, "in_channels": 1},
51
+ (1, 1, 64, 64, 64),
52
+ {
53
+ "roi_size": [32, 32, 32],
54
+ "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device),
55
+ "point_labels": torch.tensor([[1, 0]], device=device),
56
+ "class_vector": torch.tensor([1], device=device),
57
+ "point_start": 1,
58
+ },
59
+ ],
60
+ ]
61
+
62
+
63
+ @SkipIfBeforePyTorchVersion((1, 11))
64
+ @skip_if_quick
65
+ class TestPointBasedWindowInferer(unittest.TestCase):
66
+ @parameterized.expand(TEST_CASES)
67
+ def test_vista3d(self, vista3d_params, inputs_shape, inferer_params):
68
+ vista3d = vista3d132(**vista3d_params).to(device)
69
+ with eval_mode(vista3d):
70
+ inferer_params["predictor"] = vista3d
71
+ inferer_params["inputs"] = torch.randn(*inputs_shape).to(device)
72
+ stitched_output = point_based_window_inferer(**inferer_params)
73
+ self.assertEqual(stitched_output.shape, inputs_shape)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ unittest.main()
@@ -0,0 +1,100 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import torch
17
+ from parameterized import parameterized
18
+
19
+ from monai.apps.vista3d.sampler import sample_prompt_pairs
20
+
21
+ label = torch.zeros([1, 1, 64, 64, 64])
22
+ label[:, :, :10, :10, :10] = 1
23
+ label[:, :, 20:30, 20:30, 20:30] = 2
24
+ label[:, :, 30:40, 30:40, 30:40] = 3
25
+ label1 = torch.zeros([1, 1, 64, 64, 64])
26
+
27
+ TEST_VISTA_SAMPLE_PROMPT = [
28
+ [
29
+ {
30
+ "labels": label,
31
+ "label_set": [0, 1, 2, 3, 4],
32
+ "max_prompt": 5,
33
+ "max_foreprompt": 4,
34
+ "max_backprompt": 1,
35
+ "drop_label_prob": 0,
36
+ "drop_point_prob": 0,
37
+ },
38
+ [4, 4, 4, 4],
39
+ ],
40
+ [
41
+ {
42
+ "labels": label,
43
+ "label_set": [0, 1],
44
+ "max_prompt": 5,
45
+ "max_foreprompt": 4,
46
+ "max_backprompt": 1,
47
+ "drop_label_prob": 0,
48
+ "drop_point_prob": 1,
49
+ },
50
+ [2, None, None, 2],
51
+ ],
52
+ [
53
+ {
54
+ "labels": label,
55
+ "label_set": [0, 1, 2, 3, 4],
56
+ "max_prompt": 5,
57
+ "max_foreprompt": 4,
58
+ "max_backprompt": 1,
59
+ "drop_label_prob": 1,
60
+ "drop_point_prob": 0,
61
+ },
62
+ [None, 3, 3, 3],
63
+ ],
64
+ [
65
+ {
66
+ "labels": label1,
67
+ "label_set": [0, 1],
68
+ "max_prompt": 5,
69
+ "max_foreprompt": 4,
70
+ "max_backprompt": 1,
71
+ "drop_label_prob": 0,
72
+ "drop_point_prob": 1,
73
+ },
74
+ [1, None, None, 1],
75
+ ],
76
+ [
77
+ {
78
+ "labels": label1,
79
+ "label_set": [0, 1],
80
+ "max_prompt": 5,
81
+ "max_foreprompt": 4,
82
+ "max_backprompt": 0,
83
+ "drop_label_prob": 0,
84
+ "drop_point_prob": 1,
85
+ },
86
+ [None, None, None, None],
87
+ ],
88
+ ]
89
+
90
+
91
+ class TestGeneratePrompt(unittest.TestCase):
92
+ @parameterized.expand(TEST_VISTA_SAMPLE_PROMPT)
93
+ def test_result(self, input_data, expected):
94
+ output = sample_prompt_pairs(**input_data)
95
+ result = [i.shape[0] if i is not None else None for i in output]
96
+ self.assertEqual(result, expected)
97
+
98
+
99
+ if __name__ == "__main__":
100
+ unittest.main()
@@ -0,0 +1,94 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+ from unittest.case import skipUnless
16
+
17
+ import torch
18
+ from parameterized import parameterized
19
+
20
+ from monai.apps.vista3d.transforms import VistaPostTransformd, VistaPreTransformd
21
+ from monai.utils import min_version
22
+ from monai.utils.module import optional_import
23
+
24
+ measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version)
25
+
26
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
27
+
28
+
29
+ TEST_VISTA_PRETRANSFORM = [
30
+ [
31
+ {"label_prompt": [1], "points": [[0, 0, 0]], "point_labels": [1]},
32
+ {"label_prompt": [1], "points": [[0, 0, 0]], "point_labels": [3]},
33
+ ],
34
+ [
35
+ {"label_prompt": [2], "points": [[0, 0, 0]], "point_labels": [0]},
36
+ {"label_prompt": [2], "points": [[0, 0, 0]], "point_labels": [2]},
37
+ ],
38
+ [
39
+ {"label_prompt": [3], "points": [[0, 0, 0]], "point_labels": [0]},
40
+ {"label_prompt": [4, 5], "points": [[0, 0, 0]], "point_labels": [0]},
41
+ ],
42
+ [
43
+ {"label_prompt": [6], "points": [[0, 0, 0]], "point_labels": [0]},
44
+ {"label_prompt": [7, 8], "points": [[0, 0, 0]], "point_labels": [0]},
45
+ ],
46
+ ]
47
+
48
+
49
+ pred1 = torch.zeros([2, 64, 64, 64])
50
+ pred1[0, :10, :10, :10] = 1
51
+ pred1[1, 20:30, 20:30, 20:30] = 1
52
+ output1 = torch.zeros([1, 64, 64, 64])
53
+ output1[:, :10, :10, :10] = 2
54
+ output1[:, 20:30, 20:30, 20:30] = 3
55
+
56
+ # -1 is needed since pred should be before sigmoid.
57
+ pred2 = torch.zeros([1, 64, 64, 64]) - 1
58
+ pred2[:, :10, :10, :10] = 1
59
+ pred2[:, 20:30, 20:30, 20:30] = 1
60
+ output2 = torch.zeros([1, 64, 64, 64])
61
+ output2[:, 20:30, 20:30, 20:30] = 1
62
+
63
+ TEST_VISTA_POSTTRANSFORM = [
64
+ [{"pred": pred1.to(device), "label_prompt": torch.tensor([2, 3]).to(device)}, output1.to(device)],
65
+ [
66
+ {
67
+ "pred": pred2.to(device),
68
+ "points": torch.tensor([[25, 25, 25]]).to(device),
69
+ "point_labels": torch.tensor([1]).to(device),
70
+ },
71
+ output2.to(device),
72
+ ],
73
+ ]
74
+
75
+
76
+ class TestVistaPreTransformd(unittest.TestCase):
77
+ @parameterized.expand(TEST_VISTA_PRETRANSFORM)
78
+ def test_result(self, input_data, expected):
79
+ transform = VistaPreTransformd(keys="image", subclass={"3": [4, 5], "6": [7, 8]}, special_index=[1, 2])
80
+ result = transform(input_data)
81
+ self.assertEqual(result, expected)
82
+
83
+
84
+ @skipUnless(has_measure, "skimage.measure required")
85
+ class TestVistaPostTransformd(unittest.TestCase):
86
+ @parameterized.expand(TEST_VISTA_POSTTRANSFORM)
87
+ def test_result(self, input_data, expected):
88
+ transform = VistaPostTransformd(keys="pred")
89
+ result = transform(input_data)
90
+ self.assertEqual((result["pred"] == expected).all(), True)
91
+
92
+
93
+ if __name__ == "__main__":
94
+ unittest.main()
@@ -0,0 +1,10 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
@@ -0,0 +1,107 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import os
16
+ import tempfile
17
+ import unittest
18
+ from pathlib import Path
19
+
20
+ from parameterized import parameterized
21
+
22
+ from monai.bundle import ConfigParser
23
+ from monai.data import load_net_with_metadata
24
+ from monai.networks import save_state
25
+ from tests.test_utils import command_line_tests, skip_if_windows
26
+
27
+ TESTS_PATH = Path(__file__).parents[1]
28
+
29
+ TEST_CASE_1 = ["", ""]
30
+
31
+ TEST_CASE_2 = ["model", ""]
32
+
33
+ TEST_CASE_3 = ["model", "True"]
34
+
35
+
36
+ @skip_if_windows
37
+ class TestCKPTExport(unittest.TestCase):
38
+
39
+ def setUp(self):
40
+ self.device = os.environ.get("CUDA_VISIBLE_DEVICES")
41
+ if not self.device:
42
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0" # default
43
+
44
+ def tearDown(self):
45
+ if self.device is not None:
46
+ os.environ["CUDA_VISIBLE_DEVICES"] = self.device
47
+ else:
48
+ del os.environ["CUDA_VISIBLE_DEVICES"] # previously unset
49
+
50
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
51
+ def test_export(self, key_in_ckpt, use_trace):
52
+ meta_file = os.path.join(TESTS_PATH, "testing_data", "metadata.json")
53
+ config_file = os.path.join(TESTS_PATH, "testing_data", "inference.json")
54
+ with tempfile.TemporaryDirectory() as tempdir:
55
+ def_args = {"meta_file": "will be replaced by `meta_file` arg"}
56
+ def_args_file = os.path.join(tempdir, "def_args.yaml")
57
+
58
+ ckpt_file = os.path.join(tempdir, "model.pt")
59
+ ts_file = os.path.join(tempdir, "model.ts")
60
+
61
+ parser = ConfigParser()
62
+ parser.export_config_file(config=def_args, filepath=def_args_file)
63
+ parser.read_config(config_file)
64
+ net = parser.get_parsed_content("network_def")
65
+ save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file)
66
+
67
+ cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", ts_file]
68
+ cmd += ["--meta_file", meta_file, "--config_file", f"['{config_file}','{def_args_file}']", "--ckpt_file"]
69
+ cmd += [ckpt_file, "--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file]
70
+ if use_trace == "True":
71
+ cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
72
+ command_line_tests(cmd)
73
+ self.assertTrue(os.path.exists(ts_file))
74
+
75
+ _, metadata, extra_files = load_net_with_metadata(
76
+ ts_file, more_extra_files=["inference.json", "def_args.json"]
77
+ )
78
+ self.assertIn("schema", metadata)
79
+ self.assertIn("meta_file", json.loads(extra_files["def_args.json"]))
80
+ self.assertIn("network_def", json.loads(extra_files["inference.json"]))
81
+
82
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
83
+ def test_default_value(self, key_in_ckpt, use_trace):
84
+ config_file = os.path.join(TESTS_PATH, "testing_data", "inference.json")
85
+ with tempfile.TemporaryDirectory() as tempdir:
86
+ def_args = {"meta_file": "will be replaced by `meta_file` arg"}
87
+ def_args_file = os.path.join(tempdir, "def_args.yaml")
88
+ ckpt_file = os.path.join(tempdir, "models/model.pt")
89
+ ts_file = os.path.join(tempdir, "models/model.ts")
90
+
91
+ parser = ConfigParser()
92
+ parser.export_config_file(config=def_args, filepath=def_args_file)
93
+ parser.read_config(config_file)
94
+ net = parser.get_parsed_content("network_def")
95
+ save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file)
96
+
97
+ # check with default value
98
+ cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "--key_in_ckpt", key_in_ckpt]
99
+ cmd += ["--config_file", config_file, "--bundle_root", tempdir]
100
+ if use_trace == "True":
101
+ cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
102
+ command_line_tests(cmd)
103
+ self.assertTrue(os.path.exists(ts_file))
104
+
105
+
106
+ if __name__ == "__main__":
107
+ unittest.main()