monai-weekly 1.5.dev2521__py3-none-any.whl → 1.5.dev2522__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monai/__init__.py CHANGED
@@ -136,4 +136,4 @@ except BaseException:
136
136
 
137
137
  if MONAIEnvVars.debug():
138
138
  raise
139
- __commit_id__ = "222d5094ac6ce07f68ce8b07100b0a517fd2506e"
139
+ __commit_id__ = "c3a317d2bcb486199f40bda0d722a41e3869712a"
monai/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-05-25T02:37:26+0000",
11
+ "date": "2025-06-01T02:46:06+0000",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "b82b3ff93efc25aff2eee158e3acaab717973590",
15
- "version": "1.5.dev2521"
14
+ "full-revisionid": "3e75ca0d8f4cf7ce1e4b3e01ef6058c56af87fdd",
15
+ "version": "1.5.dev2522"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
monai/inferers/inferer.py CHANGED
@@ -839,6 +839,7 @@ class DiffusionInferer(Inferer):
839
839
  mode: str = "crossattn",
840
840
  verbose: bool = True,
841
841
  seg: torch.Tensor | None = None,
842
+ cfg: float | None = None,
842
843
  ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
843
844
  """
844
845
  Args:
@@ -851,6 +852,7 @@ class DiffusionInferer(Inferer):
851
852
  mode: Conditioning mode for the network.
852
853
  verbose: if true, prints the progression bar of the sampling process.
853
854
  seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
855
+ cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
854
856
  """
855
857
  if mode not in ["crossattn", "concat"]:
856
858
  raise NotImplementedError(f"{mode} condition is not supported")
@@ -877,15 +879,31 @@ class DiffusionInferer(Inferer):
877
879
  if isinstance(diffusion_model, SPADEDiffusionModelUNet)
878
880
  else diffusion_model
879
881
  )
880
- if mode == "concat" and conditioning is not None:
881
- model_input = torch.cat([image, conditioning], dim=1)
882
+ if (
883
+ cfg is not None
884
+ ): # if classifier-free guidance is used, a conditioned and unconditioned bit is generated.
885
+ model_input = torch.cat([image] * 2, dim=0)
886
+ if conditioning is not None:
887
+ uncondition = torch.ones_like(conditioning)
888
+ uncondition.fill_(-1)
889
+ conditioning_input = torch.cat([uncondition, conditioning], dim=0)
890
+ else:
891
+ conditioning_input = None
892
+ else:
893
+ model_input = image
894
+ conditioning_input = conditioning
895
+ if mode == "concat" and conditioning_input is not None:
896
+ model_input = torch.cat([model_input, conditioning_input], dim=1)
882
897
  model_output = diffusion_model(
883
898
  model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None
884
899
  )
885
900
  else:
886
901
  model_output = diffusion_model(
887
- image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning
902
+ model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning_input
888
903
  )
904
+ if cfg is not None:
905
+ model_output_uncond, model_output_cond = model_output.chunk(2)
906
+ model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)
889
907
 
890
908
  # 2. compute previous image: x_t -> x_t-1
891
909
  if not isinstance(scheduler, RFlowScheduler):
@@ -1166,6 +1184,7 @@ class LatentDiffusionInferer(DiffusionInferer):
1166
1184
  mode: str = "crossattn",
1167
1185
  verbose: bool = True,
1168
1186
  seg: torch.Tensor | None = None,
1187
+ cfg: float | None = None,
1169
1188
  ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
1170
1189
  """
1171
1190
  Args:
@@ -1180,6 +1199,7 @@ class LatentDiffusionInferer(DiffusionInferer):
1180
1199
  verbose: if true, prints the progression bar of the sampling process.
1181
1200
  seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
1182
1201
  is instance of SPADEAutoencoderKL, segmentation must be provided.
1202
+ cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1183
1203
  """
1184
1204
 
1185
1205
  if (
@@ -1203,6 +1223,7 @@ class LatentDiffusionInferer(DiffusionInferer):
1203
1223
  mode=mode,
1204
1224
  verbose=verbose,
1205
1225
  seg=seg,
1226
+ cfg=cfg,
1206
1227
  )
1207
1228
 
1208
1229
  if save_intermediates:
@@ -1381,6 +1402,7 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1381
1402
  mode: str = "crossattn",
1382
1403
  verbose: bool = True,
1383
1404
  seg: torch.Tensor | None = None,
1405
+ cfg: float | None = None,
1384
1406
  ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
