monai-weekly 1.5.dev2510__py3-none-any.whl → 1.5.dev2512__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,7 +19,7 @@ from parameterized import parameterized
19
19
 
20
20
  from monai.inferers import LatentDiffusionInferer
21
21
  from monai.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet, SPADEAutoencoderKL, SPADEDiffusionModelUNet
22
- from monai.networks.schedulers import DDPMScheduler
22
+ from monai.networks.schedulers import DDPMScheduler, RFlowScheduler
23
23
  from monai.utils import optional_import
24
24
 
25
25
  _, has_einops = optional_import("einops")
@@ -339,31 +339,32 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
339
339
 
340
340
  input = torch.randn(input_shape).to(device)
341
341
  noise = torch.randn(latent_shape).to(device)
342
- scheduler = DDPMScheduler(num_train_timesteps=10)
343
- inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
344
- scheduler.set_timesteps(num_inference_steps=10)
345
- timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
346
342
 
347
- if dm_model_type == "SPADEDiffusionModelUNet":
348
- input_shape_seg = list(input_shape)
349
- if "label_nc" in stage_2_params.keys():
350
- input_shape_seg[1] = stage_2_params["label_nc"]
343
+ for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
344
+ inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
345
+ scheduler.set_timesteps(num_inference_steps=10)
346
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
347
+
348
+ if dm_model_type == "SPADEDiffusionModelUNet":
349
+ input_shape_seg = list(input_shape)
350
+ if "label_nc" in stage_2_params.keys():
351
+ input_shape_seg[1] = stage_2_params["label_nc"]
352
+ else:
353
+ input_shape_seg[1] = autoencoder_params["label_nc"]
354
+ input_seg = torch.randn(input_shape_seg).to(device)
355
+ prediction = inferer(
356
+ inputs=input,
357
+ autoencoder_model=stage_1,
358
+ diffusion_model=stage_2,
359
+ seg=input_seg,
360
+ noise=noise,
361
+ timesteps=timesteps,
362
+ )
351
363
  else:
352
- input_shape_seg[1] = autoencoder_params["label_nc"]
353
- input_seg = torch.randn(input_shape_seg).to(device)
354
- prediction = inferer(
355
- inputs=input,
356
- autoencoder_model=stage_1,
357
- diffusion_model=stage_2,
358
- seg=input_seg,
359
- noise=noise,
360
- timesteps=timesteps,
361
- )
362
- else:
363
- prediction = inferer(
364
- inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps
365
- )
366
- self.assertEqual(prediction.shape, latent_shape)
364
+ prediction = inferer(
365
+ inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps
366
+ )
367
+ self.assertEqual(prediction.shape, latent_shape)
367
368
 
368
369
  @parameterized.expand(TEST_CASES)
369
370
  @skipUnless(has_einops, "Requires einops")
@@ -388,29 +389,30 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
388
389
  stage_2.eval()
389
390
 
390
391
  noise = torch.randn(latent_shape).to(device)
391
- scheduler = DDPMScheduler(num_train_timesteps=10)
392
- inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
393
- scheduler.set_timesteps(num_inference_steps=10)
394
392
 
395
- if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
396
- input_shape_seg = list(input_shape)
397
- if "label_nc" in stage_2_params.keys():
398
- input_shape_seg[1] = stage_2_params["label_nc"]
393
+ for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
394
+ inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
395
+ scheduler.set_timesteps(num_inference_steps=10)
396
+
397
+ if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
398
+ input_shape_seg = list(input_shape)
399
+ if "label_nc" in stage_2_params.keys():
400
+ input_shape_seg[1] = stage_2_params["label_nc"]
401
+ else:
402
+ input_shape_seg[1] = autoencoder_params["label_nc"]
403
+ input_seg = torch.randn(input_shape_seg).to(device)
404
+ sample = inferer.sample(
405
+ input_noise=noise,
406
+ autoencoder_model=stage_1,
407
+ diffusion_model=stage_2,
408
+ scheduler=scheduler,
409
+ seg=input_seg,
410
+ )
399
411
  else:
400
- input_shape_seg[1] = autoencoder_params["label_nc"]
401
- input_seg = torch.randn(input_shape_seg).to(device)
402
- sample = inferer.sample(
403
- input_noise=noise,
404
- autoencoder_model=stage_1,
405
- diffusion_model=stage_2,
406
- scheduler=scheduler,
407
- seg=input_seg,
408
- )
409
- else:
410
- sample = inferer.sample(
411
- input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler
412
- )
413
- self.assertEqual(sample.shape, input_shape)
412
+ sample = inferer.sample(
413
+ input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler
414
+ )
415
+ self.assertEqual(sample.shape, input_shape)
414
416
 
415
417
  @parameterized.expand(TEST_CASES)
416
418
  @skipUnless(has_einops, "Requires einops")
@@ -437,37 +439,38 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
437
439
  stage_2.eval()
438
440
 
439
441
  noise = torch.randn(latent_shape).to(device)
440
- scheduler = DDPMScheduler(num_train_timesteps=10)
441
- inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
442
- scheduler.set_timesteps(num_inference_steps=10)
443
442
 
444
- if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
445
- input_shape_seg = list(input_shape)
446
- if "label_nc" in stage_2_params.keys():
447
- input_shape_seg[1] = stage_2_params["label_nc"]
443
+ for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
444
+ inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
445
+ scheduler.set_timesteps(num_inference_steps=10)
446
+
447
+ if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
448
+ input_shape_seg = list(input_shape)
449
+ if "label_nc" in stage_2_params.keys():
450
+ input_shape_seg[1] = stage_2_params["label_nc"]
451
+ else:
452
+ input_shape_seg[1] = autoencoder_params["label_nc"]
453
+ input_seg = torch.randn(input_shape_seg).to(device)
454
+ sample, intermediates = inferer.sample(
455
+ input_noise=noise,
456
+ autoencoder_model=stage_1,
457
+ diffusion_model=stage_2,
458
+ scheduler=scheduler,
459
+ seg=input_seg,
460
+ save_intermediates=True,
461
+ intermediate_steps=1,
462
+ )
448
463
  else:
449
- input_shape_seg[1] = autoencoder_params["label_nc"]
450
- input_seg = torch.randn(input_shape_seg).to(device)
451
- sample, intermediates = inferer.sample(
452
- input_noise=noise,
453
- autoencoder_model=stage_1,
454
- diffusion_model=stage_2,
455
- scheduler=scheduler,
456
- seg=input_seg,
457
- save_intermediates=True,
458
- intermediate_steps=1,
459
- )
460
- else:
461
- sample, intermediates = inferer.sample(
462
- input_noise=noise,
463
- autoencoder_model=stage_1,
464
- diffusion_model=stage_2,
465
- scheduler=scheduler,
466
- save_intermediates=True,
467
- intermediate_steps=1,
468
- )
469
- self.assertEqual(len(intermediates), 10)
470
- self.assertEqual(intermediates[0].shape, input_shape)
464
+ sample, intermediates = inferer.sample(
465
+ input_noise=noise,
466
+ autoencoder_model=stage_1,
467
+ diffusion_model=stage_2,
468
+ scheduler=scheduler,
469
+ save_intermediates=True,
470
+ intermediate_steps=1,
471
+ )
472
+ self.assertEqual(len(intermediates), 10)
473
+ self.assertEqual(intermediates[0].shape, input_shape)
471
474
 
472
475
  @parameterized.expand(TEST_CASES)
473
476
  @skipUnless(has_einops, "Requires einops")
@@ -614,40 +617,40 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
614
617
  conditioning_shape[1] = n_concat_channel
615
618
  conditioning = torch.randn(conditioning_shape).to(device)
616
619
 
