monai-weekly 1.5.dev2512__py3-none-any.whl → 1.5.dev2514__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monai/utils/misc.py CHANGED
@@ -546,7 +546,7 @@ class MONAIEnvVars:
546
546
 
547
547
  @staticmethod
548
548
  def algo_hash() -> str | None:
549
- return os.environ.get("MONAI_ALGO_HASH", "c108ea9")
549
+ return os.environ.get("MONAI_ALGO_HASH", "4c18daf")
550
550
 
551
551
  @staticmethod
552
552
  def trace_transform() -> str | None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: monai-weekly
3
- Version: 1.5.dev2512
3
+ Version: 1.5.dev2514
4
4
  Summary: AI Toolkit for Healthcare Imaging
5
5
  Home-page: https://monai.io/
6
6
  Author: MONAI Consortium
@@ -1,5 +1,5 @@
1
- monai/__init__.py,sha256=fuJ4RGZ8NPyFT0vu60YfAD4gF52sBh2fQOF2hbwWNPM,4095
2
- monai/_version.py,sha256=u22VnA0EVDmAkTekBp5qsbxC8k8mBO3MhxLp7DiVDYE,503
1
+ monai/__init__.py,sha256=CFh-TH1zyQNZpWxq-GwpxdrvsABNPAk--r4Ly93w9ys,4095
2
+ monai/_version.py,sha256=NFOTegOPD1Xv0in1fjm-WdipEPfARPmpiKGLaetgHeY,503
3
3
  monai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  monai/_extensions/__init__.py,sha256=NEBPreRhQ8H9gVvgrLr_y52_TmqB96u_u4VQmeNT93I,642
5
5
  monai/_extensions/loader.py,sha256=7SiKw36q-nOzH8CRbBurFrz7GM40GCu7rc93Tm8XpnI,3643
@@ -232,7 +232,7 @@ monai/metrics/froc.py,sha256=q7MAFsHHIp5EHBHwa5UbF5PRApjUonw-hUXax9k1WxQ,7981
232
232
  monai/metrics/generalized_dice.py,sha256=9ZiEmGfMZLxFAF6AmdrbKOc8A_QOUMUmIZ6ILm-h01A,8939
233
233
  monai/metrics/hausdorff_distance.py,sha256=4_ZJZ2gV1bPhOR5Mxz0PyN6Y_X1mTZ6U6T4gSRwjfDE,11844
234
234
  monai/metrics/loss_metric.py,sha256=m9jXobVHKLeDY_8yrA9m7FwfapSAb-kYIdUJOsbvBvY,4907
235
- monai/metrics/meandice.py,sha256=bFiDcK-af4cqV-JHAO2Qh2ixwj6fLjaBCaCO6jBAmxQ,13475
235
+ monai/metrics/meandice.py,sha256=Q2Fp_YfZrlsx4cxR_h40zpeeGoIkKWQN78qzCShnbro,16237
236
236
  monai/metrics/meaniou.py,sha256=cGoW1re7v4hxXJfjyEVEFNsuzEupgJaIe6ZK_qrbIjw,7004
237
237
  monai/metrics/metric.py,sha256=VtIMNudwFkEhGAX1n0aYMaj18yKtmENKpo0JuWoVFvQ,15203
238
238
  monai/metrics/mmd.py,sha256=a_O0WlUPrtegG16eBnEaf1HngPN4s4nAH1WtvGo-8BU,3299
@@ -245,19 +245,20 @@ monai/metrics/utils.py,sha256=eQ9QGGvuNmYFrgtVFNiA44pBhaHLCkmpyeK2FcK_2Pc,46941
245
245
  monai/metrics/wrapper.py,sha256=c1zg-xcypQyZ840TEuhhLgr4sClYMWTxlv1OieJTtvE,11781
246
246
  monai/networks/__init__.py,sha256=ZzU2Qo8gDXNiRBF0JapIo3xlecZHjXsJuarF0IKVKKY,1086
247
247
  monai/networks/trt_compiler.py,sha256=IFfsM1qFZvmCUBbEvbHnZe6_zmMcXghkpkzmP43dZbk,27535
248
- monai/networks/utils.py,sha256=8kxdwqV_nxGgwjF7lt_9tsJhesCjnE1eSCvQWzqr5RQ,56372
249
- monai/networks/blocks/__init__.py,sha256=xf-4SLQjL3bU7T_vCnAIbeBzz0Ys2rrtlegJM5bej-Q,2355
248
+ monai/networks/utils.py,sha256=eFvKh-GcUGhhhM3sAQ69vn5zJAXJxRfT-uBKN0YCMp8,58226
249
+ monai/networks/blocks/__init__.py,sha256=oZHzxMiOOpLbgqWUWRAUY64FLkyx6yrDv8pZFq5iCTE,2481
250
250
  monai/networks/blocks/acti_norm.py,sha256=bVGXbTZ_ssRvmED5R7LOQ7jj4V6WbVFl8JMO-4iZ2Dk,4275
251
251
  monai/networks/blocks/activation.py,sha256=S5k3zcP2PsHBkeIxgWgNg8ppW80tTResVP2j9ZsvTFw,5839
252
252
  monai/networks/blocks/aspp.py,sha256=GGGE7NfWj77RkaWHbcLuUP4Aff-WeiDrtgtFuSoekQk,4380
253
253
  monai/networks/blocks/attention_utils.py,sha256=UAlttLpn8vJCIiYyWXEUF-NzVTQBOK-aTieGtR5WrXk,4951
254
254
  monai/networks/blocks/backbone_fpn_utils.py,sha256=mdXFwtnRgwuaisTlY-c7OkY1ZZBY3I82dAjpXFAZFbg,7488
