monai-weekly 1.5.dev2510__py3-none-any.whl → 1.5.dev2511__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,7 +26,7 @@ from monai.networks.nets import (
26
26
  SPADEAutoencoderKL,
27
27
  SPADEDiffusionModelUNet,
28
28
  )
29
- from monai.networks.schedulers import DDIMScheduler, DDPMScheduler
29
+ from monai.networks.schedulers import DDIMScheduler, DDPMScheduler, RFlowScheduler
30
30
  from monai.utils import optional_import
31
31
 
32
32
  _, has_scipy = optional_import("scipy")
@@ -545,6 +545,32 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
545
545
  )
546
546
  self.assertEqual(len(intermediates), 10)
547
547
 
548
+ @parameterized.expand(CNDM_TEST_CASES)
549
+ @skipUnless(has_einops, "Requires einops")
550
+ def test_rflow_sampler(self, model_params, controlnet_params, input_shape):
551
+ model = DiffusionModelUNet(**model_params)
552
+ controlnet = ControlNet(**controlnet_params)
553
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
554
+ model.to(device)
555
+ model.eval()
556
+ controlnet.to(device)
557
+ controlnet.eval()
558
+ mask = torch.randn(input_shape).to(device)
559
+ noise = torch.randn(input_shape).to(device)
560
+ scheduler = RFlowScheduler(num_train_timesteps=1000)
561
+ inferer = ControlNetDiffusionInferer(scheduler=scheduler)
562
+ scheduler.set_timesteps(num_inference_steps=10)
563
+ sample, intermediates = inferer.sample(
564
+ input_noise=noise,
565
+ diffusion_model=model,
566
+ scheduler=scheduler,
567
+ controlnet=controlnet,
568
+ cn_cond=mask,
569
+ save_intermediates=True,
570
+ intermediate_steps=1,
571
+ )
572
+ self.assertEqual(len(intermediates), 10)
573
+
548
574
  @parameterized.expand(CNDM_TEST_CASES)
549
575
  @skipUnless(has_einops, "Requires einops")
550
576
  def test_sampler_conditioned(self, model_params, controlnet_params, input_shape):
@@ -561,6 +587,8 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
561
587
  controlnet.eval()
562
588
  mask = torch.randn(input_shape).to(device)
563
589
  noise = torch.randn(input_shape).to(device)
590
+
591
+ # DDIM
564
592
  scheduler = DDIMScheduler(num_train_timesteps=1000)
565
593
  inferer = ControlNetDiffusionInferer(scheduler=scheduler)
566
594
  scheduler.set_timesteps(num_inference_steps=10)
@@ -577,6 +605,23 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
577
605
  )
578
606
  self.assertEqual(len(intermediates), 10)
579
607
 
608
+ # RFlow
609
+ scheduler = RFlowScheduler(num_train_timesteps=1000)
610
+ inferer = ControlNetDiffusionInferer(scheduler=scheduler)
611
+ scheduler.set_timesteps(num_inference_steps=10)
612
+ conditioning = torch.randn([input_shape[0], 1, 3]).to(device)
613
+ sample, intermediates = inferer.sample(
614
+ input_noise=noise,
615
+ diffusion_model=model,
616
+ controlnet=controlnet,
617
+ cn_cond=mask,
618
+ scheduler=scheduler,
619
+ save_intermediates=True,
620
+ intermediate_steps=1,
621
+ conditioning=conditioning,
622
+ )
623
+ self.assertEqual(len(intermediates), 10)
624
+
580
625
  @parameterized.expand(CNDM_TEST_CASES)
581
626
  @skipUnless(has_einops, "Requires einops")
582
627
  def test_get_likelihood(self, model_params, controlnet_params, input_shape):
@@ -638,6 +683,8 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
638
683
  conditioning_shape = list(input_shape)
639
684
  conditioning_shape[1] = n_concat_channel
640
685
  conditioning = torch.randn(conditioning_shape).to(device)
686
+
687
+ # DDIM
641
688
  scheduler = DDIMScheduler(num_train_timesteps=1000)
642
689
  inferer = ControlNetDiffusionInferer(scheduler=scheduler)
643
690
  scheduler.set_timesteps(num_inference_steps=10)
@@ -654,6 +701,23 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
654
701
  )
655
702
  self.assertEqual(len(intermediates), 10)
656
703
 
704
+ # RFlow
705
+ scheduler = RFlowScheduler(num_train_timesteps=1000)
706
+ inferer = ControlNetDiffusionInferer(scheduler=scheduler)
707
+ scheduler.set_timesteps(num_inference_steps=10)
708
+ sample, intermediates = inferer.sample(
709
+ input_noise=noise,
710
+ diffusion_model=model,
711
+ controlnet=controlnet,
712
+ cn_cond=mask,
713
+ scheduler=scheduler,
714
+ save_intermediates=True,
715
+ intermediate_steps=1,
716
+ conditioning=conditioning,
717
+ mode="concat",
718
+ )
719
+ self.assertEqual(len(intermediates), 10)
720
+
657
721
 