617
- scheduler = DDPMScheduler(num_train_timesteps=10)
618
- inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
619
- scheduler.set_timesteps(num_inference_steps=10)
620
-
621
- timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
622
-
623
- if dm_model_type == "SPADEDiffusionModelUNet":
624
- input_shape_seg = list(input_shape)
625
- if "label_nc" in stage_2_params.keys():
626
- input_shape_seg[1] = stage_2_params["label_nc"]
620
+ for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
621
+ inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
622
+ scheduler.set_timesteps(num_inference_steps=10)
623
+
624
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
625
+
626
+ if dm_model_type == "SPADEDiffusionModelUNet":
627
+ input_shape_seg = list(input_shape)
628
+ if "label_nc" in stage_2_params.keys():
629
+ input_shape_seg[1] = stage_2_params["label_nc"]
630
+ else:
631
+ input_shape_seg[1] = autoencoder_params["label_nc"]
632
+ input_seg = torch.randn(input_shape_seg).to(device)
633
+ prediction = inferer(
634
+ inputs=input,
635
+ autoencoder_model=stage_1,
636
+ diffusion_model=stage_2,
637
+ noise=noise,
638
+ timesteps=timesteps,
639
+ condition=conditioning,
640
+ mode="concat",
641
+ seg=input_seg,
642
+ )
627
643
  else:
628
- input_shape_seg[1] = autoencoder_params["label_nc"]
629
- input_seg = torch.randn(input_shape_seg).to(device)
630
- prediction = inferer(
631
- inputs=input,
632
- autoencoder_model=stage_1,
633
- diffusion_model=stage_2,
634
- noise=noise,
635
- timesteps=timesteps,
636
- condition=conditioning,
637
- mode="concat",
638
- seg=input_seg,
639
- )
640
- else:
641
- prediction = inferer(
642
- inputs=input,
643
- autoencoder_model=stage_1,
644
- diffusion_model=stage_2,
645
- noise=noise,
646
- timesteps=timesteps,
647
- condition=conditioning,
648
- mode="concat",
649
- )
650
- self.assertEqual(prediction.shape, latent_shape)
644
+ prediction = inferer(
645
+ inputs=input,
646
+ autoencoder_model=stage_1,
647
+ diffusion_model=stage_2,
648
+ noise=noise,
649
+ timesteps=timesteps,
650
+ condition=conditioning,
651
+ mode="concat",
652
+ )
653
+ self.assertEqual(prediction.shape, latent_shape)
651
654
 
652
655
  @parameterized.expand(TEST_CASES)
653
656
  @skipUnless(has_einops, "Requires einops")
@@ -681,36 +684,36 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
681
684
  conditioning_shape[1] = n_concat_channel
682
685
  conditioning = torch.randn(conditioning_shape).to(device)
683
686
 
684
- scheduler = DDPMScheduler(num_train_timesteps=10)
685
- inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
686
- scheduler.set_timesteps(num_inference_steps=10)
687
-
688
- if dm_model_type == "SPADEDiffusionModelUNet":
689
- input_shape_seg = list(input_shape)
690
- if "label_nc" in stage_2_params.keys():
691
- input_shape_seg[1] = stage_2_params["label_nc"]
687
+ for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
688
+ inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
689
+ scheduler.set_timesteps(num_inference_steps=10)
690
+
691
+ if dm_model_type == "SPADEDiffusionModelUNet":
692
+ input_shape_seg = list(input_shape)
693
+ if "label_nc" in stage_2_params.keys():
694
+ input_shape_seg[1] = stage_2_params["label_nc"]
695
+ else:
696
+ input_shape_seg[1] = autoencoder_params["label_nc"]
697
+ input_seg = torch.randn(input_shape_seg).to(device)
698
+ sample = inferer.sample(
699
+ input_noise=noise,
700
+ autoencoder_model=stage_1,
701
+ diffusion_model=stage_2,
702
+ scheduler=scheduler,
703
+ conditioning=conditioning,
704
+ mode="concat",
705
+ seg=input_seg,
706
+ )
692
707
  else:
693
- input_shape_seg[1] = autoencoder_params["label_nc"]
694
- input_seg = torch.randn(input_shape_seg).to(device)
695
- sample = inferer.sample(
696
- input_noise=noise,
697
- autoencoder_model=stage_1,
698
- diffusion_model=stage_2,
699
- scheduler=scheduler,
700
- conditioning=conditioning,
701
- mode="concat",
702
- seg=input_seg,
703
- )
704
- else:
705
- sample = inferer.sample(
706
- input_noise=noise,
707
- autoencoder_model=stage_1,
708
- diffusion_model=stage_2,
709
- scheduler=scheduler,
710
- conditioning=conditioning,
711
- mode="concat",
712
- )
713
- self.assertEqual(sample.shape, input_shape)
708
+ sample = inferer.sample(
709
+ input_noise=noise,
710
+ autoencoder_model=stage_1,
711
+ diffusion_model=stage_2,
712
+ scheduler=scheduler,
713
+ conditioning=conditioning,
714
+ mode="concat",
715
+ )
716
+ self.assertEqual(sample.shape, input_shape)
714
717
 
715
718
  @parameterized.expand(TEST_CASES_DIFF_SHAPES)
716
719
  @skipUnless(has_einops, "Requires einops")
@@ -738,39 +741,39 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
738
741
 
739
742
  input = torch.randn(input_shape).to(device)
740
743
  noise = torch.randn(latent_shape).to(device)