255
+ monai/networks/blocks/cablock.py,sha256=q-wBpW10Qm1dEhUW-SNvCC73YQwlQdbtscaAEeUEgN4,6914
255
256
  monai/networks/blocks/convolutions.py,sha256=gRmbYfy3IR4taiXuxeH5KGOFjP55FoVWfP4e1L6ai0s,11686
256
257
  monai/networks/blocks/crf.py,sha256=gHyRgBWD9DmmbCJnXwsMa6WN7N9fDLuT_SwH8MnHhXE,5009
257
258
  monai/networks/blocks/crossattention.py,sha256=8rb1n41NRGjMHDegWXm9jlBHTaXFxEqgNLN8xsxXQzI,8348
258
259
  monai/networks/blocks/denseblock.py,sha256=hs1rcBp95euZT5ULjgefPApZH75-hqSaVKKNtHdGt10,4747
259
260
  monai/networks/blocks/dints_block.py,sha256=-JWz4-nnAjrOxU2oJ86-qN8Krb8FayKS8Zpbp1wLXzc,9255
260
- monai/networks/blocks/downsample.py,sha256=18cwYXL5H3DC5Yq12cdqTIijDJfMCE2YNHlPetFB6UY,2413
261
+ monai/networks/blocks/downsample.py,sha256=VX8LRINwyrPXUckA_HrjBnP8eyDW_G9WRMa88XAoCYQ,13495
261
262
  monai/networks/blocks/dynunet_block.py,sha256=kg8NNTL4nBqsy6gBcxmS5ZCPxlhWM_iB0ByyTQ4AfPs,11063
262
263
  monai/networks/blocks/encoder.py,sha256=NwH5VSQLwamJqrSbZSdQqMwZCk80CPbSpMGEE0r0Cwo,3669
263
264
  monai/networks/blocks/fcn.py,sha256=mnCMrxhUdj2yZ0DPIj0Xf9OKVdv-qhG1BpnAg5j7q6c,9024
@@ -322,6 +323,7 @@ monai/networks/nets/quicknat.py,sha256=ko1BO9l4i4BVYG5V4ohkwUEyoRrPPPzmqNqnFhLTZ
322
323
  monai/networks/nets/regressor.py,sha256=6Nz5yJuQDJJOr5R0rhot_mHu7_MDCA4ybV48wS1HS1M,6482
323
324
  monai/networks/nets/regunet.py,sha256=-A6ygR7lVyAflFyqWkVVOsY94uMXWol1f2xr_HmsU1c,18664
324
325
  monai/networks/nets/resnet.py,sha256=owsWu9lK26ijhRHDCLEBLf03t681TyehVCflcPqGIec,28179
326
+ monai/networks/nets/restormer.py,sha256=53VFC15JBZzhVmYEKY5F2ddXcL6NUQ0pDFtpvRVZ8hw,12897
325
327
  monai/networks/nets/segresnet.py,sha256=xNkSIvdk7kAyc3eVn-U_gGj8MoGVc5nklFKc_fkgOUs,13994
326
328
  monai/networks/nets/segresnet_ds.py,sha256=XFF7HKMt9Wbfc9fZSgfjVdfYQcP0d19ygp3VT7OHzJg,20644
327
329
  monai/networks/nets/senet.py,sha256=yLhP9gDPoa-h9UwJZJm5qxPdPvF9calY95lButXJESs,19308
@@ -401,14 +403,14 @@ monai/transforms/spatial/functional.py,sha256=IwS0witCqbGkyuxzu_R4Ztp90S0pg9hY1i
401
403
  monai/transforms/utility/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
402
404
  monai/transforms/utility/array.py,sha256=Du3QA6m0io7mR51gUgaMwHBFNStdFmRxhaYmBCVy7BY,81215
403
405
  monai/transforms/utility/dictionary.py,sha256=iOFdTSekvkAsBbbfHeffcRsOKRtNcnt3N1cVuUarZ1s,80549
404
- monai/utils/__init__.py,sha256=2_AIpb1wqGMkmgoZ3r43muFTEsnMTCkPu3LtckipYHg,3793
406
+ monai/utils/__init__.py,sha256=LAOUinb7tHhBy_hygIoxw9WVu7W0C-oJ1lnfPAx1GI0,3813
405
407
  monai/utils/component_store.py,sha256=Fe9jbHgwwBBAeJAw0nI02Ae13v17wlwF6N9uUue8tJg,4525
406
408
  monai/utils/decorators.py,sha256=qhhdmJMjMfZIUM6x_VGUGF7kaq2cBUAam8WymAU_mhw,3156
407
409
  monai/utils/deprecate_utils.py,sha256=gKeEV4MsI51qeQ5gci2me_C-0e-tDwa3VZzd3XPQqLk,14759
408
410
  monai/utils/dist.py,sha256=7brB42CvdS8Jvr8Y7hfqov1uk6NNnYea9dYfgMYy0BY,8578
409
- monai/utils/enums.py,sha256=jXtLaNDxG3BRBgLG2t13_S_G4iVWYHZO_GztykAtmXg,19594
411
+ monai/utils/enums.py,sha256=aupxnORUHqVPF2Ac5nxstsP5aIyewMoqgGb88D62yxg,19931
410
412
  monai/utils/jupyter_utils.py,sha256=BYtj80LWQAYg5RWPj5g4j2AMCzLECvAcnZdXns0Ruw8,15651
