monai-weekly 1.5.dev2512__py3-none-any.whl → 1.5.dev2513__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/metrics/meandice.py +132 -76
- monai/networks/blocks/__init__.py +2 -1
- monai/networks/blocks/cablock.py +182 -0
- monai/networks/blocks/downsample.py +241 -2
- monai/networks/nets/restormer.py +337 -0
- monai/networks/utils.py +44 -1
- monai/utils/__init__.py +1 -0
- monai/utils/enums.py +13 -0
- {monai_weekly-1.5.dev2512.dist-info → monai_weekly-1.5.dev2513.dist-info}/METADATA +1 -1
- {monai_weekly-1.5.dev2512.dist-info → monai_weekly-1.5.dev2513.dist-info}/RECORD +20 -15
- {monai_weekly-1.5.dev2512.dist-info → monai_weekly-1.5.dev2513.dist-info}/WHEEL +1 -1
- tests/metrics/test_compute_meandice.py +3 -3
- tests/networks/blocks/test_CABlock.py +150 -0
- tests/networks/blocks/test_downsample_block.py +184 -0
- tests/networks/nets/test_restormer.py +147 -0
- tests/networks/utils/test_pixelunshuffle.py +51 -0
- tests/integration/test_downsample_block.py +0 -50
- {monai_weekly-1.5.dev2512.dist-info → monai_weekly-1.5.dev2513.dist-info}/licenses/LICENSE +0 -0
- {monai_weekly-1.5.dev2512.dist-info → monai_weekly-1.5.dev2513.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
|
-
monai/__init__.py,sha256=
|
2
|
-
monai/_version.py,sha256=
|
1
|
+
monai/__init__.py,sha256=lOVClDZxEDXvMXltT5kiFD_6UM2TYyLqp6gJRLLSWY8,4095
|
2
|
+
monai/_version.py,sha256=TaZfkTv96TjOsFZhWyZwFaEDQWRV0aCQUaMZaKfJkOo,503
|
3
3
|
monai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
monai/_extensions/__init__.py,sha256=NEBPreRhQ8H9gVvgrLr_y52_TmqB96u_u4VQmeNT93I,642
|
5
5
|
monai/_extensions/loader.py,sha256=7SiKw36q-nOzH8CRbBurFrz7GM40GCu7rc93Tm8XpnI,3643
|
@@ -232,7 +232,7 @@ monai/metrics/froc.py,sha256=q7MAFsHHIp5EHBHwa5UbF5PRApjUonw-hUXax9k1WxQ,7981
|
|
232
232
|
monai/metrics/generalized_dice.py,sha256=9ZiEmGfMZLxFAF6AmdrbKOc8A_QOUMUmIZ6ILm-h01A,8939
|
233
233
|
monai/metrics/hausdorff_distance.py,sha256=4_ZJZ2gV1bPhOR5Mxz0PyN6Y_X1mTZ6U6T4gSRwjfDE,11844
|
234
234
|
monai/metrics/loss_metric.py,sha256=m9jXobVHKLeDY_8yrA9m7FwfapSAb-kYIdUJOsbvBvY,4907
|
235
|
-
monai/metrics/meandice.py,sha256=
|
235
|
+
monai/metrics/meandice.py,sha256=Q2Fp_YfZrlsx4cxR_h40zpeeGoIkKWQN78qzCShnbro,16237
|
236
236
|
monai/metrics/meaniou.py,sha256=cGoW1re7v4hxXJfjyEVEFNsuzEupgJaIe6ZK_qrbIjw,7004
|
237
237
|
monai/metrics/metric.py,sha256=VtIMNudwFkEhGAX1n0aYMaj18yKtmENKpo0JuWoVFvQ,15203
|
238
238
|
monai/metrics/mmd.py,sha256=a_O0WlUPrtegG16eBnEaf1HngPN4s4nAH1WtvGo-8BU,3299
|
@@ -245,19 +245,20 @@ monai/metrics/utils.py,sha256=eQ9QGGvuNmYFrgtVFNiA44pBhaHLCkmpyeK2FcK_2Pc,46941
|
|
245
245
|
monai/metrics/wrapper.py,sha256=c1zg-xcypQyZ840TEuhhLgr4sClYMWTxlv1OieJTtvE,11781
|
246
246
|
monai/networks/__init__.py,sha256=ZzU2Qo8gDXNiRBF0JapIo3xlecZHjXsJuarF0IKVKKY,1086
|
247
247
|
monai/networks/trt_compiler.py,sha256=IFfsM1qFZvmCUBbEvbHnZe6_zmMcXghkpkzmP43dZbk,27535
|
248
|
-
monai/networks/utils.py,sha256=
|
249
|
-
monai/networks/blocks/__init__.py,sha256=
|
248
|
+
monai/networks/utils.py,sha256=eFvKh-GcUGhhhM3sAQ69vn5zJAXJxRfT-uBKN0YCMp8,58226
|
249
|
+
monai/networks/blocks/__init__.py,sha256=oZHzxMiOOpLbgqWUWRAUY64FLkyx6yrDv8pZFq5iCTE,2481
|
250
250
|
monai/networks/blocks/acti_norm.py,sha256=bVGXbTZ_ssRvmED5R7LOQ7jj4V6WbVFl8JMO-4iZ2Dk,4275
|
251
251
|
monai/networks/blocks/activation.py,sha256=S5k3zcP2PsHBkeIxgWgNg8ppW80tTResVP2j9ZsvTFw,5839
|
252
252
|
monai/networks/blocks/aspp.py,sha256=GGGE7NfWj77RkaWHbcLuUP4Aff-WeiDrtgtFuSoekQk,4380
|
253
253
|
monai/networks/blocks/attention_utils.py,sha256=UAlttLpn8vJCIiYyWXEUF-NzVTQBOK-aTieGtR5WrXk,4951
|
254
254
|
monai/networks/blocks/backbone_fpn_utils.py,sha256=mdXFwtnRgwuaisTlY-c7OkY1ZZBY3I82dAjpXFAZFbg,7488
|
255
|
+
monai/networks/blocks/cablock.py,sha256=q-wBpW10Qm1dEhUW-SNvCC73YQwlQdbtscaAEeUEgN4,6914
|
255
256
|
monai/networks/blocks/convolutions.py,sha256=gRmbYfy3IR4taiXuxeH5KGOFjP55FoVWfP4e1L6ai0s,11686
|
256
257
|
monai/networks/blocks/crf.py,sha256=gHyRgBWD9DmmbCJnXwsMa6WN7N9fDLuT_SwH8MnHhXE,5009
|
257
258
|
monai/networks/blocks/crossattention.py,sha256=8rb1n41NRGjMHDegWXm9jlBHTaXFxEqgNLN8xsxXQzI,8348
|
258
259
|
monai/networks/blocks/denseblock.py,sha256=hs1rcBp95euZT5ULjgefPApZH75-hqSaVKKNtHdGt10,4747
|
259
260
|
monai/networks/blocks/dints_block.py,sha256=-JWz4-nnAjrOxU2oJ86-qN8Krb8FayKS8Zpbp1wLXzc,9255
|
260
|
-
monai/networks/blocks/downsample.py,sha256=
|
261
|
+
monai/networks/blocks/downsample.py,sha256=VX8LRINwyrPXUckA_HrjBnP8eyDW_G9WRMa88XAoCYQ,13495
|
261
262
|
monai/networks/blocks/dynunet_block.py,sha256=kg8NNTL4nBqsy6gBcxmS5ZCPxlhWM_iB0ByyTQ4AfPs,11063
|
262
263
|
monai/networks/blocks/encoder.py,sha256=NwH5VSQLwamJqrSbZSdQqMwZCk80CPbSpMGEE0r0Cwo,3669
|
263
264
|
monai/networks/blocks/fcn.py,sha256=mnCMrxhUdj2yZ0DPIj0Xf9OKVdv-qhG1BpnAg5j7q6c,9024
|
@@ -322,6 +323,7 @@ monai/networks/nets/quicknat.py,sha256=ko1BO9l4i4BVYG5V4ohkwUEyoRrPPPzmqNqnFhLTZ
|
|
322
323
|
monai/networks/nets/regressor.py,sha256=6Nz5yJuQDJJOr5R0rhot_mHu7_MDCA4ybV48wS1HS1M,6482
|
323
324
|
monai/networks/nets/regunet.py,sha256=-A6ygR7lVyAflFyqWkVVOsY94uMXWol1f2xr_HmsU1c,18664
|
324
325
|
monai/networks/nets/resnet.py,sha256=owsWu9lK26ijhRHDCLEBLf03t681TyehVCflcPqGIec,28179
|
326
|
+
monai/networks/nets/restormer.py,sha256=53VFC15JBZzhVmYEKY5F2ddXcL6NUQ0pDFtpvRVZ8hw,12897
|
325
327
|
monai/networks/nets/segresnet.py,sha256=xNkSIvdk7kAyc3eVn-U_gGj8MoGVc5nklFKc_fkgOUs,13994
|
326
328
|
monai/networks/nets/segresnet_ds.py,sha256=XFF7HKMt9Wbfc9fZSgfjVdfYQcP0d19ygp3VT7OHzJg,20644
|
327
329
|
monai/networks/nets/senet.py,sha256=yLhP9gDPoa-h9UwJZJm5qxPdPvF9calY95lButXJESs,19308
|
@@ -401,12 +403,12 @@ monai/transforms/spatial/functional.py,sha256=IwS0witCqbGkyuxzu_R4Ztp90S0pg9hY1i
|
|
401
403
|
monai/transforms/utility/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
402
404
|
monai/transforms/utility/array.py,sha256=Du3QA6m0io7mR51gUgaMwHBFNStdFmRxhaYmBCVy7BY,81215
|
403
405
|
monai/transforms/utility/dictionary.py,sha256=iOFdTSekvkAsBbbfHeffcRsOKRtNcnt3N1cVuUarZ1s,80549
|
404
|
-
monai/utils/__init__.py,sha256=
|
406
|
+
monai/utils/__init__.py,sha256=LAOUinb7tHhBy_hygIoxw9WVu7W0C-oJ1lnfPAx1GI0,3813
|
405
407
|
monai/utils/component_store.py,sha256=Fe9jbHgwwBBAeJAw0nI02Ae13v17wlwF6N9uUue8tJg,4525
|
406
408
|
monai/utils/decorators.py,sha256=qhhdmJMjMfZIUM6x_VGUGF7kaq2cBUAam8WymAU_mhw,3156
|
407
409
|
monai/utils/deprecate_utils.py,sha256=gKeEV4MsI51qeQ5gci2me_C-0e-tDwa3VZzd3XPQqLk,14759
|
408
410
|
monai/utils/dist.py,sha256=7brB42CvdS8Jvr8Y7hfqov1uk6NNnYea9dYfgMYy0BY,8578
|
409
|
-
monai/utils/enums.py,sha256=
|
411
|
+
monai/utils/enums.py,sha256=aupxnORUHqVPF2Ac5nxstsP5aIyewMoqgGb88D62yxg,19931
|
410
412
|
monai/utils/jupyter_utils.py,sha256=BYtj80LWQAYg5RWPj5g4j2AMCzLECvAcnZdXns0Ruw8,15651
|
411
413
|
monai/utils/misc.py,sha256=9-5zBIDSUYewzoQBkiBm0G_HR8hmwQCT-I15RYOQqEQ,31759
|
412
414
|
monai/utils/module.py,sha256=R37PpCNCcHQvjjZFbNjNyzWb3FURaKLxQucjhzQk0eU,26087
|
@@ -423,7 +425,7 @@ monai/visualize/img2tensorboard.py,sha256=n4ztSa5BQAUxSTGvi2tp45v-F7-RNgSlbsdy-9
|
|
423
425
|
monai/visualize/occlusion_sensitivity.py,sha256=OQHEJLyIhB8zWqQsfKaX-1kvCjWFVYtLfS4dFC0nKFI,18160
|
424
426
|
monai/visualize/utils.py,sha256=B-MhTVs7sQbIqYS3yPnpBwPw2K82rE2PBtGIfpwZtWM,9894
|
425
427
|
monai/visualize/visualizer.py,sha256=qckyaMZCbezYUwE20k5yc-Pb7UozVavMDbrmyQwfYHY,1377
|
426
|
-
monai_weekly-1.5.
|
428
|
+
monai_weekly-1.5.dev2513.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
427
429
|
tests/apps/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
428
430
|
tests/apps/test_auto3dseg_bundlegen.py,sha256=FpTJo9Lfe8vdhGuWeZ9y1BQmqYwTt-s8mDVtoLGAz_I,5594
|
429
431
|
tests/apps/test_check_hash.py,sha256=MuZslW2DDCxHKEo6-PiL7hnbxGuZRRYf6HOh3ZQv1qQ,1761
|
@@ -586,7 +588,6 @@ tests/integration/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0
|
|
586
588
|
tests/integration/test_auto3dseg_ensemble.py,sha256=CIdgRSNX1VBfSmSvx_8HgsIfFsJpoLAC5sg8_HrRP_c,7570
|
587
589
|
tests/integration/test_auto3dseg_hpo.py,sha256=J5Us-7iVE5UFbbsZdzNVcvPPQENTlPFarVsSt7KR3Yo,7286
|
588
590
|
tests/integration/test_deepedit_interaction.py,sha256=tmryp1cP_QlI_tgguZybRZc7-FIEADTUGlUz3tmhEeY,4580
|
589
|
-
tests/integration/test_downsample_block.py,sha256=qvqSeTwFQHwiJ0y8uwWE8U_9ffhltJ_4U5Zg5rBnQ6M,1794
|
590
591
|
tests/integration/test_hovernet_nuclear_type_post_processingd.py,sha256=yTRmYdQBXEMMmXJjPDBPMxPSkLWj2U3bdRhaAfDXrpE,2661
|
591
592
|
tests/integration/test_integration_autorunner.py,sha256=tDK1XkMZp4hehfuzMr2LQIgavP36L_vkFcOcI1Z68Lk,7571
|
592
593
|
tests/integration/test_integration_bundle_run.py,sha256=uO87WnnG3EYnAxhudpfHy7fyxHNNzifFTw2rrMm_6XU,10734
|
@@ -656,7 +657,7 @@ tests/metrics/test_compute_f_beta.py,sha256=xbCipeICoAXWZLgDFeDAa1KjDQxDTMVArNbt
|
|
656
657
|
tests/metrics/test_compute_fid_metric.py,sha256=B9OZECl3CT1JKzG-2C_YaPFjgfvlFoS9vI1j8vBzWZg,1328
|
657
658
|
tests/metrics/test_compute_froc.py,sha256=IdP8JXI7SCXJQ4oxqpiGInIjNgApIbZ7JtdMTX1lc-U,4522
|
658
659
|
tests/metrics/test_compute_generalized_dice.py,sha256=m5468hRvCYdfEF4B459e2LW3gDXH1PZSrBM3FOHHOsk,9614
|
659
|
-
tests/metrics/test_compute_meandice.py,sha256=
|
660
|
+
tests/metrics/test_compute_meandice.py,sha256=gRqSuys0MDktMQkCR_McceY4EiyYOkxYCuZKjz3VaHs,11390
|
660
661
|
tests/metrics/test_compute_meaniou.py,sha256=hphLbY6S-DA3CQiKOug-DblzqwPK0F7aF3Pujz6H0vk,8020
|
661
662
|
tests/metrics/test_compute_mmd_metric.py,sha256=9rwvmZaj4wQKLY3xfuF85gFvZrnyWSXXDd6m7zy63sg,2025
|
662
663
|
tests/metrics/test_compute_multiscalessim_metric.py,sha256=bLL6eNE_bhL4tL4EJO5XcaGurbE5utemc4b6PmJ766k,3080
|
@@ -683,12 +684,14 @@ tests/networks/test_save_state.py,sha256=OnUJEX6vqWoIAIEvVXHbAL4Yrv1GeY0YHw2Dpos
|
|
683
684
|
tests/networks/test_to_onehot.py,sha256=QlT6RkkG7CJeh0gppSohl4kb0bmhISdx_19IybYES0Q,2224
|
684
685
|
tests/networks/test_varnet.py,sha256=-9Ew5epHVvRLc34VCFwKNpsKKoAdudpBRlqDAShpIio,2800
|
685
686
|
tests/networks/blocks/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
687
|
+
tests/networks/blocks/test_CABlock.py,sha256=PLKDIIn185l6xD_c9aYr-uuOjENtIRdlIDq3xmijb5c,5938
|
686
688
|
tests/networks/blocks/test_adn.py,sha256=lTQWChKSNkQhIN3qZIIytOcox8GAoCFGpQUEVl7tqJ8,3191
|
687
689
|
tests/networks/blocks/test_convolutions.py,sha256=6XS_DcpcA03BXttjKUU-LqfHgmLnU9RL60YzGn4mDfE,7139
|
688
690
|
tests/networks/blocks/test_crf_cpu.py,sha256=UCw5vDdAgKA6_hpMtcHLBZpWBmdN245HelYMScZLp9E,18871
|
689
691
|
tests/networks/blocks/test_crf_cuda.py,sha256=L4d8M4EZNRA0VcACBkWchZWfaHX8efDYygafZAL6UvQ,19450
|
690
692
|
tests/networks/blocks/test_crossattention.py,sha256=mZ3q4nMhNN0Ntew5w8eO39qsucJmWNvqw6YYsshsZuI,7804
|
691
693
|
tests/networks/blocks/test_denseblock.py,sha256=Nfq4zOwO7MF0OUYASf1Duorw0VZXy1v46NoQWDpWWFc,4373
|
694
|
+
tests/networks/blocks/test_downsample_block.py,sha256=Q6nCD7AzDc1YLQcNcQhgO7fQqpqidj5Qkn64rephAAY,7687
|
692
695
|
tests/networks/blocks/test_dynunet_block.py,sha256=2juAJlzKjmPqm0WzfWUKvBPcWYZeFVoqu7-knD6maUU,4994
|
693
696
|
tests/networks/blocks/test_fpn_block.py,sha256=R9w3pnnl25XlR_V_HZoZe2Kq8BuCK-fww6edTBLcNlk,3467
|
694
697
|
tests/networks/blocks/test_localnet_block.py,sha256=PGeU2CesjHa9xs-Uj0G_aE_nfqKbzJVJHk5gQTZOPsg,4806
|
@@ -773,6 +776,7 @@ tests/networks/nets/test_network_consistency.py,sha256=OuEsjkCzQEIxQ9CNJxNXqI8Kr
|
|
773
776
|
tests/networks/nets/test_patch_gan_dicriminator.py,sha256=5qhzL55pid_9ShuALPzvW21eZtdlpKupw8hdu1N4sVE,5266
|
774
777
|
tests/networks/nets/test_quicknat.py,sha256=iuJRChBt6OoOvBGUe2bZ5wvcx0AfId4gZJ7K12SP7w8,2601
|
775
778
|
tests/networks/nets/test_resnet.py,sha256=nIx9ZrHWN36iiGP9KffiEdJ5kLctySh5_zdAddl9gTc,10475
|
779
|
+
tests/networks/nets/test_restormer.py,sha256=iP9oW-D0rmHnitjOAzEz2nBATR1XTNeEAukktuQrG2s,5332
|
776
780
|
tests/networks/nets/test_segresnet.py,sha256=1-TfOY4SSTAb3oyjFhOSoXLIO7TogksQQdvgKfXDg-o,4855
|
777
781
|
tests/networks/nets/test_segresnet_ds.py,sha256=K9_tjcxknCAy3UiVjVAQ-1nCIuXoIJ8QEBbYaERtQ7g,6576
|
778
782
|
tests/networks/nets/test_senet.py,sha256=V9HyDyYMR2r2F6FzZUl6INDipH5mk-IrExkkeZwHF4Q,6226
|
@@ -808,6 +812,7 @@ tests/networks/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZ
|
|
808
812
|
tests/networks/utils/test_copy_model_state.py,sha256=SI0dlUkA0rdZVFvi5acr0nve002oIiftWIPWkqLQH2Q,6764
|
809
813
|
tests/networks/utils/test_eval_mode.py,sha256=HQqgC4COr5fAsBo2Z-DCjnOfx6WLxXlPywlGMnQY7_0,1086
|
810
814
|
tests/networks/utils/test_freeze_layers.py,sha256=y1DdBo7AMowTyzEurOs2EfK6InDZrfq1_h3z7rpSoag,2006
|
815
|
+
tests/networks/utils/test_pixelunshuffle.py,sha256=p_PO1TqAJf5fNBh36OJyfhIEKtfx1rBrMlrmWfzJp7g,1911
|
811
816
|
tests/networks/utils/test_replace_module.py,sha256=_KBuO21-kvw7vckHh0KCiY8tPjMXSXpLv5kKH-ojeY4,4350
|
812
817
|
tests/networks/utils/test_train_mode.py,sha256=GYjtzEV_ynDIRl2jVi3-A3ihS33BvubAIKaEoMKaVKI,1051
|
813
818
|
tests/optimizers/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
@@ -1181,7 +1186,7 @@ tests/visualize/test_vis_gradcam.py,sha256=WpA-pvTB75eZs7JoIc5qyvOV9PwgkzWI8-Vow
|
|
1181
1186
|
tests/visualize/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
1182
1187
|
tests/visualize/utils/test_blend_images.py,sha256=RVs2p_8RWQDfhLHDNNtZaMig27v8o0km7XxNa-zWjKE,2274
|
1183
1188
|
tests/visualize/utils/test_matshow3d.py,sha256=wXYj77L5Jvnp0f6DvL1rsi_-YlCxS0HJ9hiPmrbpuP8,5021
|
1184
|
-
monai_weekly-1.5.
|
1185
|
-
monai_weekly-1.5.
|
1186
|
-
monai_weekly-1.5.
|
1187
|
-
monai_weekly-1.5.
|
1189
|
+
monai_weekly-1.5.dev2513.dist-info/METADATA,sha256=JtMbpoy1PATNUbzheFGRBPWYUP9Ka18xyiyNHGhad8k,12008
|
1190
|
+
monai_weekly-1.5.dev2513.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
1191
|
+
monai_weekly-1.5.dev2513.dist-info/top_level.txt,sha256=hn2Y6P9xBf2R8faMeVMHhPMvrdDKxMsIOwMDYI0yTjs,12
|
1192
|
+
monai_weekly-1.5.dev2513.dist-info/RECORD,,
|
@@ -267,15 +267,15 @@ class TestComputeMeanDice(unittest.TestCase):
|
|
267
267
|
@parameterized.expand([TEST_CASE_3])
|
268
268
|
def test_helper(self, input_data, _unused):
|
269
269
|
vals = {"y_pred": dict(input_data).pop("y_pred"), "y": dict(input_data).pop("y")}
|
270
|
-
result = DiceHelper(
|
270
|
+
result = DiceHelper(threshold=True)(**vals)
|
271
271
|
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4)
|
272
272
|
np.testing.assert_allclose(sorted(result[1].cpu().numpy()), [0.0, 1.0, 2.0], atol=1e-4)
|
273
|
-
result = DiceHelper(
|
273
|
+
result = DiceHelper(apply_argmax=True, get_not_nans=False)(**vals)
|
274
274
|
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0], atol=1e-4)
|
275
275
|
|
276
276
|
num_classes = vals["y_pred"].shape[1]
|
277
277
|
vals["y_pred"] = torch.argmax(vals["y_pred"], dim=1, keepdim=True)
|
278
|
-
result = DiceHelper(
|
278
|
+
result = DiceHelper(threshold=True, num_classes=num_classes)(**vals)
|
279
279
|
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4)
|
280
280
|
|
281
281
|
# DiceMetric class tests
|
@@ -0,0 +1,150 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
from unittest import skipUnless
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.networks import eval_mode
|
21
|
+
from monai.networks.blocks.cablock import CABlock, FeedForward
|
22
|
+
from monai.utils import optional_import
|
23
|
+
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose
|
24
|
+
|
25
|
+
einops, has_einops = optional_import("einops")
|
26
|
+
|
27
|
+
|
28
|
+
TEST_CASES_CAB = []
|
29
|
+
for spatial_dims in [2, 3]:
|
30
|
+
for dim in [32, 64, 128]:
|
31
|
+
for num_heads in [2, 4, 8]:
|
32
|
+
for bias in [True, False]:
|
33
|
+
test_case = [
|
34
|
+
{
|
35
|
+
"spatial_dims": spatial_dims,
|
36
|
+
"dim": dim,
|
37
|
+
"num_heads": num_heads,
|
38
|
+
"bias": bias,
|
39
|
+
"flash_attention": False,
|
40
|
+
},
|
41
|
+
(2, dim, *([16] * spatial_dims)),
|
42
|
+
(2, dim, *([16] * spatial_dims)),
|
43
|
+
]
|
44
|
+
TEST_CASES_CAB.append(test_case)
|
45
|
+
|
46
|
+
|
47
|
+
TEST_CASES_FEEDFORWARD = [
|
48
|
+
# Test different spatial dims, dimensions and expansion factors
|
49
|
+
[{"spatial_dims": 2, "dim": 64, "ffn_expansion_factor": 2.0, "bias": True}, (2, 64, 32, 32)],
|
50
|
+
[{"spatial_dims": 3, "dim": 128, "ffn_expansion_factor": 1.5, "bias": False}, (2, 128, 16, 16, 16)],
|
51
|
+
[{"spatial_dims": 2, "dim": 256, "ffn_expansion_factor": 1.0, "bias": True}, (1, 256, 64, 64)],
|
52
|
+
]
|
53
|
+
|
54
|
+
|
55
|
+
class TestFeedForward(unittest.TestCase):
|
56
|
+
|
57
|
+
@parameterized.expand(TEST_CASES_FEEDFORWARD)
|
58
|
+
def test_shape(self, input_param, input_shape):
|
59
|
+
net = FeedForward(**input_param)
|
60
|
+
with eval_mode(net):
|
61
|
+
result = net(torch.randn(input_shape))
|
62
|
+
self.assertEqual(result.shape, input_shape)
|
63
|
+
|
64
|
+
def test_gating_mechanism(self):
|
65
|
+
net = FeedForward(spatial_dims=2, dim=32, ffn_expansion_factor=2.0, bias=True)
|
66
|
+
x = torch.ones(1, 32, 16, 16)
|
67
|
+
out = net(x)
|
68
|
+
self.assertNotEqual(torch.sum(out), torch.sum(x))
|
69
|
+
|
70
|
+
|
71
|
+
class TestCABlock(unittest.TestCase):
|
72
|
+
|
73
|
+
@parameterized.expand(TEST_CASES_CAB)
|
74
|
+
@skipUnless(has_einops, "Requires einops")
|
75
|
+
def test_shape(self, input_param, input_shape, expected_shape):
|
76
|
+
net = CABlock(**input_param)
|
77
|
+
with eval_mode(net):
|
78
|
+
result = net(torch.randn(input_shape))
|
79
|
+
self.assertEqual(result.shape, expected_shape)
|
80
|
+
|
81
|
+
@skipUnless(has_einops, "Requires einops")
|
82
|
+
def test_invalid_spatial_dims(self):
|
83
|
+
with self.assertRaises(ValueError):
|
84
|
+
CABlock(spatial_dims=4, dim=64, num_heads=4, bias=True)
|
85
|
+
|
86
|
+
@SkipIfBeforePyTorchVersion((2, 0))
|
87
|
+
@skipUnless(has_einops, "Requires einops")
|
88
|
+
def test_flash_attention(self):
|
89
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
90
|
+
block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device)
|
91
|
+
x = torch.randn(2, 64, 32, 32).to(device)
|
92
|
+
output = block(x)
|
93
|
+
self.assertEqual(output.shape, x.shape)
|
94
|
+
|
95
|
+
@skipUnless(has_einops, "Requires einops")
|
96
|
+
def test_temperature_parameter(self):
|
97
|
+
block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True)
|
98
|
+
self.assertTrue(isinstance(block.temperature, torch.nn.Parameter))
|
99
|
+
self.assertEqual(block.temperature.shape, (4, 1, 1))
|
100
|
+
|
101
|
+
@skipUnless(has_einops, "Requires einops")
|
102
|
+
def test_qkv_transformation_2d(self):
|
103
|
+
block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True)
|
104
|
+
x = torch.randn(2, 64, 32, 32)
|
105
|
+
qkv = block.qkv(x)
|
106
|
+
self.assertEqual(qkv.shape, (2, 192, 32, 32))
|
107
|
+
|
108
|
+
@skipUnless(has_einops, "Requires einops")
|
109
|
+
def test_qkv_transformation_3d(self):
|
110
|
+
block = CABlock(spatial_dims=3, dim=64, num_heads=4, bias=True)
|
111
|
+
x = torch.randn(2, 64, 16, 16, 16)
|
112
|
+
qkv = block.qkv(x)
|
113
|
+
self.assertEqual(qkv.shape, (2, 192, 16, 16, 16))
|
114
|
+
|
115
|
+
@SkipIfBeforePyTorchVersion((2, 0))
|
116
|
+
@skipUnless(has_einops, "Requires einops")
|
117
|
+
def test_flash_vs_normal_attention(self):
|
118
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
119
|
+
block_flash = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device)
|
120
|
+
block_normal = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=False).to(device)
|
121
|
+
|
122
|
+
block_normal.load_state_dict(block_flash.state_dict())
|
123
|
+
|
124
|
+
x = torch.randn(2, 64, 32, 32).to(device)
|
125
|
+
with torch.no_grad():
|
126
|
+
out_flash = block_flash(x)
|
127
|
+
out_normal = block_normal(x)
|
128
|
+
|
129
|
+
assert_allclose(out_flash, out_normal, atol=1e-4)
|
130
|
+
|
131
|
+
@skipUnless(has_einops, "Requires einops")
|
132
|
+
def test_deterministic_small_input(self):
|
133
|
+
block = CABlock(spatial_dims=2, dim=2, num_heads=1, bias=False)
|
134
|
+
with torch.no_grad():
|
135
|
+
block.qkv.conv.weight.data.fill_(1.0)
|
136
|
+
block.qkv_dwconv.conv.weight.data.fill_(1.0)
|
137
|
+
block.temperature.data.fill_(1.0)
|
138
|
+
block.project_out.conv.weight.data.fill_(1.0)
|
139
|
+
|
140
|
+
x = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]], dtype=torch.float32)
|
141
|
+
|
142
|
+
output = block(x)
|
143
|
+
# Channel attention: sum([1..8]) * (qkv_conv=1) * (dwconv=1) * (attn_weights=1) * (proj=1) = 36 * 2 = 72
|
144
|
+
expected = torch.full_like(x, 72.0)
|
145
|
+
|
146
|
+
assert_allclose(output, expected, atol=1e-6)
|
147
|
+
|
148
|
+
|
149
|
+
if __name__ == "__main__":
|
150
|
+
unittest.main()
|
@@ -0,0 +1,184 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.networks import eval_mode
|
20
|
+
from monai.networks.blocks import DownSample, MaxAvgPool, SubpixelDownsample, SubpixelUpsample
|
21
|
+
from monai.utils import optional_import
|
22
|
+
|
23
|
+
einops, has_einops = optional_import("einops")
|
24
|
+
|
25
|
+
TEST_CASES = [
|
26
|
+
[{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7
|
27
|
+
[{"spatial_dims": 1, "kernel_size": 4}, (16, 4, 63), (16, 8, 15)], # 4-channel 1D, batch 16
|
28
|
+
[{"spatial_dims": 1, "kernel_size": 4, "padding": 1}, (16, 4, 63), (16, 8, 16)], # 4-channel 1D, batch 16
|
29
|
+
[ # 4-channel 3D, batch 16
|
30
|
+
{"spatial_dims": 3, "kernel_size": 3, "ceil_mode": True},
|
31
|
+
(16, 4, 32, 24, 48),
|
32
|
+
(16, 8, 11, 8, 16),
|
33
|
+
],
|
34
|
+
[ # 1-channel 3D, batch 16
|
35
|
+
{"spatial_dims": 3, "kernel_size": 3, "ceil_mode": False},
|
36
|
+
(16, 1, 32, 24, 48),
|
37
|
+
(16, 2, 10, 8, 16),
|
38
|
+
],
|
39
|
+
]
|
40
|
+
|
41
|
+
TEST_CASES_SUBPIXEL = [
|
42
|
+
[{"spatial_dims": 2, "in_channels": 1, "scale_factor": 2}, (1, 1, 8, 8), (1, 4, 4, 4)],
|
43
|
+
[{"spatial_dims": 3, "in_channels": 2, "scale_factor": 2}, (1, 2, 8, 8, 8), (1, 16, 4, 4, 4)],
|
44
|
+
[{"spatial_dims": 1, "in_channels": 3, "scale_factor": 2}, (1, 3, 8), (1, 6, 4)],
|
45
|
+
]
|
46
|
+
|
47
|
+
TEST_CASES_DOWNSAMPLE = [
|
48
|
+
[{"spatial_dims": 2, "in_channels": 4, "mode": "conv"}, (1, 4, 16, 16), (1, 4, 8, 8)],
|
49
|
+
[{"spatial_dims": 2, "in_channels": 4, "out_channels": 8, "mode": "convgroup"}, (1, 4, 16, 16), (1, 8, 8, 8)],
|
50
|
+
[{"spatial_dims": 3, "in_channels": 2, "mode": "maxpool"}, (1, 2, 16, 16, 16), (1, 2, 8, 8, 8)],
|
51
|
+
[{"spatial_dims": 2, "in_channels": 4, "mode": "avgpool"}, (1, 4, 16, 16), (1, 4, 8, 8)],
|
52
|
+
[{"spatial_dims": 2, "in_channels": 1, "mode": "pixelunshuffle"}, (1, 1, 16, 16), (1, 4, 8, 8)],
|
53
|
+
]
|
54
|
+
|
55
|
+
|
56
|
+
class TestMaxAvgPool(unittest.TestCase):
|
57
|
+
|
58
|
+
@parameterized.expand(TEST_CASES)
|
59
|
+
def test_shape(self, input_param, input_shape, expected_shape):
|
60
|
+
net = MaxAvgPool(**input_param)
|
61
|
+
with eval_mode(net):
|
62
|
+
result = net(torch.randn(input_shape))
|
63
|
+
self.assertEqual(result.shape, expected_shape)
|
64
|
+
|
65
|
+
|
66
|
+
class TestSubpixelDownsample(unittest.TestCase):
|
67
|
+
|
68
|
+
@parameterized.expand(TEST_CASES_SUBPIXEL)
|
69
|
+
def test_shape(self, input_param, input_shape, expected_shape):
|
70
|
+
downsampler = SubpixelDownsample(**input_param)
|
71
|
+
with eval_mode(downsampler):
|
72
|
+
result = downsampler(torch.randn(input_shape))
|
73
|
+
self.assertEqual(result.shape, expected_shape)
|
74
|
+
|
75
|
+
def test_predefined_tensor(self):
|
76
|
+
test_tensor = torch.arange(4).view(4, 1, 1).repeat(1, 4, 4)
|
77
|
+
test_tensor = test_tensor.unsqueeze(0)
|
78
|
+
|
79
|
+
downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None)
|
80
|
+
with eval_mode(downsampler):
|
81
|
+
result = downsampler(test_tensor)
|
82
|
+
self.assertEqual(result.shape, (1, 16, 2, 2))
|
83
|
+
self.assertTrue(torch.all(result[0, 0:3] == 0))
|
84
|
+
self.assertTrue(torch.all(result[0, 4:7] == 1))
|
85
|
+
self.assertTrue(torch.all(result[0, 8:11] == 2))
|
86
|
+
self.assertTrue(torch.all(result[0, 12:15] == 3))
|
87
|
+
|
88
|
+
def test_reconstruction_2d(self):
|
89
|
+
input_tensor = torch.randn(1, 1, 4, 4)
|
90
|
+
down = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None)
|
91
|
+
up = SubpixelUpsample(spatial_dims=2, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False)
|
92
|
+
with eval_mode(down), eval_mode(up):
|
93
|
+
downsampled = down(input_tensor)
|
94
|
+
reconstructed = up(downsampled)
|
95
|
+
self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5))
|
96
|
+
|
97
|
+
def test_reconstruction_3d(self):
|
98
|
+
input_tensor = torch.randn(1, 1, 4, 4, 4)
|
99
|
+
down = SubpixelDownsample(spatial_dims=3, in_channels=1, scale_factor=2, conv_block=None)
|
100
|
+
up = SubpixelUpsample(spatial_dims=3, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False)
|
101
|
+
with eval_mode(down), eval_mode(up):
|
102
|
+
downsampled = down(input_tensor)
|
103
|
+
reconstructed = up(downsampled)
|
104
|
+
self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5))
|
105
|
+
|
106
|
+
def test_invalid_spatial_size(self):
|
107
|
+
downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2)
|
108
|
+
with self.assertRaises(ValueError):
|
109
|
+
downsampler(torch.randn(1, 1, 3, 4))
|
110
|
+
|
111
|
+
def test_custom_conv_block(self):
|
112
|
+
custom_conv = torch.nn.Conv2d(1, 2, kernel_size=3, padding=1)
|
113
|
+
downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=custom_conv)
|
114
|
+
with eval_mode(downsampler):
|
115
|
+
result = downsampler(torch.randn(1, 1, 4, 4))
|
116
|
+
self.assertEqual(result.shape, (1, 8, 2, 2))
|
117
|
+
|
118
|
+
|
119
|
+
class TestDownSample(unittest.TestCase):
|
120
|
+
@parameterized.expand(TEST_CASES_DOWNSAMPLE)
|
121
|
+
def test_shape(self, input_param, input_shape, expected_shape):
|
122
|
+
net = DownSample(**input_param)
|
123
|
+
with eval_mode(net):
|
124
|
+
result = net(torch.randn(input_shape))
|
125
|
+
self.assertEqual(result.shape, expected_shape)
|
126
|
+
|
127
|
+
def test_pre_post_conv(self):
|
128
|
+
net = DownSample(
|
129
|
+
spatial_dims=2,
|
130
|
+
in_channels=4,
|
131
|
+
out_channels=8,
|
132
|
+
mode="maxpool",
|
133
|
+
pre_conv="default",
|
134
|
+
post_conv=torch.nn.Conv2d(8, 16, 1),
|
135
|
+
)
|
136
|
+
with eval_mode(net):
|
137
|
+
result = net(torch.randn(1, 4, 16, 16))
|
138
|
+
self.assertEqual(result.shape, (1, 16, 8, 8))
|
139
|
+
|
140
|
+
def test_pixelunshuffle_equivalence(self):
|
141
|
+
class DownSampleLocal(torch.nn.Module):
|
142
|
+
def __init__(self, n_feat: int):
|
143
|
+
super().__init__()
|
144
|
+
self.conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False)
|
145
|
+
self.pixelunshuffle = torch.nn.PixelUnshuffle(2)
|
146
|
+
|
147
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
148
|
+
x = self.conv(x)
|
149
|
+
x = self.pixelunshuffle(x)
|
150
|
+
return x
|
151
|
+
|
152
|
+
n_feat = 2
|
153
|
+
x = torch.randn(1, n_feat, 64, 64)
|
154
|
+
|
155
|
+
fix_weight_conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False)
|
156
|
+
|
157
|
+
monai_down = DownSample(
|
158
|
+
spatial_dims=2,
|
159
|
+
in_channels=n_feat,
|
160
|
+
out_channels=n_feat // 2,
|
161
|
+
mode="pixelunshuffle",
|
162
|
+
pre_conv=fix_weight_conv,
|
163
|
+
)
|
164
|
+
|
165
|
+
local_down = DownSampleLocal(n_feat)
|
166
|
+
local_down.conv.weight.data = fix_weight_conv.weight.data.clone()
|
167
|
+
|
168
|
+
with eval_mode(monai_down), eval_mode(local_down):
|
169
|
+
out_monai = monai_down(x)
|
170
|
+
out_local = local_down(x)
|
171
|
+
|
172
|
+
self.assertTrue(torch.allclose(out_monai, out_local, rtol=1e-5))
|
173
|
+
|
174
|
+
def test_invalid_mode(self):
|
175
|
+
with self.assertRaises(ValueError):
|
176
|
+
DownSample(spatial_dims=2, in_channels=4, mode="invalid")
|
177
|
+
|
178
|
+
def test_missing_channels(self):
|
179
|
+
with self.assertRaises(ValueError):
|
180
|
+
DownSample(spatial_dims=2, mode="conv")
|
181
|
+
|
182
|
+
|
183
|
+
if __name__ == "__main__":
|
184
|
+
unittest.main()
|
@@ -0,0 +1,147 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
from unittest import skipUnless
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.networks import eval_mode
|
21
|
+
from monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer
|
22
|
+
from monai.utils import optional_import
|
23
|
+
|
24
|
+
einops, has_einops = optional_import("einops")
|
25
|
+
|
26
|
+
TEST_CASES_TRANSFORMER = [
|
27
|
+
# [spatial_dims, dim, num_heads, ffn_factor, bias, layer_norm_use_bias, flash_attn, input_shape]
|
28
|
+
[2, 48, 8, 2.66, True, True, False, (2, 48, 64, 64)],
|
29
|
+
[2, 96, 8, 2.66, False, False, False, (2, 96, 32, 32)],
|
30
|
+
[3, 48, 4, 2.66, True, True, False, (2, 48, 32, 32, 32)],
|
31
|
+
[3, 96, 8, 2.66, False, False, True, (2, 96, 16, 16, 16)],
|
32
|
+
]
|
33
|
+
|
34
|
+
TEST_CASES_PATCHEMBED = [
|
35
|
+
# spatial_dims, in_channels, embed_dim, input_shape, expected_shape
|
36
|
+
[2, 1, 48, (2, 1, 64, 64), (2, 48, 64, 64)],
|
37
|
+
[2, 3, 96, (2, 3, 32, 32), (2, 96, 32, 32)],
|
38
|
+
[3, 1, 48, (2, 1, 32, 32, 32), (2, 48, 32, 32, 32)],
|
39
|
+
[3, 4, 64, (2, 4, 16, 16, 16), (2, 64, 16, 16, 16)],
|
40
|
+
]
|
41
|
+
|
42
|
+
RESTORMER_CONFIGS = [
|
43
|
+
# 2-level architecture
|
44
|
+
{"num_blocks": [1, 1], "heads": [1, 1]},
|
45
|
+
{"num_blocks": [2, 1], "heads": [2, 1]},
|
46
|
+
# 3-level architecture
|
47
|
+
{"num_blocks": [1, 1, 1], "heads": [1, 1, 1]},
|
48
|
+
{"num_blocks": [2, 1, 1], "heads": [2, 1, 1]},
|
49
|
+
]
|
50
|
+
|
51
|
+
TEST_CASES_RESTORMER = []
|
52
|
+
for config in RESTORMER_CONFIGS:
|
53
|
+
# 2D cases
|
54
|
+
TEST_CASES_RESTORMER.extend(
|
55
|
+
[
|
56
|
+
[
|
57
|
+
{
|
58
|
+
"spatial_dims": 2,
|
59
|
+
"in_channels": 1,
|
60
|
+
"out_channels": 1,
|
61
|
+
"dim": 48,
|
62
|
+
"num_blocks": config["num_blocks"],
|
63
|
+
"heads": config["heads"],
|
64
|
+
"num_refinement_blocks": 2,
|
65
|
+
"ffn_expansion_factor": 1.5,
|
66
|
+
},
|
67
|
+
(2, 1, 64, 64),
|
68
|
+
(2, 1, 64, 64),
|
69
|
+
],
|
70
|
+
# 3D cases
|
71
|
+
[
|
72
|
+
{
|
73
|
+
"spatial_dims": 3,
|
74
|
+
"in_channels": 1,
|
75
|
+
"out_channels": 1,
|
76
|
+
"dim": 16,
|
77
|
+
"num_blocks": config["num_blocks"],
|
78
|
+
"heads": config["heads"],
|
79
|
+
"num_refinement_blocks": 2,
|
80
|
+
"ffn_expansion_factor": 1.5,
|
81
|
+
},
|
82
|
+
(2, 1, 32, 32, 32),
|
83
|
+
(2, 1, 32, 32, 32),
|
84
|
+
],
|
85
|
+
]
|
86
|
+
)
|
87
|
+
|
88
|
+
|
89
|
+
class TestMDTATransformerBlock(unittest.TestCase):
|
90
|
+
|
91
|
+
@parameterized.expand(TEST_CASES_TRANSFORMER)
|
92
|
+
@skipUnless(has_einops, "Requires einops")
|
93
|
+
def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape):
|
94
|
+
if flash and not torch.cuda.is_available():
|
95
|
+
self.skipTest("Flash attention requires CUDA")
|
96
|
+
block = MDTATransformerBlock(
|
97
|
+
spatial_dims=spatial_dims,
|
98
|
+
dim=dim,
|
99
|
+
num_heads=heads,
|
100
|
+
ffn_expansion_factor=ffn_factor,
|
101
|
+
bias=bias,
|
102
|
+
layer_norm_use_bias=layer_norm_use_bias,
|
103
|
+
flash_attention=flash,
|
104
|
+
)
|
105
|
+
with eval_mode(block):
|
106
|
+
x = torch.randn(shape)
|
107
|
+
output = block(x)
|
108
|
+
self.assertEqual(output.shape, x.shape)
|
109
|
+
|
110
|
+
|
111
|
+
class TestOverlapPatchEmbed(unittest.TestCase):
|
112
|
+
|
113
|
+
@parameterized.expand(TEST_CASES_PATCHEMBED)
|
114
|
+
def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected_shape):
|
115
|
+
net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_channels=in_channels, embed_dim=embed_dim)
|
116
|
+
with eval_mode(net):
|
117
|
+
result = net(torch.randn(input_shape))
|
118
|
+
self.assertEqual(result.shape, expected_shape)
|
119
|
+
|
120
|
+
|
121
|
+
class TestRestormer(unittest.TestCase):
|
122
|
+
|
123
|
+
@parameterized.expand(TEST_CASES_RESTORMER)
|
124
|
+
@skipUnless(has_einops, "Requires einops")
|
125
|
+
def test_shape(self, input_param, input_shape, expected_shape):
|
126
|
+
if input_param.get("flash_attention", False) and not torch.cuda.is_available():
|
127
|
+
self.skipTest("Flash attention requires CUDA")
|
128
|
+
net = Restormer(**input_param)
|
129
|
+
with eval_mode(net):
|
130
|
+
result = net(torch.randn(input_shape))
|
131
|
+
self.assertEqual(result.shape, expected_shape)
|
132
|
+
|
133
|
+
@skipUnless(has_einops, "Requires einops")
|
134
|
+
def test_small_input_error_2d(self):
|
135
|
+
net = Restormer(spatial_dims=2, in_channels=1, out_channels=1)
|
136
|
+
with self.assertRaises(AssertionError):
|
137
|
+
net(torch.randn(1, 1, 8, 8))
|
138
|
+
|
139
|
+
@skipUnless(has_einops, "Requires einops")
|
140
|
+
def test_small_input_error_3d(self):
|
141
|
+
net = Restormer(spatial_dims=3, in_channels=1, out_channels=1)
|
142
|
+
with self.assertRaises(AssertionError):
|
143
|
+
net(torch.randn(1, 1, 8, 8, 8))
|
144
|
+
|
145
|
+
|
146
|
+
if __name__ == "__main__":
|
147
|
+
unittest.main()
|