1385
1407
  """
1386
1408
  Args:
@@ -1395,6 +1417,7 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1395
1417
  mode: Conditioning mode for the network.
1396
1418
  verbose: if true, prints the progression bar of the sampling process.
1397
1419
  seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
1420
+ cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1398
1421
  """
1399
1422
  if mode not in ["crossattn", "concat"]:
1400
1423
  raise NotImplementedError(f"{mode} condition is not supported")
@@ -1413,14 +1436,31 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1413
1436
  progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))
1414
1437
  intermediates = []
1415
1438
 
1439
+ if cfg is not None:
1440
+ cn_cond = torch.cat([cn_cond] * 2, dim=0)
1441
+
1416
1442
  for t, next_t in progress_bar:
1443
+ # Controlnet prediction
1444
+ if cfg is not None:
1445
+ model_input = torch.cat([image] * 2, dim=0)
1446
+ if conditioning is not None:
1447
+ uncondition = torch.ones_like(conditioning)
1448
+ uncondition.fill_(-1)
1449
+ conditioning_input = torch.cat([uncondition, conditioning], dim=0)
1450
+ else:
1451
+ conditioning_input = None
1452
+ else:
1453
+ model_input = image
1454
+ conditioning_input = conditioning
1455
+
1456
+ # Diffusion model prediction
1417
1457
  diffuse = diffusion_model
1418
1458
  if isinstance(diffusion_model, SPADEDiffusionModelUNet):
1419
1459
  diffuse = partial(diffusion_model, seg=seg)
1420
1460
 
1421
- if mode == "concat" and conditioning is not None:
1461
+ if mode == "concat" and conditioning_input is not None:
1422
1462
  # 1. Conditioning
1423
- model_input = torch.cat([image, conditioning], dim=1)
1463
+ model_input = torch.cat([model_input, conditioning_input], dim=1)
1424
1464
  # 2. ControlNet forward
1425
1465
  down_block_res_samples, mid_block_res_sample = controlnet(
1426
1466
  x=model_input,
@@ -1437,20 +1477,28 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1437
1477
  mid_block_additional_residual=mid_block_res_sample,
1438
1478
  )
1439
1479
  else:
1480
+ # 1. Controlnet forward
1440
1481
  down_block_res_samples, mid_block_res_sample = controlnet(
1441
- x=image,
1482
+ x=model_input,
1442
1483
  timesteps=torch.Tensor((t,)).to(input_noise.device),
1443
1484
  controlnet_cond=cn_cond,
1444
- context=conditioning,
1485
+ context=conditioning_input,
1445
1486
  )
1487
+ # 2. predict noise model_output
1446
1488
  model_output = diffuse(
1447
- image,
1489
+ model_input,
1448
1490
  timesteps=torch.Tensor((t,)).to(input_noise.device),
1449
- context=conditioning,
1491
+ context=conditioning_input,
1450
1492
  down_block_additional_residuals=down_block_res_samples,
1451
1493
  mid_block_additional_residual=mid_block_res_sample,
1452
1494
  )
1453
1495
 
1496
+ # If classifier-free guidance isn't None, we split and compute the weighting between
1497
+ # conditioned and unconditioned output.
1498
+ if cfg is not None:
1499
+ model_output_uncond, model_output_cond = model_output.chunk(2)
1500
+ model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)
1501
+
1454
1502
  # 3. compute previous image: x_t -> x_t-1
1455
1503
  if not isinstance(scheduler, RFlowScheduler):
1456
1504
  image, _ = scheduler.step(model_output, t, image) # type: ignore
@@ -1714,6 +1762,7 @@ class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer):
1714
1762
  mode: str = "crossattn",
1715
1763
  verbose: bool = True,
1716
1764
  seg: torch.Tensor | None = None,
1765
+ cfg: float | None = None,
1717
1766
  ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
1718
1767
  """
1719
1768
  Args:
@@ -1730,6 +1779,7 @@ class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer):
1730
1779
  verbose: if true, prints the progression bar of the sampling process.
1731
1780
  seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
1732
1781
  is instance of SPADEAutoencoderKL, segmentation must be provided.
1782
+ cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1733
1783
  """
1734
1784
 
1735
1785
  if (
@@ -1757,6 +1807,7 @@ class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer):
1757
1807
  mode=mode,
1758
1808
  verbose=verbose,
1759
1809
  seg=seg,
1810
+ cfg=cfg,
1760
1811
  )
1761
1812
 
1762
1813
  if save_intermediates:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: monai-weekly