411
- monai/utils/misc.py,sha256=9-5zBIDSUYewzoQBkiBm0G_HR8hmwQCT-I15RYOQqEQ,31759
413
+ monai/utils/misc.py,sha256=M0oCfj55pZTrcYF0QgyS91JflqBwxSuNnOifl2HRSZk,31759
412
414
  monai/utils/module.py,sha256=R37PpCNCcHQvjjZFbNjNyzWb3FURaKLxQucjhzQk0eU,26087
413
415
  monai/utils/nvtx.py,sha256=i9JBxR1uhW1ZCgLPLlTx8b907QlXkFzJyTBLMlFjhtU,6876
414
416
  monai/utils/ordering.py,sha256=0nlA5b5QpVCHbtiCbTC-YsqjTmjm0bub0IeJhGFBOes,8270
@@ -423,7 +425,7 @@ monai/visualize/img2tensorboard.py,sha256=n4ztSa5BQAUxSTGvi2tp45v-F7-RNgSlbsdy-9
423
425
  monai/visualize/occlusion_sensitivity.py,sha256=OQHEJLyIhB8zWqQsfKaX-1kvCjWFVYtLfS4dFC0nKFI,18160
424
426
  monai/visualize/utils.py,sha256=B-MhTVs7sQbIqYS3yPnpBwPw2K82rE2PBtGIfpwZtWM,9894
425
427
  monai/visualize/visualizer.py,sha256=qckyaMZCbezYUwE20k5yc-Pb7UozVavMDbrmyQwfYHY,1377
426
- monai_weekly-1.5.dev2512.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
428
+ monai_weekly-1.5.dev2514.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
427
429
  tests/apps/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
428
430
  tests/apps/test_auto3dseg_bundlegen.py,sha256=FpTJo9Lfe8vdhGuWeZ9y1BQmqYwTt-s8mDVtoLGAz_I,5594
429
431
  tests/apps/test_check_hash.py,sha256=MuZslW2DDCxHKEo6-PiL7hnbxGuZRRYf6HOh3ZQv1qQ,1761
@@ -586,7 +588,6 @@ tests/integration/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0
586
588
  tests/integration/test_auto3dseg_ensemble.py,sha256=CIdgRSNX1VBfSmSvx_8HgsIfFsJpoLAC5sg8_HrRP_c,7570
587
589
  tests/integration/test_auto3dseg_hpo.py,sha256=J5Us-7iVE5UFbbsZdzNVcvPPQENTlPFarVsSt7KR3Yo,7286
588
590
  tests/integration/test_deepedit_interaction.py,sha256=tmryp1cP_QlI_tgguZybRZc7-FIEADTUGlUz3tmhEeY,4580
589
- tests/integration/test_downsample_block.py,sha256=qvqSeTwFQHwiJ0y8uwWE8U_9ffhltJ_4U5Zg5rBnQ6M,1794
590
591
  tests/integration/test_hovernet_nuclear_type_post_processingd.py,sha256=yTRmYdQBXEMMmXJjPDBPMxPSkLWj2U3bdRhaAfDXrpE,2661
591
592
  tests/integration/test_integration_autorunner.py,sha256=tDK1XkMZp4hehfuzMr2LQIgavP36L_vkFcOcI1Z68Lk,7571
592
593
  tests/integration/test_integration_bundle_run.py,sha256=uO87WnnG3EYnAxhudpfHy7fyxHNNzifFTw2rrMm_6XU,10734
@@ -656,7 +657,7 @@ tests/metrics/test_compute_f_beta.py,sha256=xbCipeICoAXWZLgDFeDAa1KjDQxDTMVArNbt
656
657
  tests/metrics/test_compute_fid_metric.py,sha256=B9OZECl3CT1JKzG-2C_YaPFjgfvlFoS9vI1j8vBzWZg,1328
657
658
  tests/metrics/test_compute_froc.py,sha256=IdP8JXI7SCXJQ4oxqpiGInIjNgApIbZ7JtdMTX1lc-U,4522
658
659
  tests/metrics/test_compute_generalized_dice.py,sha256=m5468hRvCYdfEF4B459e2LW3gDXH1PZSrBM3FOHHOsk,9614
659
- tests/metrics/test_compute_meandice.py,sha256=kC7JEqHUe54GrPxypoEjlmUZtxVZxjbhfRWEsZPP7CY,11381
660
+ tests/metrics/test_compute_meandice.py,sha256=gRqSuys0MDktMQkCR_McceY4EiyYOkxYCuZKjz3VaHs,11390
660
661
  tests/metrics/test_compute_meaniou.py,sha256=hphLbY6S-DA3CQiKOug-DblzqwPK0F7aF3Pujz6H0vk,8020
661
662
  tests/metrics/test_compute_mmd_metric.py,sha256=9rwvmZaj4wQKLY3xfuF85gFvZrnyWSXXDd6m7zy63sg,2025
662
663
  tests/metrics/test_compute_multiscalessim_metric.py,sha256=bLL6eNE_bhL4tL4EJO5XcaGurbE5utemc4b6PmJ766k,3080
@@ -683,12 +684,14 @@ tests/networks/test_save_state.py,sha256=OnUJEX6vqWoIAIEvVXHbAL4Yrv1GeY0YHw2Dpos
683
684
  tests/networks/test_to_onehot.py,sha256=QlT6RkkG7CJeh0gppSohl4kb0bmhISdx_19IybYES0Q,2224
684
685
  tests/networks/test_varnet.py,sha256=-9Ew5epHVvRLc34VCFwKNpsKKoAdudpBRlqDAShpIio,2800
685
686
  tests/networks/blocks/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
687
+ tests/networks/blocks/test_CABlock.py,sha256=PLKDIIn185l6xD_c9aYr-uuOjENtIRdlIDq3xmijb5c,5938
686
688
  tests/networks/blocks/test_adn.py,sha256=lTQWChKSNkQhIN3qZIIytOcox8GAoCFGpQUEVl7tqJ8,3191