658
722
  class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
659
723
  @parameterized.expand(LATENT_CNDM_TEST_CASES)
@@ -691,39 +755,39 @@ class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
691
755
  input = torch.randn(input_shape).to(device)
692
756
  mask = torch.randn(input_shape).to(device)
693
757
  noise = torch.randn(latent_shape).to(device)
694
- scheduler = DDPMScheduler(num_train_timesteps=10)
695
- inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
696
- scheduler.set_timesteps(num_inference_steps=10)
697
- timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
698
758
 
699
- if dm_model_type == "SPADEDiffusionModelUNet":
700
- input_shape_seg = list(input_shape)
701
- if "label_nc" in stage_2_params.keys():
702
- input_shape_seg[1] = stage_2_params["label_nc"]
759
+ for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
760
+ inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
761
+ scheduler.set_timesteps(num_inference_steps=10)
762
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
763
+ if dm_model_type == "SPADEDiffusionModelUNet":
764
+ input_shape_seg = list(input_shape)
765
+ if "label_nc" in stage_2_params.keys():
766
+ input_shape_seg[1] = stage_2_params["label_nc"]
767
+ else:
768
+ input_shape_seg[1] = autoencoder_params["label_nc"]
769
+ input_seg = torch.randn(input_shape_seg).to(device)
770
+ prediction = inferer(
771
+ inputs=input,
772
+ autoencoder_model=stage_1,
773
+ diffusion_model=stage_2,
774
+ controlnet=controlnet,
775
+ cn_cond=mask,
776
+ seg=input_seg,
777
+ noise=noise,
778
+ timesteps=timesteps,
779
+ )
703
780
  else:
704
- input_shape_seg[1] = autoencoder_params["label_nc"]
705
- input_seg = torch.randn(input_shape_seg).to(device)
706
- prediction = inferer(
707
- inputs=input,
708
- autoencoder_model=stage_1,
709
- diffusion_model=stage_2,
710
- controlnet=controlnet,
711
- cn_cond=mask,
712
- seg=input_seg,
713
- noise=noise,
714
- timesteps=timesteps,
715
- )
716
- else:
717
- prediction = inferer(
718
- inputs=input,
719
- autoencoder_model=stage_1,
720
- diffusion_model=stage_2,
721
- noise=noise,
722
- timesteps=timesteps,
723
- controlnet=controlnet,
724
- cn_cond=mask,
725
- )
726
- self.assertEqual(prediction.shape, latent_shape)
781
+ prediction = inferer(
782
+ inputs=input,
783
+ autoencoder_model=stage_1,
784
+ diffusion_model=stage_2,
785
+ noise=noise,
786
+ timesteps=timesteps,
787
+ controlnet=controlnet,
788
+ cn_cond=mask,
789
+ )
790
+ self.assertEqual(prediction.shape, latent_shape)
727
791
 
728
792
  @parameterized.expand(LATENT_CNDM_TEST_CASES)
729
793
  @skipUnless(has_einops, "Requires einops")
@@ -19,7 +19,7 @@ from parameterized import parameterized
19
19
 
20
20
  from monai.inferers import DiffusionInferer
21
21
  from monai.networks.nets import DiffusionModelUNet
22
- from monai.networks.schedulers import DDIMScheduler, DDPMScheduler
22
+ from monai.networks.schedulers import DDIMScheduler, DDPMScheduler, RFlowScheduler
23
23
  from monai.utils import optional_import
24
24
 
25
25
  _, has_scipy = optional_import("scipy")
@@ -120,6 +120,22 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
120
120
  )
121
121
  self.assertEqual(len(intermediates), 10)
122
122
 
123
+ @parameterized.expand(TEST_CASES)
124
+ @skipUnless(has_einops, "Requires einops")
125
+ def test_rflow_sampler(self, model_params, input_shape):
126
+ model = DiffusionModelUNet(**model_params)
127
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
128
+ model.to(device)
129
+ model.eval()
130
+ noise = torch.randn(input_shape).to(device)
131
+ scheduler = RFlowScheduler(num_train_timesteps=1000)
132
+ inferer = DiffusionInferer(scheduler=scheduler)
133
+ scheduler.set_timesteps(num_inference_steps=10)
134
+ sample, intermediates = inferer.sample(
135
+ input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1
136
+ )
137
+ self.assertEqual(len(intermediates), 10)
138
+
123
139
  @parameterized.expand(TEST_CASES)
124
140
  @skipUnless(has_einops, "Requires einops")