741
- scheduler = DDPMScheduler(num_train_timesteps=10)
742
- # We infer the VAE shape
743
- autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
744
- inferer = LatentDiffusionInferer(
745
- scheduler=scheduler,
746
- scale_factor=1.0,
747
- ldm_latent_shape=list(latent_shape[2:]),
748
- autoencoder_latent_shape=autoencoder_latent_shape,
749
- )
750
- scheduler.set_timesteps(num_inference_steps=10)
751
-
752
- timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
753
-
754
- if dm_model_type == "SPADEDiffusionModelUNet":
755
- input_shape_seg = list(input_shape)
756
- if "label_nc" in stage_2_params.keys():
757
- input_shape_seg[1] = stage_2_params["label_nc"]
758
- else:
759
- input_shape_seg[1] = autoencoder_params["label_nc"]
760
- input_seg = torch.randn(input_shape_seg).to(device)
761
- prediction = inferer(
762
- inputs=input,
763
- autoencoder_model=stage_1,
764
- diffusion_model=stage_2,
765
- noise=noise,
766
- timesteps=timesteps,
767
- seg=input_seg,
768
- )
769
- else:
770
- prediction = inferer(
771
- inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps
744
+ for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
745
+ # We infer the VAE shape
746
+ autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
747
+ inferer = LatentDiffusionInferer(
748
+ scheduler=scheduler,
749
+ scale_factor=1.0,
750
+ ldm_latent_shape=list(latent_shape[2:]),
751
+ autoencoder_latent_shape=autoencoder_latent_shape,
772
752
  )
773
- self.assertEqual(prediction.shape, latent_shape)
753
+ scheduler.set_timesteps(num_inference_steps=10)
754
+
755
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
756
+
757
+ if dm_model_type == "SPADEDiffusionModelUNet":
758
+ input_shape_seg = list(input_shape)
759
+ if "label_nc" in stage_2_params.keys():
760
+ input_shape_seg[1] = stage_2_params["label_nc"]
761
+ else:
762
+ input_shape_seg[1] = autoencoder_params["label_nc"]
763
+ input_seg = torch.randn(input_shape_seg).to(device)
764
+ prediction = inferer(
765
+ inputs=input,
766
+ autoencoder_model=stage_1,
767
+ diffusion_model=stage_2,
768
+ noise=noise,
769
+ timesteps=timesteps,
770
+ seg=input_seg,
771
+ )
772
+ else:
773
+ prediction = inferer(
774
+ inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps
775
+ )
776
+ self.assertEqual(prediction.shape, latent_shape)
774
777
 
775
778
  @parameterized.expand(TEST_CASES_DIFF_SHAPES)
776
779
  @skipUnless(has_einops, "Requires einops")
@@ -797,40 +800,42 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
797
800
  stage_2.eval()
798
801
 
799
802
  noise = torch.randn(latent_shape).to(device)
800
- scheduler = DDPMScheduler(num_train_timesteps=10)
801
- # We infer the VAE shape
802
- if ae_model_type == "VQVAE":
803
- autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
804
- else:
805
- autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
806
-
807
- inferer = LatentDiffusionInferer(
808
- scheduler=scheduler,
809
- scale_factor=1.0,
810
- ldm_latent_shape=list(latent_shape[2:]),
811
- autoencoder_latent_shape=autoencoder_latent_shape,
812
- )
813
- scheduler.set_timesteps(num_inference_steps=10)
814
-
815
- if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
816
- input_shape_seg = list(input_shape)
817
- if "label_nc" in stage_2_params.keys():
818
- input_shape_seg[1] = stage_2_params["label_nc"]
803
+ for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
804
+ # We infer the VAE shape
805
+ if ae_model_type == "VQVAE":
806
+ autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
819
807
  else:
820
- input_shape_seg[1] = autoencoder_params["label_nc"]
821
- input_seg = torch.randn(input_shape_seg).to(device)
822
- prediction, _ = inferer.sample(
823
- autoencoder_model=stage_1,
824
- diffusion_model=stage_2,
825
- input_noise=noise,
826
- save_intermediates=True,
827
- seg=input_seg,
828
- )
829
- else:
830
- prediction = inferer.sample(
831
- autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False
808
+ autoencoder_latent_shape = [
809
+ i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]
810
+ ]
811
+
812
+ inferer = LatentDiffusionInferer(
813
+ scheduler=scheduler,
814
+ scale_factor=1.0,
815
+ ldm_latent_shape=list(latent_shape[2:]),
816
+ autoencoder_latent_shape=autoencoder_latent_shape,
832
817
  )
833
- self.assertEqual(prediction.shape, input_shape)
818
+ scheduler.set_timesteps(num_inference_steps=10)
819
+
820
+ if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
821
+ input_shape_seg = list(input_shape)
822
+ if "label_nc" in stage_2_params.keys():
823
+ input_shape_seg[1] = stage_2_params["label_nc"]
824
+ else:
825
+ input_shape_seg[1] = autoencoder_params["label_nc"]
826
+ input_seg = torch.randn(input_shape_seg).to(device)
827
+ prediction, _ = inferer.sample(
828
+ autoencoder_model=stage_1,
829
+ diffusion_model=stage_2,
830
+ input_noise=noise,
831
+ save_intermediates=True,
832
+ seg=input_seg,
833
+ )
834
+ else:
835
+ prediction = inferer.sample(
836
+ autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False
837
+ )
838
+ self.assertEqual(prediction.shape, input_shape)
834
839
 
835
840
  @skipUnless(has_einops, "Requires einops")
836
841
  def test_incompatible_spade_setup(self):
@@ -866,18 +871,19 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
866
871
  stage_2.eval()
867
872
  noise = torch.randn((1, 3, 4, 4)).to(device)
868
873
  input_seg = torch.randn((1, 3, 8, 8)).to(device)
869
- scheduler = DDPMScheduler(num_train_timesteps=10)
870
- inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
871
- scheduler.set_timesteps(num_inference_steps=10)
872
874
 
873
- with self.assertRaises(ValueError):
874
- _ = inferer.sample(
875
- input_noise=noise,
876
- autoencoder_model=stage_1,
877
- diffusion_model=stage_2,
878
- scheduler=scheduler,
879
- seg=input_seg,
880
- )
875
+ for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
876
+ inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
877
+ scheduler.set_timesteps(num_inference_steps=10)
878
+
879
+ with self.assertRaises(ValueError):
880
+ _ = inferer.sample(
881
+ input_noise=noise,
882
+ autoencoder_model=stage_1,
883
+ diffusion_model=stage_2,
884
+ scheduler=scheduler,
885
+ seg=input_seg,
886
+ )
881
887
 
882
888
 
883
889
  if __name__ == "__main__":
@@ -0,0 +1,105 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import unittest
15
+
16
+ import torch
17
+ from parameterized import parameterized
18
+
19
+ from monai.networks.schedulers import RFlowScheduler
20
+ from tests.test_utils import assert_allclose
21
+
22
+ TEST_2D_CASE = []
23
+ for sample_method in ["uniform", "logit-normal"]:
24
+ TEST_2D_CASE.append(
25
+ [{"sample_method": sample_method, "use_timestep_transform": False}, (2, 6, 16, 16), (2, 6, 16, 16)]
26
+ )
27
+
28
+ for sample_method in ["uniform", "logit-normal"]:
29
+ TEST_2D_CASE.append(
30
+ [
31
+ {"sample_method": sample_method, "use_timestep_transform": True, "spatial_dim": 2},
32
+ (2, 6, 16, 16),
33
+ (2, 6, 16, 16),
34
+ ]
35
+ )
36
+
37
+
38
+ TEST_3D_CASE = []
39
+ for sample_method in ["uniform", "logit-normal"]:
40
+ TEST_3D_CASE.append(
41
+ [{"sample_method": sample_method, "use_timestep_transform": False}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]
42
+ )
43
+
44
+ for sample_method in ["uniform", "logit-normal"]:
45
+ TEST_3D_CASE.append(
46
+ [
47
+ {"sample_method": sample_method, "use_timestep_transform": True, "spatial_dim": 3},
48
+ (2, 6, 16, 16, 16),
49
+ (2, 6, 16, 16, 16),
50
+ ]
51
+ )
52
+
53
+ TEST_CASES = TEST_2D_CASE + TEST_3D_CASE
54
+
55
+ TEST_FULl_LOOP = [
56
+ [{"sample_method": "uniform"}, (1, 1, 2, 2), torch.Tensor([[[[-0.786166, -0.057519], [2.442662, -0.407664]]]])]
57
+ ]
58
+
59
+
60
+ class TestRFlowScheduler(unittest.TestCase):
61
+ @parameterized.expand(TEST_CASES)
62
+ def test_add_noise(self, input_param, input_shape, expected_shape):
63
+ scheduler = RFlowScheduler(**input_param)
64
+ original_sample = torch.zeros(input_shape)
65
+ timesteps = scheduler.sample_timesteps(original_sample)
66
+ noise = torch.randn_like(original_sample)
67
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long()
68
+ noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps)
69
+ self.assertEqual(noisy.shape, expected_shape)
70
+
71
+ @parameterized.expand(TEST_CASES)
72
+ def test_step_shape(self, input_param, input_shape, expected_shape):
73
+ scheduler = RFlowScheduler(**input_param)
74
+ model_output = torch.randn(input_shape)
75
+ sample = torch.randn(input_shape)
76
+ scheduler.set_timesteps(num_inference_steps=100, input_img_size_numel=torch.numel(sample[0, 0, ...]))
77
+ output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)
78
+ self.assertEqual(output_step[0].shape, expected_shape)
79
+ self.assertEqual(output_step[1].shape, expected_shape)
80
+
81
+ @parameterized.expand(TEST_FULl_LOOP)
82
+ def test_full_timestep_loop(self, input_param, input_shape, expected_output):
83
+ scheduler = RFlowScheduler(**input_param)
84
+ torch.manual_seed(42)
85
+ model_output = torch.randn(input_shape)
86
+ sample = torch.randn(input_shape)
87
+ scheduler.set_timesteps(50, input_img_size_numel=torch.numel(sample[0, 0, ...]))
88
+ for t in range(50):
89
+ sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample)
90
+ assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3)
91
+
92
+ def test_set_timesteps(self):
93
+ scheduler = RFlowScheduler(num_train_timesteps=1000)
94
+ scheduler.set_timesteps(num_inference_steps=100, input_img_size_numel=16 * 16 * 16)
95
+ self.assertEqual(scheduler.num_inference_steps, 100)
96
+ self.assertEqual(len(scheduler.timesteps), 100)
97
+
98
+ def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):
99
+ scheduler = RFlowScheduler(num_train_timesteps=1000)
100
+ with self.assertRaises(ValueError):
101
+ scheduler.set_timesteps(num_inference_steps=2000, input_img_size_numel=16 * 16 * 16)
102
+
103
+
104
+ if __name__ == "__main__":
105
+ unittest.main()