687
689
  tests/networks/blocks/test_convolutions.py,sha256=6XS_DcpcA03BXttjKUU-LqfHgmLnU9RL60YzGn4mDfE,7139
688
690
  tests/networks/blocks/test_crf_cpu.py,sha256=UCw5vDdAgKA6_hpMtcHLBZpWBmdN245HelYMScZLp9E,18871
689
691
  tests/networks/blocks/test_crf_cuda.py,sha256=L4d8M4EZNRA0VcACBkWchZWfaHX8efDYygafZAL6UvQ,19450
690
692
  tests/networks/blocks/test_crossattention.py,sha256=mZ3q4nMhNN0Ntew5w8eO39qsucJmWNvqw6YYsshsZuI,7804
691
693
  tests/networks/blocks/test_denseblock.py,sha256=Nfq4zOwO7MF0OUYASf1Duorw0VZXy1v46NoQWDpWWFc,4373
694
+ tests/networks/blocks/test_downsample_block.py,sha256=Q6nCD7AzDc1YLQcNcQhgO7fQqpqidj5Qkn64rephAAY,7687
692
695
  tests/networks/blocks/test_dynunet_block.py,sha256=2juAJlzKjmPqm0WzfWUKvBPcWYZeFVoqu7-knD6maUU,4994
693
696
  tests/networks/blocks/test_fpn_block.py,sha256=R9w3pnnl25XlR_V_HZoZe2Kq8BuCK-fww6edTBLcNlk,3467
694
697
  tests/networks/blocks/test_localnet_block.py,sha256=PGeU2CesjHa9xs-Uj0G_aE_nfqKbzJVJHk5gQTZOPsg,4806
@@ -773,6 +776,7 @@ tests/networks/nets/test_network_consistency.py,sha256=OuEsjkCzQEIxQ9CNJxNXqI8Kr
773
776
  tests/networks/nets/test_patch_gan_dicriminator.py,sha256=5qhzL55pid_9ShuALPzvW21eZtdlpKupw8hdu1N4sVE,5266
774
777
  tests/networks/nets/test_quicknat.py,sha256=iuJRChBt6OoOvBGUe2bZ5wvcx0AfId4gZJ7K12SP7w8,2601
775
778
  tests/networks/nets/test_resnet.py,sha256=nIx9ZrHWN36iiGP9KffiEdJ5kLctySh5_zdAddl9gTc,10475
779
+ tests/networks/nets/test_restormer.py,sha256=iP9oW-D0rmHnitjOAzEz2nBATR1XTNeEAukktuQrG2s,5332
776
780
  tests/networks/nets/test_segresnet.py,sha256=1-TfOY4SSTAb3oyjFhOSoXLIO7TogksQQdvgKfXDg-o,4855
777
781
  tests/networks/nets/test_segresnet_ds.py,sha256=K9_tjcxknCAy3UiVjVAQ-1nCIuXoIJ8QEBbYaERtQ7g,6576
778
782
  tests/networks/nets/test_senet.py,sha256=V9HyDyYMR2r2F6FzZUl6INDipH5mk-IrExkkeZwHF4Q,6226
@@ -808,6 +812,7 @@ tests/networks/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZ
808
812
  tests/networks/utils/test_copy_model_state.py,sha256=SI0dlUkA0rdZVFvi5acr0nve002oIiftWIPWkqLQH2Q,6764
809
813
  tests/networks/utils/test_eval_mode.py,sha256=HQqgC4COr5fAsBo2Z-DCjnOfx6WLxXlPywlGMnQY7_0,1086
810
814
  tests/networks/utils/test_freeze_layers.py,sha256=y1DdBo7AMowTyzEurOs2EfK6InDZrfq1_h3z7rpSoag,2006
815
+ tests/networks/utils/test_pixelunshuffle.py,sha256=p_PO1TqAJf5fNBh36OJyfhIEKtfx1rBrMlrmWfzJp7g,1911
811
816
  tests/networks/utils/test_replace_module.py,sha256=_KBuO21-kvw7vckHh0KCiY8tPjMXSXpLv5kKH-ojeY4,4350
812
817
  tests/networks/utils/test_train_mode.py,sha256=GYjtzEV_ynDIRl2jVi3-A3ihS33BvubAIKaEoMKaVKI,1051
813
818
  tests/optimizers/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
@@ -1181,7 +1186,7 @@ tests/visualize/test_vis_gradcam.py,sha256=WpA-pvTB75eZs7JoIc5qyvOV9PwgkzWI8-Vow
1181
1186
  tests/visualize/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
1182
1187
  tests/visualize/utils/test_blend_images.py,sha256=RVs2p_8RWQDfhLHDNNtZaMig27v8o0km7XxNa-zWjKE,2274
1183
1188
  tests/visualize/utils/test_matshow3d.py,sha256=wXYj77L5Jvnp0f6DvL1rsi_-YlCxS0HJ9hiPmrbpuP8,5021