3
- Version: 1.5.dev2521
3
+ Version: 1.5.dev2522
4
4
  Summary: AI Toolkit for Healthcare Imaging
5
5
  Home-page: https://monai.io/
6
6
  Author: MONAI Consortium
@@ -194,6 +194,15 @@ Its ambitions are as follows:
194
194
  - customizable design for varying user expertise;
195
195
  - multi-GPU multi-node data parallelism support.
196
196
 
197
+ ## Requirements
198
+
199
+ MONAI works with the [currently supported versions of Python](https://devguide.python.org/versions), and depends directly on NumPy and PyTorch with many optional dependencies.
200
+
201
+ * Major releases of MONAI will have dependency versions stated for them. The current state of the `dev` branch in this repository is the unreleased development version of MONAI which typically will support current versions of dependencies and include updates and bug fixes to do so.
202
+ * PyTorch support covers [the current version](https://github.com/pytorch/pytorch/releases) plus three previous minor versions. If compatibility issues with a PyTorch version and other dependencies arise, support for a version may be delayed until a major release.
203
+ * Our support policy for other dependencies adheres for the most part to [SPEC0](https://scientific-python.org/specs/spec-0000), where dependency versions are supported where possible for up to two years. Discovered vulnerabilities or defects may require certain versions to be explicitly not supported.
204
+ * See the `requirements*.txt` files for dependency version information.
205
+
197
206
  ## Installation
198
207
 
199
208
  To install [the current release](https://pypi.org/project/monai/), you can simply run:
@@ -1,5 +1,5 @@
1
- monai/__init__.py,sha256=0oI9wDt1DDKPLqB0ZYtLX6PJxGhcLlKv38t3wsrg7aQ,4095
2
- monai/_version.py,sha256=iCrMQdxtRTq_xr7ATYU7pyuyBg-64AXid_jWWAwcZ9g,503
1
+ monai/__init__.py,sha256=eZQIXyDYcb7TrjIfc4e844rVm7n1JSyFVCrSaNR63h8,4095
2
+ monai/_version.py,sha256=7wXHOrAWUBKXGlWLCp1a8MS05DKSCeFNQSp2QC6OdFY,503
3
3
  monai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  monai/_extensions/__init__.py,sha256=NEBPreRhQ8H9gVvgrLr_y52_TmqB96u_u4VQmeNT93I,642
5
5
  monai/_extensions/loader.py,sha256=7SiKw36q-nOzH8CRbBurFrz7GM40GCu7rc93Tm8XpnI,3643
@@ -196,7 +196,7 @@ monai/handlers/trt_handler.py,sha256=uWFdgC8QKRkcNwWfKIbQMdK6-MX_1ON0mKabeIn1ltI
196
196
  monai/handlers/utils.py,sha256=Ib1u-PLrtIkiLqTfREnrCWpN4af1btdNzkyMZuuuYyU,10239
197
197
  monai/handlers/validation_handler.py,sha256=NZO21c6zzXbmAgJZHkkdoZQSQIHwuxh94QD3PLUldGU,3674
198
198
  monai/inferers/__init__.py,sha256=K74t_RCeUPdEZvHzIPzVAwZ9DtmouLqhb3qDEmFBWs4,1107
199
- monai/inferers/inferer.py,sha256=rgAI5qnLpszoXiSj3HCaqYiMAxymvqYO0Ltujq_lJUo,94617
199
+ monai/inferers/inferer.py,sha256=kfrxbSRAug_2oO943gQNHPfhfllu-DrseimT3vEb7uE,97280
200
200
  monai/inferers/merger.py,sha256=JxSLdlXTKW1xug11UWQNi6dNtpqVRbGCLc-ifj06g8U,16613
201
201
  monai/inferers/splitter.py,sha256=_hTnFdvDNRckkA7ZGQehVsNZw83oXoGFWyk5VXNqgJg,21149
202
202
  monai/inferers/utils.py,sha256=dvZBCAjaPa8xXcJuXRzNQ-fBzteauzkKbxE5YZdGBGY,20374
@@ -426,7 +426,7 @@ monai/visualize/img2tensorboard.py,sha256=n4ztSa5BQAUxSTGvi2tp45v-F7-RNgSlbsdy-9
426
426
  monai/visualize/occlusion_sensitivity.py,sha256=0SwhLO7ePDfIXJj67_UmXDZLxXItMeM-uNrPaCE0xXg,18159
427
427
  monai/visualize/utils.py,sha256=B-MhTVs7sQbIqYS3yPnpBwPw2K82rE2PBtGIfpwZtWM,9894
428
428
  monai/visualize/visualizer.py,sha256=qckyaMZCbezYUwE20k5yc-Pb7UozVavMDbrmyQwfYHY,1377
429
- monai_weekly-1.5.dev2521.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
429
+ monai_weekly-1.5.dev2522.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
430
430
  tests/apps/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
431
431
  tests/apps/test_auto3dseg_bundlegen.py,sha256=FpTJo9Lfe8vdhGuWeZ9y1BQmqYwTt-s8mDVtoLGAz_I,5594
432
432
  tests/apps/test_check_hash.py,sha256=MuZslW2DDCxHKEo6-PiL7hnbxGuZRRYf6HOh3ZQv1qQ,1761
@@ -508,7 +508,7 @@ tests/apps/vista3d/test_vista3d_sampler.py,sha256=-luQCe3Hhle2PC9AkFCUgK8gozOD0O
508
508
  tests/apps/vista3d/test_vista3d_transforms.py,sha256=nAPiDBNWeXLoW7ax3HHL63t5jqzQ3HFa-6wTvdyqVJk,3280
509
509
  tests/bundle/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
510
510
  tests/bundle/test_bundle_ckpt_export.py,sha256=VnpigCoBAAc2lo0rWOpVMg0IYGB6vbHXL8xLtB1Pkio,4622
511
- tests/bundle/test_bundle_download.py,sha256=snf7bfFbiLaQoXOC9nR3w7RVYQv1t2l1qMjSlzyIBDE,20213
511
+ tests/bundle/test_bundle_download.py,sha256=ER3JgkeRFr4ae6_swRnvFLQwhNv2vXwp_9cNqqDf45M,19674
512
512
  tests/bundle/test_bundle_get_data.py,sha256=lQh321mev_7fsLXRg0Tq5uEjuQILethDHRKzB6VV0o4,3667
513
513
  tests/bundle/test_bundle_push_to_hf_hub.py,sha256=Zjl6xDwRKgkS6jvO5dzMBaTLEd4EXyMXp0_wzDNSY3g,1740
514
514
  tests/bundle/test_bundle_trt_export.py,sha256=png-2SGjBSt46LXSz-PLprOXwJ0WkC_3YLR3Ibk_WBc,6344
@@ -575,9 +575,9 @@ tests/handlers/test_trt_compile.py,sha256=p8Gr2CJmBo6gG8w7bGlAO--nDHtQvy9Ld3jtua
575
575
  tests/handlers/test_write_metrics_reports.py,sha256=oKGYR1plj1hSAu-ntbxkw_TD4O5hKPwVH_BS3MdHIbs,3027
576
576
  tests/inferers/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
577
577
  tests/inferers/test_avg_merger.py,sha256=lMR2PcNGFD6sfF6CjJTkahrAiMA5m5LUs5A11P6h8n0,5952
578
- tests/inferers/test_controlnet_inferers.py,sha256=pGseHgnfMH-UOoAoUsKXdqka-IZc8X83ThauSanH--o,52825
579
- tests/inferers/test_diffusion_inferer.py,sha256=U6zNPnem9_cY9bDxMh6L2hThsmla7sDq9ivWQEyqNAk,14613
580
- tests/inferers/test_latent_diffusion_inferer.py,sha256=4cnS77I5YpFX1wKcTrlPfKVP3g6UHOkbuADgiXrScks,33544
578
+ tests/inferers/test_controlnet_inferers.py,sha256=cSBhxb9tMaoRuVmL634sv46cbhxPFZOs6a1ghCtQTOE,52922
579
+ tests/inferers/test_diffusion_inferer.py,sha256=NbcxyqdaFzHG5VR8usLt2mhqFlzVSP3EohGVu-LNBUE,16848
580
+ tests/inferers/test_latent_diffusion_inferer.py,sha256=ZU-ULLM52Z8Ukvl8qtXum_QRx_HaQxv0T_W98aNY4uM,35642
581
581
  tests/inferers/test_patch_inferer.py,sha256=LkYXWVn71vWinP-OJsIvq3FPH3jr36T7nKRIH5PzaqY,9878
582
582
  tests/inferers/test_saliency_inferer.py,sha256=7miHRbA4yb_WGcxql6za9uXXoZlql_7y23f7IzsyIps,1949
583
583
  tests/inferers/test_slice_inferer.py,sha256=kzaJjjTnf2rAiR75l8A_J-Kie4NaLj2bogi-aJ5L5mk,1897
@@ -1189,7 +1189,7 @@ tests/visualize/test_vis_gradcam.py,sha256=WpA-pvTB75eZs7JoIc5qyvOV9PwgkzWI8-Vow
1189
1189
  tests/visualize/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
1190
1190
  tests/visualize/utils/test_blend_images.py,sha256=RVs2p_8RWQDfhLHDNNtZaMig27v8o0km7XxNa-zWjKE,2274
1191
1191
  tests/visualize/utils/test_matshow3d.py,sha256=wXYj77L5Jvnp0f6DvL1rsi_-YlCxS0HJ9hiPmrbpuP8,5021
1192
- monai_weekly-1.5.dev2521.dist-info/METADATA,sha256=XhDzd0-niJZ3AvX990KJ8IZ-xOwEFanBCfNyUrGab80,12033
1193
- monai_weekly-1.5.dev2521.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
1194
- monai_weekly-1.5.dev2521.dist-info/top_level.txt,sha256=hn2Y6P9xBf2R8faMeVMHhPMvrdDKxMsIOwMDYI0yTjs,12
1195
- monai_weekly-1.5.dev2521.dist-info/RECORD,,
1192
+ monai_weekly-1.5.dev2522.dist-info/METADATA,sha256=SECyXqQSdsULPcpimKHFcaPJko0nMWSVYjg31Za4OHU,13152
1193
+ monai_weekly-1.5.dev2522.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
1194
+ monai_weekly-1.5.dev2522.dist-info/top_level.txt,sha256=hn2Y6P9xBf2R8faMeVMHhPMvrdDKxMsIOwMDYI0yTjs,12
1195
+ monai_weekly-1.5.dev2522.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.8.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -316,20 +316,6 @@ class TestLoad(unittest.TestCase):
316
316
  output_2 = model_2.forward(input_tensor)
317
317
  assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False)
318
318
 
319
- model_3 = load(
320
- name=bundle_name,
321
- model_file=model_file,
322
- bundle_dir=tempdir,
323
- progress=False,
324
- device=device,
325
- net_name=model_name,
326
- source="github",
327
- **net_args,
328
- )
329
- model_3.eval()
330
- output_3 = model_3.forward(input_tensor)
331
- assert_allclose(output_3, expected_output, atol=1e-4, rtol=1e-4, type_test=False)
332
-
333
319
  @parameterized.expand([TEST_CASE_8])
334
320
  @skip_if_quick
335
321
  @skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.")
@@ -482,16 +482,20 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
482
482
  scheduler = DDPMScheduler(num_train_timesteps=10)
483
483
  inferer = ControlNetDiffusionInferer(scheduler=scheduler)
484
484
  scheduler.set_timesteps(num_inference_steps=10)
485
- sample, intermediates = inferer.sample(
486
- input_noise=noise,
487
- diffusion_model=model,
488
- scheduler=scheduler,
489
- controlnet=controlnet,
490
- cn_cond=mask,
491
- save_intermediates=True,
492
- intermediate_steps=1,
493
- )
494
- self.assertEqual(len(intermediates), 10)
485
+
486
+ for cfg in [5, None]:
487
+ sample, intermediates = inferer.sample(
488
+ input_noise=noise,
489
+ diffusion_model=model,
490
+ scheduler=scheduler,
491
+ controlnet=controlnet,
492
+ cn_cond=mask,
493
+ save_intermediates=True,
494
+ intermediate_steps=1,
495
+ cfg=cfg,
496
+ )
497
+
498
+ self.assertEqual(len(intermediates), 10)
495
499
 
496
500
  @parameterized.expand(CNDM_TEST_CASES)
497
501
  @skipUnless(has_einops, "Requires einops")
@@ -88,6 +88,27 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
88
88
  )
89
89
  self.assertEqual(len(intermediates), 10)
90
90
 
91
+ @parameterized.expand(TEST_CASES)
92
+ @skipUnless(has_einops, "Requires einops")
93
+ def test_sample_cfg(self, model_params, input_shape):
94
+ model = DiffusionModelUNet(**model_params)
95
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
96
+ model.to(device)
97
+ model.eval()
98
+ noise = torch.randn(input_shape).to(device)
99
+ scheduler = DDPMScheduler(num_train_timesteps=10)
100
+ inferer = DiffusionInferer(scheduler=scheduler)
101
+ scheduler.set_timesteps(num_inference_steps=10)
102
+ sample, intermediates = inferer.sample(
103
+ input_noise=noise,
104
+ diffusion_model=model,
105
+ scheduler=scheduler,
106
+ save_intermediates=True,
107
+ intermediate_steps=1,
108
+ cfg=5,
109
+ )
110
+ self.assertEqual(sample.shape, noise.shape)
111
+
91
112
  @parameterized.expand(TEST_CASES)
92
113
  @skipUnless(has_einops, "Requires einops")
93
114
  def test_ddpm_sampler(self, model_params, input_shape):
@@ -244,6 +265,38 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
244
265
  )
245
266
  self.assertEqual(len(intermediates), 10)
246
267
 
268
+ @parameterized.expand(TEST_CASES)
269
+ @skipUnless(has_einops, "Requires einops")
270
+ def test_sampler_conditioned_concat_cfg(self, model_params, input_shape):
271
+ # copy the model_params dict to prevent from modifying test cases
272
+ model_params = model_params.copy()
273
+ n_concat_channel = 2
274
+ model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
275
+ model_params["cross_attention_dim"] = None
276
+ model_params["with_conditioning"] = False
277
+ model = DiffusionModelUNet(**model_params)
278
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
279
+ model.to(device)
280
+ model.eval()
281
+ noise = torch.randn(input_shape).to(device)
282
+ conditioning_shape = list(input_shape)
283
+ conditioning_shape[1] = n_concat_channel
284
+ conditioning = torch.randn(conditioning_shape).to(device)
285
+ scheduler = DDIMScheduler(num_train_timesteps=1000)
286
+ inferer = DiffusionInferer(scheduler=scheduler)
287
+ scheduler.set_timesteps(num_inference_steps=10)
288
+ sample, intermediates = inferer.sample(
289
+ input_noise=noise,
290
+ diffusion_model=model,
291
+ scheduler=scheduler,
292
+ save_intermediates=True,
293
+ intermediate_steps=1,
294
+ conditioning=conditioning,
295
+ mode="concat",
296
+ cfg=5,
297
+ )
298
+ self.assertEqual(len(intermediates), 10)
299
+
247
300
  @parameterized.expand(TEST_CASES)
248
301
  @skipUnless(has_einops, "Requires einops")
249
302
  def test_sampler_conditioned_concat_rflow(self, model_params, input_shape):
@@ -414,6 +414,55 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
414
414
  )