125
141
  def test_sampler_conditioned(self, model_params, input_shape):
@@ -144,6 +160,30 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
144
160
  )
145
161
  self.assertEqual(len(intermediates), 10)
146
162
 
163
+ @parameterized.expand(TEST_CASES)
164
+ @skipUnless(has_einops, "Requires einops")
165
+ def test_sampler_conditioned_rflow(self, model_params, input_shape):
166
+ model_params["with_conditioning"] = True
167
+ model_params["cross_attention_dim"] = 3
168
+ model = DiffusionModelUNet(**model_params)
169
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
170
+ model.to(device)
171
+ model.eval()
172
+ noise = torch.randn(input_shape).to(device)
173
+ scheduler = RFlowScheduler(num_train_timesteps=1000)
174
+ inferer = DiffusionInferer(scheduler=scheduler)
175
+ scheduler.set_timesteps(num_inference_steps=10)
176
+ conditioning = torch.randn([input_shape[0], 1, 3]).to(device)
177
+ sample, intermediates = inferer.sample(
178
+ input_noise=noise,
179
+ diffusion_model=model,
180
+ scheduler=scheduler,
181
+ save_intermediates=True,
182
+ intermediate_steps=1,
183
+ conditioning=conditioning,
184
+ )
185
+ self.assertEqual(len(intermediates), 10)
186
+
147
187
  @parameterized.expand(TEST_CASES)
148
188
  @skipUnless(has_einops, "Requires einops")
149
189
  def test_get_likelihood(self, model_params, input_shape):
@@ -204,6 +244,37 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
204
244
  )
205
245
  self.assertEqual(len(intermediates), 10)
206
246
 
247
+ @parameterized.expand(TEST_CASES)
248
+ @skipUnless(has_einops, "Requires einops")
249
+ def test_sampler_conditioned_concat_rflow(self, model_params, input_shape):
250
+ # copy the model_params dict to prevent from modifying test cases
251
+ model_params = model_params.copy()
252
+ n_concat_channel = 2
253
+ model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
254
+ model_params["cross_attention_dim"] = None
255
+ model_params["with_conditioning"] = False
256
+ model = DiffusionModelUNet(**model_params)
257
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
258
+ model.to(device)
259
+ model.eval()
260
+ noise = torch.randn(input_shape).to(device)
261
+ conditioning_shape = list(input_shape)
262
+ conditioning_shape[1] = n_concat_channel
263
+ conditioning = torch.randn(conditioning_shape).to(device)
264
+ scheduler = RFlowScheduler(num_train_timesteps=1000)
265
+ inferer = DiffusionInferer(scheduler=scheduler)
266
+ scheduler.set_timesteps(num_inference_steps=10)
267
+ sample, intermediates = inferer.sample(
268
+ input_noise=noise,
269
+ diffusion_model=model,
270
+ scheduler=scheduler,
271
+ save_intermediates=True,
272
+ intermediate_steps=1,
273
+ conditioning=conditioning,
274
+ mode="concat",
275
+ )
276
+ self.assertEqual(len(intermediates), 10)
277
+
207
278
  @parameterized.expand(TEST_CASES)
208
279
  @skipUnless(has_einops, "Requires einops")
209
280
  def test_call_conditioned_concat(self, model_params, input_shape):
@@ -231,6 +302,33 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
231
302
  )
232
303
  self.assertEqual(sample.shape, input_shape)
233
304
 
305
+ @parameterized.expand(TEST_CASES)
306
+ @skipUnless(has_einops, "Requires einops")
307
+ def test_call_conditioned_concat_rflow(self, model_params, input_shape):
308
+ # copy the model_params dict to prevent from modifying test cases
309
+ model_params = model_params.copy()
310
+ n_concat_channel = 2
311
+ model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
312
+ model_params["cross_attention_dim"] = None
313
+ model_params["with_conditioning"] = False
314
+ model = DiffusionModelUNet(**model_params)
315
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
316
+ model.to(device)
317
+ model.eval()
318
+ input = torch.randn(input_shape).to(device)
319
+ noise = torch.randn(input_shape).to(device)
320
+ conditioning_shape = list(input_shape)
321
+ conditioning_shape[1] = n_concat_channel
322
+ conditioning = torch.randn(conditioning_shape).to(device)
323
+ scheduler = RFlowScheduler(num_train_timesteps=1000)
324
+ inferer = DiffusionInferer(scheduler=scheduler)
325
+ scheduler.set_timesteps(num_inference_steps=10)
326
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
327
+ sample = inferer(
328
+ inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps, condition=conditioning, mode="concat"
329
+ )
330
+ self.assertEqual(sample.shape, input_shape)
331
+
234
332
 
235
333
  if __name__ == "__main__":
236
334
  unittest.main()