1184
- monai_weekly-1.5.dev2512.dist-info/METADATA,sha256=m8Cze_lpf0KT5GBA-1frWxwCOrRuIxDISVduzhJJLXM,12008
1185
- monai_weekly-1.5.dev2512.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
1186
- monai_weekly-1.5.dev2512.dist-info/top_level.txt,sha256=hn2Y6P9xBf2R8faMeVMHhPMvrdDKxMsIOwMDYI0yTjs,12
1187
- monai_weekly-1.5.dev2512.dist-info/RECORD,,
1189
+ monai_weekly-1.5.dev2514.dist-info/METADATA,sha256=D_HyUHJvaNgwcKd2KZ6XK0fpADz8jFTu5fLQl1p2Nwk,12008
1190
+ monai_weekly-1.5.dev2514.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
1191
+ monai_weekly-1.5.dev2514.dist-info/top_level.txt,sha256=hn2Y6P9xBf2R8faMeVMHhPMvrdDKxMsIOwMDYI0yTjs,12
1192
+ monai_weekly-1.5.dev2514.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (77.0.3)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -267,15 +267,15 @@ class TestComputeMeanDice(unittest.TestCase):
267
267
  @parameterized.expand([TEST_CASE_3])
268
268
  def test_helper(self, input_data, _unused):
269
269
  vals = {"y_pred": dict(input_data).pop("y_pred"), "y": dict(input_data).pop("y")}
270
- result = DiceHelper(sigmoid=True)(**vals)
270
+ result = DiceHelper(threshold=True)(**vals)
271
271
  np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4)
272
272
  np.testing.assert_allclose(sorted(result[1].cpu().numpy()), [0.0, 1.0, 2.0], atol=1e-4)
273
- result = DiceHelper(softmax=True, get_not_nans=False)(**vals)
273
+ result = DiceHelper(apply_argmax=True, get_not_nans=False)(**vals)
274
274
  np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0], atol=1e-4)
275
275
 
276
276
  num_classes = vals["y_pred"].shape[1]
277
277
  vals["y_pred"] = torch.argmax(vals["y_pred"], dim=1, keepdim=True)
278
- result = DiceHelper(sigmoid=True, num_classes=num_classes)(**vals)
278
+ result = DiceHelper(threshold=True, num_classes=num_classes)(**vals)
279
279
  np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4)
280
280
 
281
281
  # DiceMetric class tests