415
415
  self.assertEqual(sample.shape, input_shape)
416
416
 
417
+ @parameterized.expand(TEST_CASES)
418
+ @skipUnless(has_einops, "Requires einops")
419
+ def test_sample_shape_with_cfg(
420
+ self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
421
+ ):
422
+ stage_1 = None
423
+
424
+ if ae_model_type == "AutoencoderKL":
425
+ stage_1 = AutoencoderKL(**autoencoder_params)
426
+ if ae_model_type == "VQVAE":
427
+ stage_1 = VQVAE(**autoencoder_params)
428
+ if dm_model_type == "SPADEDiffusionModelUNet":
429
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
430
+ else:
431
+ stage_2 = DiffusionModelUNet(**stage_2_params)
432
+
433
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
434
+ stage_1.to(device)
435
+ stage_2.to(device)
436
+ stage_1.eval()
437
+ stage_2.eval()
438
+
439
+ noise = torch.randn(latent_shape).to(device)
440
+
441
+ for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
442
+ inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
443
+ scheduler.set_timesteps(num_inference_steps=10)
444
+
445
+ if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
446
+ input_shape_seg = list(input_shape)
447
+ if "label_nc" in stage_2_params.keys():
448
+ input_shape_seg[1] = stage_2_params["label_nc"]
449
+ else:
450
+ input_shape_seg[1] = autoencoder_params["label_nc"]
451
+ input_seg = torch.randn(input_shape_seg).to(device)
452
+ sample = inferer.sample(
453
+ input_noise=noise,
454
+ autoencoder_model=stage_1,
455
+ diffusion_model=stage_2,
456
+ scheduler=scheduler,
457
+ seg=input_seg,
458
+ cfg=5,
459
+ )
460
+ else:
461
+ sample = inferer.sample(
462
+ input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, cfg=5
463
+ )
464
+ self.assertEqual(sample.shape, input_shape)
465
+
417
466
  @parameterized.expand(TEST_CASES)
418
467
  @skipUnless(has_einops, "Requires einops")
419
468
  def test_sample_intermediates(