monai-weekly 1.5.dev2521__py3-none-any.whl → 1.5.dev2522__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/inferers/inferer.py +60 -9
- {monai_weekly-1.5.dev2521.dist-info → monai_weekly-1.5.dev2522.dist-info}/METADATA +10 -1
- {monai_weekly-1.5.dev2521.dist-info → monai_weekly-1.5.dev2522.dist-info}/RECORD +12 -12
- {monai_weekly-1.5.dev2521.dist-info → monai_weekly-1.5.dev2522.dist-info}/WHEEL +1 -1
- tests/bundle/test_bundle_download.py +0 -14
- tests/inferers/test_controlnet_inferers.py +14 -10
- tests/inferers/test_diffusion_inferer.py +53 -0
- tests/inferers/test_latent_diffusion_inferer.py +49 -0
- {monai_weekly-1.5.dev2521.dist-info → monai_weekly-1.5.dev2522.dist-info}/licenses/LICENSE +0 -0
- {monai_weekly-1.5.dev2521.dist-info → monai_weekly-1.5.dev2522.dist-info}/top_level.txt +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2025-
|
11
|
+
"date": "2025-06-01T02:46:06+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.5.
|
14
|
+
"full-revisionid": "3e75ca0d8f4cf7ce1e4b3e01ef6058c56af87fdd",
|
15
|
+
"version": "1.5.dev2522"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
monai/inferers/inferer.py
CHANGED
@@ -839,6 +839,7 @@ class DiffusionInferer(Inferer):
|
|
839
839
|
mode: str = "crossattn",
|
840
840
|
verbose: bool = True,
|
841
841
|
seg: torch.Tensor | None = None,
|
842
|
+
cfg: float | None = None,
|
842
843
|
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
843
844
|
"""
|
844
845
|
Args:
|
@@ -851,6 +852,7 @@ class DiffusionInferer(Inferer):
|
|
851
852
|
mode: Conditioning mode for the network.
|
852
853
|
verbose: if true, prints the progression bar of the sampling process.
|
853
854
|
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
|
855
|
+
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
|
854
856
|
"""
|
855
857
|
if mode not in ["crossattn", "concat"]:
|
856
858
|
raise NotImplementedError(f"{mode} condition is not supported")
|
@@ -877,15 +879,31 @@ class DiffusionInferer(Inferer):
|
|
877
879
|
if isinstance(diffusion_model, SPADEDiffusionModelUNet)
|
878
880
|
else diffusion_model
|
879
881
|
)
|
880
|
-
if
|
881
|
-
|
882
|
+
if (
|
883
|
+
cfg is not None
|
884
|
+
): # if classifier-free guidance is used, a conditioned and unconditioned bit is generated.
|
885
|
+
model_input = torch.cat([image] * 2, dim=0)
|
886
|
+
if conditioning is not None:
|
887
|
+
uncondition = torch.ones_like(conditioning)
|
888
|
+
uncondition.fill_(-1)
|
889
|
+
conditioning_input = torch.cat([uncondition, conditioning], dim=0)
|
890
|
+
else:
|
891
|
+
conditioning_input = None
|
892
|
+
else:
|
893
|
+
model_input = image
|
894
|
+
conditioning_input = conditioning
|
895
|
+
if mode == "concat" and conditioning_input is not None:
|
896
|
+
model_input = torch.cat([model_input, conditioning_input], dim=1)
|
882
897
|
model_output = diffusion_model(
|
883
898
|
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None
|
884
899
|
)
|
885
900
|
else:
|
886
901
|
model_output = diffusion_model(
|
887
|
-
|
902
|
+
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning_input
|
888
903
|
)
|
904
|
+
if cfg is not None:
|
905
|
+
model_output_uncond, model_output_cond = model_output.chunk(2)
|
906
|
+
model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)
|
889
907
|
|
890
908
|
# 2. compute previous image: x_t -> x_t-1
|
891
909
|
if not isinstance(scheduler, RFlowScheduler):
|
@@ -1166,6 +1184,7 @@ class LatentDiffusionInferer(DiffusionInferer):
|
|
1166
1184
|
mode: str = "crossattn",
|
1167
1185
|
verbose: bool = True,
|
1168
1186
|
seg: torch.Tensor | None = None,
|
1187
|
+
cfg: float | None = None,
|
1169
1188
|
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
1170
1189
|
"""
|
1171
1190
|
Args:
|
@@ -1180,6 +1199,7 @@ class LatentDiffusionInferer(DiffusionInferer):
|
|
1180
1199
|
verbose: if true, prints the progression bar of the sampling process.
|
1181
1200
|
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
|
1182
1201
|
is instance of SPADEAutoencoderKL, segmentation must be provided.
|
1202
|
+
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
|
1183
1203
|
"""
|
1184
1204
|
|
1185
1205
|
if (
|
@@ -1203,6 +1223,7 @@ class LatentDiffusionInferer(DiffusionInferer):
|
|
1203
1223
|
mode=mode,
|
1204
1224
|
verbose=verbose,
|
1205
1225
|
seg=seg,
|
1226
|
+
cfg=cfg,
|
1206
1227
|
)
|
1207
1228
|
|
1208
1229
|
if save_intermediates:
|
@@ -1381,6 +1402,7 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1381
1402
|
mode: str = "crossattn",
|
1382
1403
|
verbose: bool = True,
|
1383
1404
|
seg: torch.Tensor | None = None,
|
1405
|
+
cfg: float | None = None,
|
1384
1406
|
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
1385
1407
|
"""
|
1386
1408
|
Args:
|
@@ -1395,6 +1417,7 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1395
1417
|
mode: Conditioning mode for the network.
|
1396
1418
|
verbose: if true, prints the progression bar of the sampling process.
|
1397
1419
|
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
|
1420
|
+
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
|
1398
1421
|
"""
|
1399
1422
|
if mode not in ["crossattn", "concat"]:
|
1400
1423
|
raise NotImplementedError(f"{mode} condition is not supported")
|
@@ -1413,14 +1436,31 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1413
1436
|
progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))
|
1414
1437
|
intermediates = []
|
1415
1438
|
|
1439
|
+
if cfg is not None:
|
1440
|
+
cn_cond = torch.cat([cn_cond] * 2, dim=0)
|
1441
|
+
|
1416
1442
|
for t, next_t in progress_bar:
|
1443
|
+
# Controlnet prediction
|
1444
|
+
if cfg is not None:
|
1445
|
+
model_input = torch.cat([image] * 2, dim=0)
|
1446
|
+
if conditioning is not None:
|
1447
|
+
uncondition = torch.ones_like(conditioning)
|
1448
|
+
uncondition.fill_(-1)
|
1449
|
+
conditioning_input = torch.cat([uncondition, conditioning], dim=0)
|
1450
|
+
else:
|
1451
|
+
conditioning_input = None
|
1452
|
+
else:
|
1453
|
+
model_input = image
|
1454
|
+
conditioning_input = conditioning
|
1455
|
+
|
1456
|
+
# Diffusion model prediction
|
1417
1457
|
diffuse = diffusion_model
|
1418
1458
|
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
|
1419
1459
|
diffuse = partial(diffusion_model, seg=seg)
|
1420
1460
|
|
1421
|
-
if mode == "concat" and
|
1461
|
+
if mode == "concat" and conditioning_input is not None:
|
1422
1462
|
# 1. Conditioning
|
1423
|
-
model_input = torch.cat([
|
1463
|
+
model_input = torch.cat([model_input, conditioning_input], dim=1)
|
1424
1464
|
# 2. ControlNet forward
|
1425
1465
|
down_block_res_samples, mid_block_res_sample = controlnet(
|
1426
1466
|
x=model_input,
|
@@ -1437,20 +1477,28 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1437
1477
|
mid_block_additional_residual=mid_block_res_sample,
|
1438
1478
|
)
|
1439
1479
|
else:
|
1480
|
+
# 1. Controlnet forward
|
1440
1481
|
down_block_res_samples, mid_block_res_sample = controlnet(
|
1441
|
-
x=
|
1482
|
+
x=model_input,
|
1442
1483
|
timesteps=torch.Tensor((t,)).to(input_noise.device),
|
1443
1484
|
controlnet_cond=cn_cond,
|
1444
|
-
context=
|
1485
|
+
context=conditioning_input,
|
1445
1486
|
)
|
1487
|
+
# 2. predict noise model_output
|
1446
1488
|
model_output = diffuse(
|
1447
|
-
|
1489
|
+
model_input,
|
1448
1490
|
timesteps=torch.Tensor((t,)).to(input_noise.device),
|
1449
|
-
context=
|
1491
|
+
context=conditioning_input,
|
1450
1492
|
down_block_additional_residuals=down_block_res_samples,
|
1451
1493
|
mid_block_additional_residual=mid_block_res_sample,
|
1452
1494
|
)
|
1453
1495
|
|
1496
|
+
# If classifier-free guidance isn't None, we split and compute the weighting between
|
1497
|
+
# conditioned and unconditioned output.
|
1498
|
+
if cfg is not None:
|
1499
|
+
model_output_uncond, model_output_cond = model_output.chunk(2)
|
1500
|
+
model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)
|
1501
|
+
|
1454
1502
|
# 3. compute previous image: x_t -> x_t-1
|
1455
1503
|
if not isinstance(scheduler, RFlowScheduler):
|
1456
1504
|
image, _ = scheduler.step(model_output, t, image) # type: ignore
|
@@ -1714,6 +1762,7 @@ class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer):
|
|
1714
1762
|
mode: str = "crossattn",
|
1715
1763
|
verbose: bool = True,
|
1716
1764
|
seg: torch.Tensor | None = None,
|
1765
|
+
cfg: float | None = None,
|
1717
1766
|
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
1718
1767
|
"""
|
1719
1768
|
Args:
|
@@ -1730,6 +1779,7 @@ class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer):
|
|
1730
1779
|
verbose: if true, prints the progression bar of the sampling process.
|
1731
1780
|
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
|
1732
1781
|
is instance of SPADEAutoencoderKL, segmentation must be provided.
|
1782
|
+
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
|
1733
1783
|
"""
|
1734
1784
|
|
1735
1785
|
if (
|
@@ -1757,6 +1807,7 @@ class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer):
|
|
1757
1807
|
mode=mode,
|
1758
1808
|
verbose=verbose,
|
1759
1809
|
seg=seg,
|
1810
|
+
cfg=cfg,
|
1760
1811
|
)
|
1761
1812
|
|
1762
1813
|
if save_intermediates:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: monai-weekly
|
3
|
-
Version: 1.5.
|
3
|
+
Version: 1.5.dev2522
|
4
4
|
Summary: AI Toolkit for Healthcare Imaging
|
5
5
|
Home-page: https://monai.io/
|
6
6
|
Author: MONAI Consortium
|
@@ -194,6 +194,15 @@ Its ambitions are as follows:
|
|
194
194
|
- customizable design for varying user expertise;
|
195
195
|
- multi-GPU multi-node data parallelism support.
|
196
196
|
|
197
|
+
## Requirements
|
198
|
+
|
199
|
+
MONAI works with the [currently supported versions of Python](https://devguide.python.org/versions), and depends directly on NumPy and PyTorch with many optional dependencies.
|
200
|
+
|
201
|
+
* Major releases of MONAI will have dependency versions stated for them. The current state of the `dev` branch in this repository is the unreleased development version of MONAI which typically will support current versions of dependencies and include updates and bug fixes to do so.
|
202
|
+
* PyTorch support covers [the current version](https://github.com/pytorch/pytorch/releases) plus three previous minor versions. If compatibility issues with a PyTorch version and other dependencies arise, support for a version may be delayed until a major release.
|
203
|
+
* Our support policy for other dependencies adheres for the most part to [SPEC0](https://scientific-python.org/specs/spec-0000), where dependency versions are supported where possible for up to two years. Discovered vulnerabilities or defects may require certain versions to be explicitly not supported.
|
204
|
+
* See the `requirements*.txt` files for dependency version information.
|
205
|
+
|
197
206
|
## Installation
|
198
207
|
|
199
208
|
To install [the current release](https://pypi.org/project/monai/), you can simply run:
|
@@ -1,5 +1,5 @@
|
|
1
|
-
monai/__init__.py,sha256=
|
2
|
-
monai/_version.py,sha256=
|
1
|
+
monai/__init__.py,sha256=eZQIXyDYcb7TrjIfc4e844rVm7n1JSyFVCrSaNR63h8,4095
|
2
|
+
monai/_version.py,sha256=7wXHOrAWUBKXGlWLCp1a8MS05DKSCeFNQSp2QC6OdFY,503
|
3
3
|
monai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
monai/_extensions/__init__.py,sha256=NEBPreRhQ8H9gVvgrLr_y52_TmqB96u_u4VQmeNT93I,642
|
5
5
|
monai/_extensions/loader.py,sha256=7SiKw36q-nOzH8CRbBurFrz7GM40GCu7rc93Tm8XpnI,3643
|
@@ -196,7 +196,7 @@ monai/handlers/trt_handler.py,sha256=uWFdgC8QKRkcNwWfKIbQMdK6-MX_1ON0mKabeIn1ltI
|
|
196
196
|
monai/handlers/utils.py,sha256=Ib1u-PLrtIkiLqTfREnrCWpN4af1btdNzkyMZuuuYyU,10239
|
197
197
|
monai/handlers/validation_handler.py,sha256=NZO21c6zzXbmAgJZHkkdoZQSQIHwuxh94QD3PLUldGU,3674
|
198
198
|
monai/inferers/__init__.py,sha256=K74t_RCeUPdEZvHzIPzVAwZ9DtmouLqhb3qDEmFBWs4,1107
|
199
|
-
monai/inferers/inferer.py,sha256=
|
199
|
+
monai/inferers/inferer.py,sha256=kfrxbSRAug_2oO943gQNHPfhfllu-DrseimT3vEb7uE,97280
|
200
200
|
monai/inferers/merger.py,sha256=JxSLdlXTKW1xug11UWQNi6dNtpqVRbGCLc-ifj06g8U,16613
|
201
201
|
monai/inferers/splitter.py,sha256=_hTnFdvDNRckkA7ZGQehVsNZw83oXoGFWyk5VXNqgJg,21149
|
202
202
|
monai/inferers/utils.py,sha256=dvZBCAjaPa8xXcJuXRzNQ-fBzteauzkKbxE5YZdGBGY,20374
|
@@ -426,7 +426,7 @@ monai/visualize/img2tensorboard.py,sha256=n4ztSa5BQAUxSTGvi2tp45v-F7-RNgSlbsdy-9
|
|
426
426
|
monai/visualize/occlusion_sensitivity.py,sha256=0SwhLO7ePDfIXJj67_UmXDZLxXItMeM-uNrPaCE0xXg,18159
|
427
427
|
monai/visualize/utils.py,sha256=B-MhTVs7sQbIqYS3yPnpBwPw2K82rE2PBtGIfpwZtWM,9894
|
428
428
|
monai/visualize/visualizer.py,sha256=qckyaMZCbezYUwE20k5yc-Pb7UozVavMDbrmyQwfYHY,1377
|
429
|
-
monai_weekly-1.5.
|
429
|
+
monai_weekly-1.5.dev2522.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
430
430
|
tests/apps/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
431
431
|
tests/apps/test_auto3dseg_bundlegen.py,sha256=FpTJo9Lfe8vdhGuWeZ9y1BQmqYwTt-s8mDVtoLGAz_I,5594
|
432
432
|
tests/apps/test_check_hash.py,sha256=MuZslW2DDCxHKEo6-PiL7hnbxGuZRRYf6HOh3ZQv1qQ,1761
|
@@ -508,7 +508,7 @@ tests/apps/vista3d/test_vista3d_sampler.py,sha256=-luQCe3Hhle2PC9AkFCUgK8gozOD0O
|
|
508
508
|
tests/apps/vista3d/test_vista3d_transforms.py,sha256=nAPiDBNWeXLoW7ax3HHL63t5jqzQ3HFa-6wTvdyqVJk,3280
|
509
509
|
tests/bundle/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
510
510
|
tests/bundle/test_bundle_ckpt_export.py,sha256=VnpigCoBAAc2lo0rWOpVMg0IYGB6vbHXL8xLtB1Pkio,4622
|
511
|
-
tests/bundle/test_bundle_download.py,sha256=
|
511
|
+
tests/bundle/test_bundle_download.py,sha256=ER3JgkeRFr4ae6_swRnvFLQwhNv2vXwp_9cNqqDf45M,19674
|
512
512
|
tests/bundle/test_bundle_get_data.py,sha256=lQh321mev_7fsLXRg0Tq5uEjuQILethDHRKzB6VV0o4,3667
|
513
513
|
tests/bundle/test_bundle_push_to_hf_hub.py,sha256=Zjl6xDwRKgkS6jvO5dzMBaTLEd4EXyMXp0_wzDNSY3g,1740
|
514
514
|
tests/bundle/test_bundle_trt_export.py,sha256=png-2SGjBSt46LXSz-PLprOXwJ0WkC_3YLR3Ibk_WBc,6344
|
@@ -575,9 +575,9 @@ tests/handlers/test_trt_compile.py,sha256=p8Gr2CJmBo6gG8w7bGlAO--nDHtQvy9Ld3jtua
|
|
575
575
|
tests/handlers/test_write_metrics_reports.py,sha256=oKGYR1plj1hSAu-ntbxkw_TD4O5hKPwVH_BS3MdHIbs,3027
|
576
576
|
tests/inferers/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
577
577
|
tests/inferers/test_avg_merger.py,sha256=lMR2PcNGFD6sfF6CjJTkahrAiMA5m5LUs5A11P6h8n0,5952
|
578
|
-
tests/inferers/test_controlnet_inferers.py,sha256=
|
579
|
-
tests/inferers/test_diffusion_inferer.py,sha256=
|
580
|
-
tests/inferers/test_latent_diffusion_inferer.py,sha256=
|
578
|
+
tests/inferers/test_controlnet_inferers.py,sha256=cSBhxb9tMaoRuVmL634sv46cbhxPFZOs6a1ghCtQTOE,52922
|
579
|
+
tests/inferers/test_diffusion_inferer.py,sha256=NbcxyqdaFzHG5VR8usLt2mhqFlzVSP3EohGVu-LNBUE,16848
|
580
|
+
tests/inferers/test_latent_diffusion_inferer.py,sha256=ZU-ULLM52Z8Ukvl8qtXum_QRx_HaQxv0T_W98aNY4uM,35642
|
581
581
|
tests/inferers/test_patch_inferer.py,sha256=LkYXWVn71vWinP-OJsIvq3FPH3jr36T7nKRIH5PzaqY,9878
|
582
582
|
tests/inferers/test_saliency_inferer.py,sha256=7miHRbA4yb_WGcxql6za9uXXoZlql_7y23f7IzsyIps,1949
|
583
583
|
tests/inferers/test_slice_inferer.py,sha256=kzaJjjTnf2rAiR75l8A_J-Kie4NaLj2bogi-aJ5L5mk,1897
|
@@ -1189,7 +1189,7 @@ tests/visualize/test_vis_gradcam.py,sha256=WpA-pvTB75eZs7JoIc5qyvOV9PwgkzWI8-Vow
|
|
1189
1189
|
tests/visualize/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
1190
1190
|
tests/visualize/utils/test_blend_images.py,sha256=RVs2p_8RWQDfhLHDNNtZaMig27v8o0km7XxNa-zWjKE,2274
|
1191
1191
|
tests/visualize/utils/test_matshow3d.py,sha256=wXYj77L5Jvnp0f6DvL1rsi_-YlCxS0HJ9hiPmrbpuP8,5021
|
1192
|
-
monai_weekly-1.5.
|
1193
|
-
monai_weekly-1.5.
|
1194
|
-
monai_weekly-1.5.
|
1195
|
-
monai_weekly-1.5.
|
1192
|
+
monai_weekly-1.5.dev2522.dist-info/METADATA,sha256=SECyXqQSdsULPcpimKHFcaPJko0nMWSVYjg31Za4OHU,13152
|
1193
|
+
monai_weekly-1.5.dev2522.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
1194
|
+
monai_weekly-1.5.dev2522.dist-info/top_level.txt,sha256=hn2Y6P9xBf2R8faMeVMHhPMvrdDKxMsIOwMDYI0yTjs,12
|
1195
|
+
monai_weekly-1.5.dev2522.dist-info/RECORD,,
|
@@ -316,20 +316,6 @@ class TestLoad(unittest.TestCase):
|
|
316
316
|
output_2 = model_2.forward(input_tensor)
|
317
317
|
assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False)
|
318
318
|
|
319
|
-
model_3 = load(
|
320
|
-
name=bundle_name,
|
321
|
-
model_file=model_file,
|
322
|
-
bundle_dir=tempdir,
|
323
|
-
progress=False,
|
324
|
-
device=device,
|
325
|
-
net_name=model_name,
|
326
|
-
source="github",
|
327
|
-
**net_args,
|
328
|
-
)
|
329
|
-
model_3.eval()
|
330
|
-
output_3 = model_3.forward(input_tensor)
|
331
|
-
assert_allclose(output_3, expected_output, atol=1e-4, rtol=1e-4, type_test=False)
|
332
|
-
|
333
319
|
@parameterized.expand([TEST_CASE_8])
|
334
320
|
@skip_if_quick
|
335
321
|
@skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.")
|
@@ -482,16 +482,20 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
482
482
|
scheduler = DDPMScheduler(num_train_timesteps=10)
|
483
483
|
inferer = ControlNetDiffusionInferer(scheduler=scheduler)
|
484
484
|
scheduler.set_timesteps(num_inference_steps=10)
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
485
|
+
|
486
|
+
for cfg in [5, None]:
|
487
|
+
sample, intermediates = inferer.sample(
|
488
|
+
input_noise=noise,
|
489
|
+
diffusion_model=model,
|
490
|
+
scheduler=scheduler,
|
491
|
+
controlnet=controlnet,
|
492
|
+
cn_cond=mask,
|
493
|
+
save_intermediates=True,
|
494
|
+
intermediate_steps=1,
|
495
|
+
cfg=cfg,
|
496
|
+
)
|
497
|
+
|
498
|
+
self.assertEqual(len(intermediates), 10)
|
495
499
|
|
496
500
|
@parameterized.expand(CNDM_TEST_CASES)
|
497
501
|
@skipUnless(has_einops, "Requires einops")
|
@@ -88,6 +88,27 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
|
|
88
88
|
)
|
89
89
|
self.assertEqual(len(intermediates), 10)
|
90
90
|
|
91
|
+
@parameterized.expand(TEST_CASES)
|
92
|
+
@skipUnless(has_einops, "Requires einops")
|
93
|
+
def test_sample_cfg(self, model_params, input_shape):
|
94
|
+
model = DiffusionModelUNet(**model_params)
|
95
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
96
|
+
model.to(device)
|
97
|
+
model.eval()
|
98
|
+
noise = torch.randn(input_shape).to(device)
|
99
|
+
scheduler = DDPMScheduler(num_train_timesteps=10)
|
100
|
+
inferer = DiffusionInferer(scheduler=scheduler)
|
101
|
+
scheduler.set_timesteps(num_inference_steps=10)
|
102
|
+
sample, intermediates = inferer.sample(
|
103
|
+
input_noise=noise,
|
104
|
+
diffusion_model=model,
|
105
|
+
scheduler=scheduler,
|
106
|
+
save_intermediates=True,
|
107
|
+
intermediate_steps=1,
|
108
|
+
cfg=5,
|
109
|
+
)
|
110
|
+
self.assertEqual(sample.shape, noise.shape)
|
111
|
+
|
91
112
|
@parameterized.expand(TEST_CASES)
|
92
113
|
@skipUnless(has_einops, "Requires einops")
|
93
114
|
def test_ddpm_sampler(self, model_params, input_shape):
|
@@ -244,6 +265,38 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
|
|
244
265
|
)
|
245
266
|
self.assertEqual(len(intermediates), 10)
|
246
267
|
|
268
|
+
@parameterized.expand(TEST_CASES)
|
269
|
+
@skipUnless(has_einops, "Requires einops")
|
270
|
+
def test_sampler_conditioned_concat_cfg(self, model_params, input_shape):
|
271
|
+
# copy the model_params dict to prevent from modifying test cases
|
272
|
+
model_params = model_params.copy()
|
273
|
+
n_concat_channel = 2
|
274
|
+
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
|
275
|
+
model_params["cross_attention_dim"] = None
|
276
|
+
model_params["with_conditioning"] = False
|
277
|
+
model = DiffusionModelUNet(**model_params)
|
278
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
279
|
+
model.to(device)
|
280
|
+
model.eval()
|
281
|
+
noise = torch.randn(input_shape).to(device)
|
282
|
+
conditioning_shape = list(input_shape)
|
283
|
+
conditioning_shape[1] = n_concat_channel
|
284
|
+
conditioning = torch.randn(conditioning_shape).to(device)
|
285
|
+
scheduler = DDIMScheduler(num_train_timesteps=1000)
|
286
|
+
inferer = DiffusionInferer(scheduler=scheduler)
|
287
|
+
scheduler.set_timesteps(num_inference_steps=10)
|
288
|
+
sample, intermediates = inferer.sample(
|
289
|
+
input_noise=noise,
|
290
|
+
diffusion_model=model,
|
291
|
+
scheduler=scheduler,
|
292
|
+
save_intermediates=True,
|
293
|
+
intermediate_steps=1,
|
294
|
+
conditioning=conditioning,
|
295
|
+
mode="concat",
|
296
|
+
cfg=5,
|
297
|
+
)
|
298
|
+
self.assertEqual(len(intermediates), 10)
|
299
|
+
|
247
300
|
@parameterized.expand(TEST_CASES)
|
248
301
|
@skipUnless(has_einops, "Requires einops")
|
249
302
|
def test_sampler_conditioned_concat_rflow(self, model_params, input_shape):
|
@@ -414,6 +414,55 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
|
|
414
414
|
)
|
415
415
|
self.assertEqual(sample.shape, input_shape)
|
416
416
|
|
417
|
+
@parameterized.expand(TEST_CASES)
|
418
|
+
@skipUnless(has_einops, "Requires einops")
|
419
|
+
def test_sample_shape_with_cfg(
|
420
|
+
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
|
421
|
+
):
|
422
|
+
stage_1 = None
|
423
|
+
|
424
|
+
if ae_model_type == "AutoencoderKL":
|
425
|
+
stage_1 = AutoencoderKL(**autoencoder_params)
|
426
|
+
if ae_model_type == "VQVAE":
|
427
|
+
stage_1 = VQVAE(**autoencoder_params)
|
428
|
+
if dm_model_type == "SPADEDiffusionModelUNet":
|
429
|
+
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
|
430
|
+
else:
|
431
|
+
stage_2 = DiffusionModelUNet(**stage_2_params)
|
432
|
+
|
433
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
434
|
+
stage_1.to(device)
|
435
|
+
stage_2.to(device)
|
436
|
+
stage_1.eval()
|
437
|
+
stage_2.eval()
|
438
|
+
|
439
|
+
noise = torch.randn(latent_shape).to(device)
|
440
|
+
|
441
|
+
for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
|
442
|
+
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
|
443
|
+
scheduler.set_timesteps(num_inference_steps=10)
|
444
|
+
|
445
|
+
if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
|
446
|
+
input_shape_seg = list(input_shape)
|
447
|
+
if "label_nc" in stage_2_params.keys():
|
448
|
+
input_shape_seg[1] = stage_2_params["label_nc"]
|
449
|
+
else:
|
450
|
+
input_shape_seg[1] = autoencoder_params["label_nc"]
|
451
|
+
input_seg = torch.randn(input_shape_seg).to(device)
|
452
|
+
sample = inferer.sample(
|
453
|
+
input_noise=noise,
|
454
|
+
autoencoder_model=stage_1,
|
455
|
+
diffusion_model=stage_2,
|
456
|
+
scheduler=scheduler,
|
457
|
+
seg=input_seg,
|
458
|
+
cfg=5,
|
459
|
+
)
|
460
|
+
else:
|
461
|
+
sample = inferer.sample(
|
462
|
+
input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, cfg=5
|
463
|
+
)
|
464
|
+
self.assertEqual(sample.shape, input_shape)
|
465
|
+
|
417
466
|
@parameterized.expand(TEST_CASES)
|
418
467
|
@skipUnless(has_einops, "Requires einops")
|
419
468
|
def test_sample_intermediates(
|
File without changes
|
File without changes
|