@@ -0,0 +1,150 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+ from unittest import skipUnless
16
+
17
+ import torch
18
+ from parameterized import parameterized
19
+
20
+ from monai.networks import eval_mode
21
+ from monai.networks.blocks.cablock import CABlock, FeedForward
22
+ from monai.utils import optional_import
23
+ from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose
24
+
25
+ einops, has_einops = optional_import("einops")
26
+
27
+
28
+ TEST_CASES_CAB = []
29
+ for spatial_dims in [2, 3]:
30
+ for dim in [32, 64, 128]:
31
+ for num_heads in [2, 4, 8]:
32
+ for bias in [True, False]:
33
+ test_case = [
34
+ {
35
+ "spatial_dims": spatial_dims,
36
+ "dim": dim,
37
+ "num_heads": num_heads,
38
+ "bias": bias,
39
+ "flash_attention": False,
40
+ },
41
+ (2, dim, *([16] * spatial_dims)),
42
+ (2, dim, *([16] * spatial_dims)),
43
+ ]
44
+ TEST_CASES_CAB.append(test_case)
45
+
46
+
47
+ TEST_CASES_FEEDFORWARD = [
48
+ # Test different spatial dims, dimensions and expansion factors
49
+ [{"spatial_dims": 2, "dim": 64, "ffn_expansion_factor": 2.0, "bias": True}, (2, 64, 32, 32)],
50
+ [{"spatial_dims": 3, "dim": 128, "ffn_expansion_factor": 1.5, "bias": False}, (2, 128, 16, 16, 16)],
51
+ [{"spatial_dims": 2, "dim": 256, "ffn_expansion_factor": 1.0, "bias": True}, (1, 256, 64, 64)],
52
+ ]
53
+
54
+
55
+ class TestFeedForward(unittest.TestCase):
56
+
57
+ @parameterized.expand(TEST_CASES_FEEDFORWARD)
58
+ def test_shape(self, input_param, input_shape):
59
+ net = FeedForward(**input_param)
60
+ with eval_mode(net):
61
+ result = net(torch.randn(input_shape))
62
+ self.assertEqual(result.shape, input_shape)
63
+
64
+ def test_gating_mechanism(self):
65
+ net = FeedForward(spatial_dims=2, dim=32, ffn_expansion_factor=2.0, bias=True)
66
+ x = torch.ones(1, 32, 16, 16)
67
+ out = net(x)
68
+ self.assertNotEqual(torch.sum(out), torch.sum(x))
69
+
70
+
71
+ class TestCABlock(unittest.TestCase):
72
+
73
+ @parameterized.expand(TEST_CASES_CAB)
74
+ @skipUnless(has_einops, "Requires einops")
75
+ def test_shape(self, input_param, input_shape, expected_shape):
76
+ net = CABlock(**input_param)
77
+ with eval_mode(net):
78
+ result = net(torch.randn(input_shape))
79
+ self.assertEqual(result.shape, expected_shape)
80
+
81
+ @skipUnless(has_einops, "Requires einops")
82
+ def test_invalid_spatial_dims(self):
83
+ with self.assertRaises(ValueError):
84
+ CABlock(spatial_dims=4, dim=64, num_heads=4, bias=True)
85
+
86
+ @SkipIfBeforePyTorchVersion((2, 0))
87
+ @skipUnless(has_einops, "Requires einops")
88
+ def test_flash_attention(self):
89
+ device = "cuda" if torch.cuda.is_available() else "cpu"
90
+ block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device)
91
+ x = torch.randn(2, 64, 32, 32).to(device)
92
+ output = block(x)
93
+ self.assertEqual(output.shape, x.shape)
94
+
95
+ @skipUnless(has_einops, "Requires einops")
96
+ def test_temperature_parameter(self):
97
+ block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True)
98
+ self.assertTrue(isinstance(block.temperature, torch.nn.Parameter))
99
+ self.assertEqual(block.temperature.shape, (4, 1, 1))
100
+
101
+ @skipUnless(has_einops, "Requires einops")
102
+ def test_qkv_transformation_2d(self):
103
+ block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True)
104
+ x = torch.randn(2, 64, 32, 32)
105
+ qkv = block.qkv(x)
106
+ self.assertEqual(qkv.shape, (2, 192, 32, 32))
107
+
108
+ @skipUnless(has_einops, "Requires einops")
109
+ def test_qkv_transformation_3d(self):
110
+ block = CABlock(spatial_dims=3, dim=64, num_heads=4, bias=True)
111
+ x = torch.randn(2, 64, 16, 16, 16)
112
+ qkv = block.qkv(x)
113
+ self.assertEqual(qkv.shape, (2, 192, 16, 16, 16))
114
+
115
+ @SkipIfBeforePyTorchVersion((2, 0))
116
+ @skipUnless(has_einops, "Requires einops")
117
+ def test_flash_vs_normal_attention(self):
118
+ device = "cuda" if torch.cuda.is_available() else "cpu"
119
+ block_flash = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device)
120
+ block_normal = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=False).to(device)
121
+
122
+ block_normal.load_state_dict(block_flash.state_dict())
123
+
124
+ x = torch.randn(2, 64, 32, 32).to(device)
125
+ with torch.no_grad():
126
+ out_flash = block_flash(x)
127
+ out_normal = block_normal(x)
128
+
129
+ assert_allclose(out_flash, out_normal, atol=1e-4)
130
+
131
+ @skipUnless(has_einops, "Requires einops")
132
+ def test_deterministic_small_input(self):
133
+ block = CABlock(spatial_dims=2, dim=2, num_heads=1, bias=False)
134
+ with torch.no_grad():
135
+ block.qkv.conv.weight.data.fill_(1.0)
136
+ block.qkv_dwconv.conv.weight.data.fill_(1.0)
137
+ block.temperature.data.fill_(1.0)
138
+ block.project_out.conv.weight.data.fill_(1.0)
139
+
140
+ x = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]], dtype=torch.float32)
141
+
142
+ output = block(x)
143
+ # Channel attention: sum([1..8]) * (qkv_conv=1) * (dwconv=1) * (attn_weights=1) * (proj=1) = 36 * 2 = 72
144
+ expected = torch.full_like(x, 72.0)
145
+
146
+ assert_allclose(output, expected, atol=1e-6)
147
+
148
+
149
+ if __name__ == "__main__":
150
+ unittest.main()
@@ -0,0 +1,184 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import torch
17
+ from parameterized import parameterized
18
+
19
+ from monai.networks import eval_mode
20
+ from monai.networks.blocks import DownSample, MaxAvgPool, SubpixelDownsample, SubpixelUpsample
21
+ from monai.utils import optional_import
22
+
23
+ einops, has_einops = optional_import("einops")
24
+
25
+ TEST_CASES = [
26
+ [{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7
27
+ [{"spatial_dims": 1, "kernel_size": 4}, (16, 4, 63), (16, 8, 15)], # 4-channel 1D, batch 16
28
+ [{"spatial_dims": 1, "kernel_size": 4, "padding": 1}, (16, 4, 63), (16, 8, 16)], # 4-channel 1D, batch 16
29
+ [ # 4-channel 3D, batch 16
30
+ {"spatial_dims": 3, "kernel_size": 3, "ceil_mode": True},
31
+ (16, 4, 32, 24, 48),
32
+ (16, 8, 11, 8, 16),
33
+ ],
34
+ [ # 1-channel 3D, batch 16
35
+ {"spatial_dims": 3, "kernel_size": 3, "ceil_mode": False},
36
+ (16, 1, 32, 24, 48),
37
+ (16, 2, 10, 8, 16),
38
+ ],
39
+ ]
40
+
41
+ TEST_CASES_SUBPIXEL = [
42
+ [{"spatial_dims": 2, "in_channels": 1, "scale_factor": 2}, (1, 1, 8, 8), (1, 4, 4, 4)],
43
+ [{"spatial_dims": 3, "in_channels": 2, "scale_factor": 2}, (1, 2, 8, 8, 8), (1, 16, 4, 4, 4)],
44
+ [{"spatial_dims": 1, "in_channels": 3, "scale_factor": 2}, (1, 3, 8), (1, 6, 4)],
45
+ ]
46
+
47
+ TEST_CASES_DOWNSAMPLE = [
48
+ [{"spatial_dims": 2, "in_channels": 4, "mode": "conv"}, (1, 4, 16, 16), (1, 4, 8, 8)],
49
+ [{"spatial_dims": 2, "in_channels": 4, "out_channels": 8, "mode": "convgroup"}, (1, 4, 16, 16), (1, 8, 8, 8)],
50
+ [{"spatial_dims": 3, "in_channels": 2, "mode": "maxpool"}, (1, 2, 16, 16, 16), (1, 2, 8, 8, 8)],
51
+ [{"spatial_dims": 2, "in_channels": 4, "mode": "avgpool"}, (1, 4, 16, 16), (1, 4, 8, 8)],
52
+ [{"spatial_dims": 2, "in_channels": 1, "mode": "pixelunshuffle"}, (1, 1, 16, 16), (1, 4, 8, 8)],
53
+ ]
54
+
55
+
56
+ class TestMaxAvgPool(unittest.TestCase):
57
+
58
+ @parameterized.expand(TEST_CASES)
59
+ def test_shape(self, input_param, input_shape, expected_shape):
60
+ net = MaxAvgPool(**input_param)
61
+ with eval_mode(net):
62
+ result = net(torch.randn(input_shape))
63
+ self.assertEqual(result.shape, expected_shape)
64
+
65
+
66
+ class TestSubpixelDownsample(unittest.TestCase):
67
+
68
+ @parameterized.expand(TEST_CASES_SUBPIXEL)
69
+ def test_shape(self, input_param, input_shape, expected_shape):
70
+ downsampler = SubpixelDownsample(**input_param)
71
+ with eval_mode(downsampler):
72
+ result = downsampler(torch.randn(input_shape))
73
+ self.assertEqual(result.shape, expected_shape)
74
+
75
+ def test_predefined_tensor(self):
76
+ test_tensor = torch.arange(4).view(4, 1, 1).repeat(1, 4, 4)
77
+ test_tensor = test_tensor.unsqueeze(0)
78
+
79
+ downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None)
80
+ with eval_mode(downsampler):
81
+ result = downsampler(test_tensor)
82
+ self.assertEqual(result.shape, (1, 16, 2, 2))
83
+ self.assertTrue(torch.all(result[0, 0:3] == 0))
84
+ self.assertTrue(torch.all(result[0, 4:7] == 1))
85
+ self.assertTrue(torch.all(result[0, 8:11] == 2))
86
+ self.assertTrue(torch.all(result[0, 12:15] == 3))
87
+
88
+ def test_reconstruction_2d(self):
89
+ input_tensor = torch.randn(1, 1, 4, 4)
90
+ down = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None)
91
+ up = SubpixelUpsample(spatial_dims=2, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False)
92
+ with eval_mode(down), eval_mode(up):
93
+ downsampled = down(input_tensor)
94
+ reconstructed = up(downsampled)
95
+ self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5))
96
+
97
+ def test_reconstruction_3d(self):
98
+ input_tensor = torch.randn(1, 1, 4, 4, 4)
99
+ down = SubpixelDownsample(spatial_dims=3, in_channels=1, scale_factor=2, conv_block=None)
100
+ up = SubpixelUpsample(spatial_dims=3, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False)
101
+ with eval_mode(down), eval_mode(up):
102
+ downsampled = down(input_tensor)
103
+ reconstructed = up(downsampled)
104
+ self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5))
105
+
106
+ def test_invalid_spatial_size(self):
107
+ downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2)
108
+ with self.assertRaises(ValueError):
109
+ downsampler(torch.randn(1, 1, 3, 4))
110
+
111
+ def test_custom_conv_block(self):
112
+ custom_conv = torch.nn.Conv2d(1, 2, kernel_size=3, padding=1)
113
+ downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=custom_conv)
114
+ with eval_mode(downsampler):
115
+ result = downsampler(torch.randn(1, 1, 4, 4))
116
+ self.assertEqual(result.shape, (1, 8, 2, 2))
117
+
118
+
119
+ class TestDownSample(unittest.TestCase):
120
+ @parameterized.expand(TEST_CASES_DOWNSAMPLE)
121
+ def test_shape(self, input_param, input_shape, expected_shape):
122
+ net = DownSample(**input_param)
123
+ with eval_mode(net):
124
+ result = net(torch.randn(input_shape))
125
+ self.assertEqual(result.shape, expected_shape)
126
+
127
+ def test_pre_post_conv(self):
128
+ net = DownSample(
129
+ spatial_dims=2,
130
+ in_channels=4,
131
+ out_channels=8,
132
+ mode="maxpool",
133
+ pre_conv="default",
134
+ post_conv=torch.nn.Conv2d(8, 16, 1),
135
+ )
136
+ with eval_mode(net):
137
+ result = net(torch.randn(1, 4, 16, 16))
138
+ self.assertEqual(result.shape, (1, 16, 8, 8))
139
+
140
+ def test_pixelunshuffle_equivalence(self):
141
+ class DownSampleLocal(torch.nn.Module):
142
+ def __init__(self, n_feat: int):
143
+ super().__init__()
144
+ self.conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False)
145
+ self.pixelunshuffle = torch.nn.PixelUnshuffle(2)
146
+
147
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
148
+ x = self.conv(x)
149
+ x = self.pixelunshuffle(x)
150
+ return x
151
+
152
+ n_feat = 2
153
+ x = torch.randn(1, n_feat, 64, 64)
154
+
155
+ fix_weight_conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False)
156
+
157
+ monai_down = DownSample(
158
+ spatial_dims=2,
159
+ in_channels=n_feat,
160
+ out_channels=n_feat // 2,
161
+ mode="pixelunshuffle",
162
+ pre_conv=fix_weight_conv,
163
+ )
164
+
165
+ local_down = DownSampleLocal(n_feat)
166
+ local_down.conv.weight.data = fix_weight_conv.weight.data.clone()
167
+
168
+ with eval_mode(monai_down), eval_mode(local_down):
169
+ out_monai = monai_down(x)
170
+ out_local = local_down(x)
171
+
172
+ self.assertTrue(torch.allclose(out_monai, out_local, rtol=1e-5))
173
+
174
+ def test_invalid_mode(self):
175
+ with self.assertRaises(ValueError):
176
+ DownSample(spatial_dims=2, in_channels=4, mode="invalid")
177
+
178
+ def test_missing_channels(self):
179
+ with self.assertRaises(ValueError):
180
+ DownSample(spatial_dims=2, mode="conv")
181
+
182
+
183
+ if __name__ == "__main__":
184
+ unittest.main()
@@ -0,0 +1,147 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+ from unittest import skipUnless
16
+
17
+ import torch
18
+ from parameterized import parameterized
19
+
20
+ from monai.networks import eval_mode
21
+ from monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer
22
+ from monai.utils import optional_import
23
+
24
+ einops, has_einops = optional_import("einops")
25
+
26
+ TEST_CASES_TRANSFORMER = [
27
+ # [spatial_dims, dim, num_heads, ffn_factor, bias, layer_norm_use_bias, flash_attn, input_shape]
28
+ [2, 48, 8, 2.66, True, True, False, (2, 48, 64, 64)],
29
+ [2, 96, 8, 2.66, False, False, False, (2, 96, 32, 32)],
30
+ [3, 48, 4, 2.66, True, True, False, (2, 48, 32, 32, 32)],
31
+ [3, 96, 8, 2.66, False, False, True, (2, 96, 16, 16, 16)],
32
+ ]
33
+
34
+ TEST_CASES_PATCHEMBED = [
35
+ # spatial_dims, in_channels, embed_dim, input_shape, expected_shape
36
+ [2, 1, 48, (2, 1, 64, 64), (2, 48, 64, 64)],
37
+ [2, 3, 96, (2, 3, 32, 32), (2, 96, 32, 32)],
38
+ [3, 1, 48, (2, 1, 32, 32, 32), (2, 48, 32, 32, 32)],
39
+ [3, 4, 64, (2, 4, 16, 16, 16), (2, 64, 16, 16, 16)],
40
+ ]
41
+
42
+ RESTORMER_CONFIGS = [
43
+ # 2-level architecture
44
+ {"num_blocks": [1, 1], "heads": [1, 1]},
45
+ {"num_blocks": [2, 1], "heads": [2, 1]},
46
+ # 3-level architecture
47
+ {"num_blocks": [1, 1, 1], "heads": [1, 1, 1]},
48
+ {"num_blocks": [2, 1, 1], "heads": [2, 1, 1]},
49
+ ]
50
+
51
+ TEST_CASES_RESTORMER = []
52
+ for config in RESTORMER_CONFIGS:
53
+ # 2D cases
54
+ TEST_CASES_RESTORMER.extend(
55
+ [
56
+ [
57
+ {
58
+ "spatial_dims": 2,
59
+ "in_channels": 1,
60
+ "out_channels": 1,
61
+ "dim": 48,
62
+ "num_blocks": config["num_blocks"],
63
+ "heads": config["heads"],
64
+ "num_refinement_blocks": 2,
65
+ "ffn_expansion_factor": 1.5,
66
+ },
67
+ (2, 1, 64, 64),
68
+ (2, 1, 64, 64),
69
+ ],
70
+ # 3D cases
71
+ [
72
+ {
73
+ "spatial_dims": 3,
74
+ "in_channels": 1,
75
+ "out_channels": 1,
76
+ "dim": 16,
77
+ "num_blocks": config["num_blocks"],
78
+ "heads": config["heads"],
79
+ "num_refinement_blocks": 2,
80
+ "ffn_expansion_factor": 1.5,
81
+ },
82
+ (2, 1, 32, 32, 32),
83
+ (2, 1, 32, 32, 32),
84
+ ],
85
+ ]
86
+ )
87
+
88
+
89
+ class TestMDTATransformerBlock(unittest.TestCase):
90
+
91
+ @parameterized.expand(TEST_CASES_TRANSFORMER)
92
+ @skipUnless(has_einops, "Requires einops")
93
+ def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape):
94
+ if flash and not torch.cuda.is_available():
95
+ self.skipTest("Flash attention requires CUDA")
96
+ block = MDTATransformerBlock(
97
+ spatial_dims=spatial_dims,
98
+ dim=dim,
99
+ num_heads=heads,
100
+ ffn_expansion_factor=ffn_factor,
101
+ bias=bias,
102
+ layer_norm_use_bias=layer_norm_use_bias,
103
+ flash_attention=flash,
104
+ )
105
+ with eval_mode(block):
106
+ x = torch.randn(shape)
107
+ output = block(x)
108
+ self.assertEqual(output.shape, x.shape)
109
+
110
+
111
+ class TestOverlapPatchEmbed(unittest.TestCase):
112
+
113
+ @parameterized.expand(TEST_CASES_PATCHEMBED)
114
+ def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected_shape):
115
+ net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_channels=in_channels, embed_dim=embed_dim)
116
+ with eval_mode(net):
117
+ result = net(torch.randn(input_shape))
118
+ self.assertEqual(result.shape, expected_shape)
119
+
120
+
121
+ class TestRestormer(unittest.TestCase):
122
+
123
+ @parameterized.expand(TEST_CASES_RESTORMER)
124
+ @skipUnless(has_einops, "Requires einops")
125
+ def test_shape(self, input_param, input_shape, expected_shape):
126
+ if input_param.get("flash_attention", False) and not torch.cuda.is_available():
127
+ self.skipTest("Flash attention requires CUDA")
128
+ net = Restormer(**input_param)
129
+ with eval_mode(net):
130
+ result = net(torch.randn(input_shape))
131
+ self.assertEqual(result.shape, expected_shape)
132
+
133
+ @skipUnless(has_einops, "Requires einops")
134
+ def test_small_input_error_2d(self):
135
+ net = Restormer(spatial_dims=2, in_channels=1, out_channels=1)
136
+ with self.assertRaises(AssertionError):
137
+ net(torch.randn(1, 1, 8, 8))
138
+
139
+ @skipUnless(has_einops, "Requires einops")
140
+ def test_small_input_error_3d(self):
141
+ net = Restormer(spatial_dims=3, in_channels=1, out_channels=1)
142
+ with self.assertRaises(AssertionError):
143
+ net(torch.randn(1, 1, 8, 8, 8))
144
+
145
+
146
+ if __name__ == "__main__":
147
+ unittest.main()