monai-weekly 1.5.dev2510__py3-none-any.whl → 1.5.dev2511__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +4 -0
- monai/inferers/inferer.py +29 -9
- monai/networks/schedulers/__init__.py +1 -0
- monai/networks/schedulers/rectified_flow.py +322 -0
- {monai_weekly-1.5.dev2510.dist-info → monai_weekly-1.5.dev2511.dist-info}/METADATA +1 -1
- {monai_weekly-1.5.dev2510.dist-info → monai_weekly-1.5.dev2511.dist-info}/RECORD +15 -13
- {monai_weekly-1.5.dev2510.dist-info → monai_weekly-1.5.dev2511.dist-info}/WHEEL +1 -1
- tests/inferers/test_controlnet_inferers.py +96 -32
- tests/inferers/test_diffusion_inferer.py +99 -1
- tests/inferers/test_latent_diffusion_inferer.py +217 -211
- tests/networks/schedulers/test_scheduler_rflow.py +105 -0
- {monai_weekly-1.5.dev2510.dist-info → monai_weekly-1.5.dev2511.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2510.dist-info → monai_weekly-1.5.dev2511.dist-info}/top_level.txt +0 -0
@@ -26,7 +26,7 @@ from monai.networks.nets import (
|
|
26
26
|
SPADEAutoencoderKL,
|
27
27
|
SPADEDiffusionModelUNet,
|
28
28
|
)
|
29
|
-
from monai.networks.schedulers import DDIMScheduler, DDPMScheduler
|
29
|
+
from monai.networks.schedulers import DDIMScheduler, DDPMScheduler, RFlowScheduler
|
30
30
|
from monai.utils import optional_import
|
31
31
|
|
32
32
|
_, has_scipy = optional_import("scipy")
|
@@ -545,6 +545,32 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
545
545
|
)
|
546
546
|
self.assertEqual(len(intermediates), 10)
|
547
547
|
|
548
|
+
@parameterized.expand(CNDM_TEST_CASES)
|
549
|
+
@skipUnless(has_einops, "Requires einops")
|
550
|
+
def test_rflow_sampler(self, model_params, controlnet_params, input_shape):
|
551
|
+
model = DiffusionModelUNet(**model_params)
|
552
|
+
controlnet = ControlNet(**controlnet_params)
|
553
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
554
|
+
model.to(device)
|
555
|
+
model.eval()
|
556
|
+
controlnet.to(device)
|
557
|
+
controlnet.eval()
|
558
|
+
mask = torch.randn(input_shape).to(device)
|
559
|
+
noise = torch.randn(input_shape).to(device)
|
560
|
+
scheduler = RFlowScheduler(num_train_timesteps=1000)
|
561
|
+
inferer = ControlNetDiffusionInferer(scheduler=scheduler)
|
562
|
+
scheduler.set_timesteps(num_inference_steps=10)
|
563
|
+
sample, intermediates = inferer.sample(
|
564
|
+
input_noise=noise,
|
565
|
+
diffusion_model=model,
|
566
|
+
scheduler=scheduler,
|
567
|
+
controlnet=controlnet,
|
568
|
+
cn_cond=mask,
|
569
|
+
save_intermediates=True,
|
570
|
+
intermediate_steps=1,
|
571
|
+
)
|
572
|
+
self.assertEqual(len(intermediates), 10)
|
573
|
+
|
548
574
|
@parameterized.expand(CNDM_TEST_CASES)
|
549
575
|
@skipUnless(has_einops, "Requires einops")
|
550
576
|
def test_sampler_conditioned(self, model_params, controlnet_params, input_shape):
|
@@ -561,6 +587,8 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
561
587
|
controlnet.eval()
|
562
588
|
mask = torch.randn(input_shape).to(device)
|
563
589
|
noise = torch.randn(input_shape).to(device)
|
590
|
+
|
591
|
+
# DDIM
|
564
592
|
scheduler = DDIMScheduler(num_train_timesteps=1000)
|
565
593
|
inferer = ControlNetDiffusionInferer(scheduler=scheduler)
|
566
594
|
scheduler.set_timesteps(num_inference_steps=10)
|
@@ -577,6 +605,23 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
577
605
|
)
|
578
606
|
self.assertEqual(len(intermediates), 10)
|
579
607
|
|
608
|
+
# RFlow
|
609
|
+
scheduler = RFlowScheduler(num_train_timesteps=1000)
|
610
|
+
inferer = ControlNetDiffusionInferer(scheduler=scheduler)
|
611
|
+
scheduler.set_timesteps(num_inference_steps=10)
|
612
|
+
conditioning = torch.randn([input_shape[0], 1, 3]).to(device)
|
613
|
+
sample, intermediates = inferer.sample(
|
614
|
+
input_noise=noise,
|
615
|
+
diffusion_model=model,
|
616
|
+
controlnet=controlnet,
|
617
|
+
cn_cond=mask,
|
618
|
+
scheduler=scheduler,
|
619
|
+
save_intermediates=True,
|
620
|
+
intermediate_steps=1,
|
621
|
+
conditioning=conditioning,
|
622
|
+
)
|
623
|
+
self.assertEqual(len(intermediates), 10)
|
624
|
+
|
580
625
|
@parameterized.expand(CNDM_TEST_CASES)
|
581
626
|
@skipUnless(has_einops, "Requires einops")
|
582
627
|
def test_get_likelihood(self, model_params, controlnet_params, input_shape):
|
@@ -638,6 +683,8 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
638
683
|
conditioning_shape = list(input_shape)
|
639
684
|
conditioning_shape[1] = n_concat_channel
|
640
685
|
conditioning = torch.randn(conditioning_shape).to(device)
|
686
|
+
|
687
|
+
# DDIM
|
641
688
|
scheduler = DDIMScheduler(num_train_timesteps=1000)
|
642
689
|
inferer = ControlNetDiffusionInferer(scheduler=scheduler)
|
643
690
|
scheduler.set_timesteps(num_inference_steps=10)
|
@@ -654,6 +701,23 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
654
701
|
)
|
655
702
|
self.assertEqual(len(intermediates), 10)
|
656
703
|
|
704
|
+
# RFlow
|
705
|
+
scheduler = RFlowScheduler(num_train_timesteps=1000)
|
706
|
+
inferer = ControlNetDiffusionInferer(scheduler=scheduler)
|
707
|
+
scheduler.set_timesteps(num_inference_steps=10)
|
708
|
+
sample, intermediates = inferer.sample(
|
709
|
+
input_noise=noise,
|
710
|
+
diffusion_model=model,
|
711
|
+
controlnet=controlnet,
|
712
|
+
cn_cond=mask,
|
713
|
+
scheduler=scheduler,
|
714
|
+
save_intermediates=True,
|
715
|
+
intermediate_steps=1,
|
716
|
+
conditioning=conditioning,
|
717
|
+
mode="concat",
|
718
|
+
)
|
719
|
+
self.assertEqual(len(intermediates), 10)
|
720
|
+
|
657
721
|
|
658
722
|
class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
659
723
|
@parameterized.expand(LATENT_CNDM_TEST_CASES)
|
@@ -691,39 +755,39 @@ class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
691
755
|
input = torch.randn(input_shape).to(device)
|
692
756
|
mask = torch.randn(input_shape).to(device)
|
693
757
|
noise = torch.randn(latent_shape).to(device)
|
694
|
-
scheduler = DDPMScheduler(num_train_timesteps=10)
|
695
|
-
inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
|
696
|
-
scheduler.set_timesteps(num_inference_steps=10)
|
697
|
-
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
|
698
758
|
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
759
|
+
for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
|
760
|
+
inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
|
761
|
+
scheduler.set_timesteps(num_inference_steps=10)
|
762
|
+
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
|
763
|
+
if dm_model_type == "SPADEDiffusionModelUNet":
|
764
|
+
input_shape_seg = list(input_shape)
|
765
|
+
if "label_nc" in stage_2_params.keys():
|
766
|
+
input_shape_seg[1] = stage_2_params["label_nc"]
|
767
|
+
else:
|
768
|
+
input_shape_seg[1] = autoencoder_params["label_nc"]
|
769
|
+
input_seg = torch.randn(input_shape_seg).to(device)
|
770
|
+
prediction = inferer(
|
771
|
+
inputs=input,
|
772
|
+
autoencoder_model=stage_1,
|
773
|
+
diffusion_model=stage_2,
|
774
|
+
controlnet=controlnet,
|
775
|
+
cn_cond=mask,
|
776
|
+
seg=input_seg,
|
777
|
+
noise=noise,
|
778
|
+
timesteps=timesteps,
|
779
|
+
)
|
703
780
|
else:
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
timesteps=timesteps,
|
715
|
-
)
|
716
|
-
else:
|
717
|
-
prediction = inferer(
|
718
|
-
inputs=input,
|
719
|
-
autoencoder_model=stage_1,
|
720
|
-
diffusion_model=stage_2,
|
721
|
-
noise=noise,
|
722
|
-
timesteps=timesteps,
|
723
|
-
controlnet=controlnet,
|
724
|
-
cn_cond=mask,
|
725
|
-
)
|
726
|
-
self.assertEqual(prediction.shape, latent_shape)
|
781
|
+
prediction = inferer(
|
782
|
+
inputs=input,
|
783
|
+
autoencoder_model=stage_1,
|
784
|
+
diffusion_model=stage_2,
|
785
|
+
noise=noise,
|
786
|
+
timesteps=timesteps,
|
787
|
+
controlnet=controlnet,
|
788
|
+
cn_cond=mask,
|
789
|
+
)
|
790
|
+
self.assertEqual(prediction.shape, latent_shape)
|
727
791
|
|
728
792
|
@parameterized.expand(LATENT_CNDM_TEST_CASES)
|
729
793
|
@skipUnless(has_einops, "Requires einops")
|
@@ -19,7 +19,7 @@ from parameterized import parameterized
|
|
19
19
|
|
20
20
|
from monai.inferers import DiffusionInferer
|
21
21
|
from monai.networks.nets import DiffusionModelUNet
|
22
|
-
from monai.networks.schedulers import DDIMScheduler, DDPMScheduler
|
22
|
+
from monai.networks.schedulers import DDIMScheduler, DDPMScheduler, RFlowScheduler
|
23
23
|
from monai.utils import optional_import
|
24
24
|
|
25
25
|
_, has_scipy = optional_import("scipy")
|
@@ -120,6 +120,22 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
|
|
120
120
|
)
|
121
121
|
self.assertEqual(len(intermediates), 10)
|
122
122
|
|
123
|
+
@parameterized.expand(TEST_CASES)
|
124
|
+
@skipUnless(has_einops, "Requires einops")
|
125
|
+
def test_rflow_sampler(self, model_params, input_shape):
|
126
|
+
model = DiffusionModelUNet(**model_params)
|
127
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
128
|
+
model.to(device)
|
129
|
+
model.eval()
|
130
|
+
noise = torch.randn(input_shape).to(device)
|
131
|
+
scheduler = RFlowScheduler(num_train_timesteps=1000)
|
132
|
+
inferer = DiffusionInferer(scheduler=scheduler)
|
133
|
+
scheduler.set_timesteps(num_inference_steps=10)
|
134
|
+
sample, intermediates = inferer.sample(
|
135
|
+
input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1
|
136
|
+
)
|
137
|
+
self.assertEqual(len(intermediates), 10)
|
138
|
+
|
123
139
|
@parameterized.expand(TEST_CASES)
|
124
140
|
@skipUnless(has_einops, "Requires einops")
|
125
141
|
def test_sampler_conditioned(self, model_params, input_shape):
|
@@ -144,6 +160,30 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
|
|
144
160
|
)
|
145
161
|
self.assertEqual(len(intermediates), 10)
|
146
162
|
|
163
|
+
@parameterized.expand(TEST_CASES)
|
164
|
+
@skipUnless(has_einops, "Requires einops")
|
165
|
+
def test_sampler_conditioned_rflow(self, model_params, input_shape):
|
166
|
+
model_params["with_conditioning"] = True
|
167
|
+
model_params["cross_attention_dim"] = 3
|
168
|
+
model = DiffusionModelUNet(**model_params)
|
169
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
170
|
+
model.to(device)
|
171
|
+
model.eval()
|
172
|
+
noise = torch.randn(input_shape).to(device)
|
173
|
+
scheduler = RFlowScheduler(num_train_timesteps=1000)
|
174
|
+
inferer = DiffusionInferer(scheduler=scheduler)
|
175
|
+
scheduler.set_timesteps(num_inference_steps=10)
|
176
|
+
conditioning = torch.randn([input_shape[0], 1, 3]).to(device)
|
177
|
+
sample, intermediates = inferer.sample(
|
178
|
+
input_noise=noise,
|
179
|
+
diffusion_model=model,
|
180
|
+
scheduler=scheduler,
|
181
|
+
save_intermediates=True,
|
182
|
+
intermediate_steps=1,
|
183
|
+
conditioning=conditioning,
|
184
|
+
)
|
185
|
+
self.assertEqual(len(intermediates), 10)
|
186
|
+
|
147
187
|
@parameterized.expand(TEST_CASES)
|
148
188
|
@skipUnless(has_einops, "Requires einops")
|
149
189
|
def test_get_likelihood(self, model_params, input_shape):
|
@@ -204,6 +244,37 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
|
|
204
244
|
)
|
205
245
|
self.assertEqual(len(intermediates), 10)
|
206
246
|
|
247
|
+
@parameterized.expand(TEST_CASES)
|
248
|
+
@skipUnless(has_einops, "Requires einops")
|
249
|
+
def test_sampler_conditioned_concat_rflow(self, model_params, input_shape):
|
250
|
+
# copy the model_params dict to prevent from modifying test cases
|
251
|
+
model_params = model_params.copy()
|
252
|
+
n_concat_channel = 2
|
253
|
+
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
|
254
|
+
model_params["cross_attention_dim"] = None
|
255
|
+
model_params["with_conditioning"] = False
|
256
|
+
model = DiffusionModelUNet(**model_params)
|
257
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
258
|
+
model.to(device)
|
259
|
+
model.eval()
|
260
|
+
noise = torch.randn(input_shape).to(device)
|
261
|
+
conditioning_shape = list(input_shape)
|
262
|
+
conditioning_shape[1] = n_concat_channel
|
263
|
+
conditioning = torch.randn(conditioning_shape).to(device)
|
264
|
+
scheduler = RFlowScheduler(num_train_timesteps=1000)
|
265
|
+
inferer = DiffusionInferer(scheduler=scheduler)
|
266
|
+
scheduler.set_timesteps(num_inference_steps=10)
|
267
|
+
sample, intermediates = inferer.sample(
|
268
|
+
input_noise=noise,
|
269
|
+
diffusion_model=model,
|
270
|
+
scheduler=scheduler,
|
271
|
+
save_intermediates=True,
|
272
|
+
intermediate_steps=1,
|
273
|
+
conditioning=conditioning,
|
274
|
+
mode="concat",
|
275
|
+
)
|
276
|
+
self.assertEqual(len(intermediates), 10)
|
277
|
+
|
207
278
|
@parameterized.expand(TEST_CASES)
|
208
279
|
@skipUnless(has_einops, "Requires einops")
|
209
280
|
def test_call_conditioned_concat(self, model_params, input_shape):
|
@@ -231,6 +302,33 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
|
|
231
302
|
)
|
232
303
|
self.assertEqual(sample.shape, input_shape)
|
233
304
|
|
305
|
+
@parameterized.expand(TEST_CASES)
|
306
|
+
@skipUnless(has_einops, "Requires einops")
|
307
|
+
def test_call_conditioned_concat_rflow(self, model_params, input_shape):
|
308
|
+
# copy the model_params dict to prevent from modifying test cases
|
309
|
+
model_params = model_params.copy()
|
310
|
+
n_concat_channel = 2
|
311
|
+
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
|
312
|
+
model_params["cross_attention_dim"] = None
|
313
|
+
model_params["with_conditioning"] = False
|
314
|
+
model = DiffusionModelUNet(**model_params)
|
315
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
316
|
+
model.to(device)
|
317
|
+
model.eval()
|
318
|
+
input = torch.randn(input_shape).to(device)
|
319
|
+
noise = torch.randn(input_shape).to(device)
|
320
|
+
conditioning_shape = list(input_shape)
|
321
|
+
conditioning_shape[1] = n_concat_channel
|
322
|
+
conditioning = torch.randn(conditioning_shape).to(device)
|
323
|
+
scheduler = RFlowScheduler(num_train_timesteps=1000)
|
324
|
+
inferer = DiffusionInferer(scheduler=scheduler)
|
325
|
+
scheduler.set_timesteps(num_inference_steps=10)
|
326
|
+
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
|
327
|
+
sample = inferer(
|
328
|
+
inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps, condition=conditioning, mode="concat"
|
329
|
+
)
|
330
|
+
self.assertEqual(sample.shape, input_shape)
|
331
|
+
|
234
332
|
|
235
333
|
if __name__ == "__main__":
|
236
334
|
unittest.main()
|