monai-weekly 1.5.dev2508__py3-none-any.whl → 1.5.dev2509__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/bundle/scripts.py +39 -16
- monai/inferers/inferer.py +29 -11
- {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2509.dist-info}/METADATA +1 -1
- {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2509.dist-info}/RECORD +10 -10
- {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2509.dist-info}/WHEEL +1 -1
- tests/inferers/test_controlnet_inferers.py +9 -0
- {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2509.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2509.dist-info}/top_level.txt +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2025-
|
11
|
+
"date": "2025-03-02T02:29:03+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.5.
|
14
|
+
"full-revisionid": "5f85a7bfd54b91be03213999a7c177bfe2d583b2",
|
15
|
+
"version": "1.5.dev2509"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
monai/bundle/scripts.py
CHANGED
@@ -15,6 +15,7 @@ import ast
|
|
15
15
|
import json
|
16
16
|
import os
|
17
17
|
import re
|
18
|
+
import urllib
|
18
19
|
import warnings
|
19
20
|
import zipfile
|
20
21
|
from collections.abc import Mapping, Sequence
|
@@ -58,7 +59,7 @@ from monai.utils import (
|
|
58
59
|
validate, _ = optional_import("jsonschema", name="validate")
|
59
60
|
ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
|
60
61
|
Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
|
61
|
-
|
62
|
+
requests, has_requests = optional_import("requests")
|
62
63
|
onnx, _ = optional_import("onnx")
|
63
64
|
huggingface_hub, _ = optional_import("huggingface_hub")
|
64
65
|
|
@@ -206,6 +207,16 @@ def _download_from_monaihosting(download_path: Path, filename: str, version: str
|
|
206
207
|
extractall(filepath=filepath, output_dir=download_path, has_base=True)
|
207
208
|
|
208
209
|
|
210
|
+
def _download_from_bundle_info(download_path: Path, filename: str, version: str, progress: bool) -> None:
|
211
|
+
bundle_info = get_bundle_info(bundle_name=filename, version=version)
|
212
|
+
if not bundle_info:
|
213
|
+
raise ValueError(f"Bundle info not found for {filename} v{version}.")
|
214
|
+
url = bundle_info["browser_download_url"]
|
215
|
+
filepath = download_path / f"{filename}_v{version}.zip"
|
216
|
+
download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
|
217
|
+
extractall(filepath=filepath, output_dir=download_path, has_base=True)
|
218
|
+
|
219
|
+
|
209
220
|
def _add_ngc_prefix(name: str, prefix: str = "monai_") -> str:
|
210
221
|
if name.startswith(prefix):
|
211
222
|
return name
|
@@ -222,7 +233,7 @@ def _get_all_download_files(request_url: str, headers: dict | None = None) -> li
|
|
222
233
|
if not has_requests:
|
223
234
|
raise ValueError("requests package is required, please install it.")
|
224
235
|
headers = {} if headers is None else headers
|
225
|
-
response =
|
236
|
+
response = requests.get(request_url, headers=headers)
|
226
237
|
response.raise_for_status()
|
227
238
|
model_info = json.loads(response.text)
|
228
239
|
|
@@ -266,7 +277,7 @@ def _download_from_ngc_private(
|
|
266
277
|
request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
|
267
278
|
if has_requests:
|
268
279
|
headers = {} if headers is None else headers
|
269
|
-
response =
|
280
|
+
response = requests.get(request_url, headers=headers)
|
270
281
|
response.raise_for_status()
|
271
282
|
else:
|
272
283
|
raise ValueError("NGC API requires requests package. Please install it.")
|
@@ -289,7 +300,7 @@ def _get_ngc_token(api_key, retry=0):
|
|
289
300
|
url = "https://authn.nvidia.com/token?service=ngc"
|
290
301
|
headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key}
|
291
302
|
if has_requests:
|
292
|
-
response =
|
303
|
+
response = requests.get(url, headers=headers)
|
293
304
|
if not response.ok:
|
294
305
|
# retry 3 times, if failed, raise an error.
|
295
306
|
if retry < 3:
|
@@ -303,14 +314,17 @@ def _get_ngc_token(api_key, retry=0):
|
|
303
314
|
|
304
315
|
def _get_latest_bundle_version_monaihosting(name):
|
305
316
|
full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}"
|
306
|
-
requests_get, has_requests = optional_import("requests", name="get")
|
307
317
|
if has_requests:
|
308
|
-
resp =
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
318
|
+
resp = requests.get(full_url)
|
319
|
+
try:
|
320
|
+
resp.raise_for_status()
|
321
|
+
model_info = json.loads(resp.text)
|
322
|
+
return model_info["model"]["latestVersionIdStr"]
|
323
|
+
except requests.exceptions.HTTPError:
|
324
|
+
# for monaihosting bundles, if cannot find the version, get from model zoo model_info.json
|
325
|
+
return get_bundle_versions(name)["latest_version"]
|
326
|
+
|
327
|
+
raise ValueError("NGC API requires requests package. Please install it.")
|
314
328
|
|
315
329
|
|
316
330
|
def _examine_monai_version(monai_version: str) -> tuple[bool, str]:
|
@@ -388,14 +402,14 @@ def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers:
|
|
388
402
|
version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements
|
389
403
|
if headers:
|
390
404
|
version_header.update(headers)
|
391
|
-
resp =
|
405
|
+
resp = requests.get(version_endpoint, headers=version_header)
|
392
406
|
resp.raise_for_status()
|
393
407
|
model_info = json.loads(resp.text)
|
394
408
|
latest_versions = _list_latest_versions(model_info)
|
395
409
|
|
396
410
|
for version in latest_versions:
|
397
411
|
file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json"
|
398
|
-
resp =
|
412
|
+
resp = requests.get(file_endpoint, headers=headers)
|
399
413
|
metadata = json.loads(resp.text)
|
400
414
|
resp.raise_for_status()
|
401
415
|
# if the package version is not available or the model is compatible with the package version
|
@@ -585,7 +599,16 @@ def download(
|
|
585
599
|
name_ver = "_v".join([name_, version_]) if version_ is not None else name_
|
586
600
|
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)
|
587
601
|
elif source_ == "monaihosting":
|
588
|
-
|
602
|
+
try:
|
603
|
+
_download_from_monaihosting(
|
604
|
+
download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
|
605
|
+
)
|
606
|
+
except urllib.error.HTTPError:
|
607
|
+
# for monaihosting bundles, if cannot download from default host, download according to bundle_info
|
608
|
+
_download_from_bundle_info(
|
609
|
+
download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
|
610
|
+
)
|
611
|
+
|
589
612
|
elif source_ == "ngc":
|
590
613
|
_download_from_ngc(
|
591
614
|
download_path=bundle_dir_,
|
@@ -792,9 +815,9 @@ def _get_all_bundles_info(
|
|
792
815
|
|
793
816
|
if auth_token is not None:
|
794
817
|
headers = {"Authorization": f"Bearer {auth_token}"}
|
795
|
-
resp =
|
818
|
+
resp = requests.get(request_url, headers=headers)
|
796
819
|
else:
|
797
|
-
resp =
|
820
|
+
resp = requests.get(request_url)
|
798
821
|
resp.raise_for_status()
|
799
822
|
else:
|
800
823
|
raise ValueError("requests package is required, please install it.")
|
monai/inferers/inferer.py
CHANGED
@@ -1334,13 +1334,15 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1334
1334
|
raise NotImplementedError(f"{mode} condition is not supported")
|
1335
1335
|
|
1336
1336
|
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
|
1337
|
-
|
1338
|
-
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
|
1339
|
-
)
|
1337
|
+
|
1340
1338
|
if mode == "concat" and condition is not None:
|
1341
1339
|
noisy_image = torch.cat([noisy_image, condition], dim=1)
|
1342
1340
|
condition = None
|
1343
1341
|
|
1342
|
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
1343
|
+
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition
|
1344
|
+
)
|
1345
|
+
|
1344
1346
|
diffuse = diffusion_model
|
1345
1347
|
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
|
1346
1348
|
diffuse = partial(diffusion_model, seg=seg)
|
@@ -1396,17 +1398,21 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1396
1398
|
progress_bar = iter(scheduler.timesteps)
|
1397
1399
|
intermediates = []
|
1398
1400
|
for t in progress_bar:
|
1399
|
-
# 1. ControlNet forward
|
1400
|
-
down_block_res_samples, mid_block_res_sample = controlnet(
|
1401
|
-
x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
|
1402
|
-
)
|
1403
|
-
# 2. predict noise model_output
|
1404
1401
|
diffuse = diffusion_model
|
1405
1402
|
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
|
1406
1403
|
diffuse = partial(diffusion_model, seg=seg)
|
1407
1404
|
|
1408
1405
|
if mode == "concat" and conditioning is not None:
|
1406
|
+
# 1. Conditioning
|
1409
1407
|
model_input = torch.cat([image, conditioning], dim=1)
|
1408
|
+
# 2. ControlNet forward
|
1409
|
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
1410
|
+
x=model_input,
|
1411
|
+
timesteps=torch.Tensor((t,)).to(input_noise.device),
|
1412
|
+
controlnet_cond=cn_cond,
|
1413
|
+
context=None,
|
1414
|
+
)
|
1415
|
+
# 3. predict noise model_output
|
1410
1416
|
model_output = diffuse(
|
1411
1417
|
model_input,
|
1412
1418
|
timesteps=torch.Tensor((t,)).to(input_noise.device),
|
@@ -1415,6 +1421,12 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1415
1421
|
mid_block_additional_residual=mid_block_res_sample,
|
1416
1422
|
)
|
1417
1423
|
else:
|
1424
|
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
1425
|
+
x=image,
|
1426
|
+
timesteps=torch.Tensor((t,)).to(input_noise.device),
|
1427
|
+
controlnet_cond=cn_cond,
|
1428
|
+
context=conditioning,
|
1429
|
+
)
|
1418
1430
|
model_output = diffuse(
|
1419
1431
|
image,
|
1420
1432
|
timesteps=torch.Tensor((t,)).to(input_noise.device),
|
@@ -1485,9 +1497,6 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1485
1497
|
for t in progress_bar:
|
1486
1498
|
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
|
1487
1499
|
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
|
1488
|
-
down_block_res_samples, mid_block_res_sample = controlnet(
|
1489
|
-
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
|
1490
|
-
)
|
1491
1500
|
|
1492
1501
|
diffuse = diffusion_model
|
1493
1502
|
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
|
@@ -1495,6 +1504,9 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1495
1504
|
|
1496
1505
|
if mode == "concat" and conditioning is not None:
|
1497
1506
|
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
|
1507
|
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
1508
|
+
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond, context=None
|
1509
|
+
)
|
1498
1510
|
model_output = diffuse(
|
1499
1511
|
noisy_image,
|
1500
1512
|
timesteps=timesteps,
|
@@ -1503,6 +1515,12 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1503
1515
|
mid_block_additional_residual=mid_block_res_sample,
|
1504
1516
|
)
|
1505
1517
|
else:
|
1518
|
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
1519
|
+
x=noisy_image,
|
1520
|
+
timesteps=torch.Tensor((t,)).to(inputs.device),
|
1521
|
+
controlnet_cond=cn_cond,
|
1522
|
+
context=conditioning,
|
1523
|
+
)
|
1506
1524
|
model_output = diffuse(
|
1507
1525
|
x=noisy_image,
|
1508
1526
|
timesteps=timesteps,
|
@@ -1,5 +1,5 @@
|
|
1
|
-
monai/__init__.py,sha256=
|
2
|
-
monai/_version.py,sha256=
|
1
|
+
monai/__init__.py,sha256=2QSN66gMNzIDVAeBWVrsS3xgXmpc90Ksxr0j3D3KLiQ,4095
|
2
|
+
monai/_version.py,sha256=3pISgTcfhG3j_LA8zhH9EcyDi6PgzKxbNALoD_5HCps,503
|
3
3
|
monai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
monai/_extensions/__init__.py,sha256=NEBPreRhQ8H9gVvgrLr_y52_TmqB96u_u4VQmeNT93I,642
|
5
5
|
monai/_extensions/loader.py,sha256=7SiKw36q-nOzH8CRbBurFrz7GM40GCu7rc93Tm8XpnI,3643
|
@@ -114,7 +114,7 @@ monai/bundle/config_item.py,sha256=rMjXSGkjJZdi04BwSHwCcIwzIb_TflmC3xDhC3SVJRs,1
|
|
114
114
|
monai/bundle/config_parser.py,sha256=cGyEn-cqNk0rEEZ1Qiv6UydmIDvtWZcMVljyfVm5i50,23025
|
115
115
|
monai/bundle/properties.py,sha256=iN3K4FVmN9ny1Hw9p5j7_ULcCdSD8PmrR7qXxbNz49k,11582
|
116
116
|
monai/bundle/reference_resolver.py,sha256=GXCMK4iogxxE6VocsmAbUrcXosmC5arnjeG9zYhHKpg,16748
|
117
|
-
monai/bundle/scripts.py,sha256=
|
117
|
+
monai/bundle/scripts.py,sha256=p7wlT0BplTIdW4DbxRPotf_tLsgddvtklW1kcAEPBZQ,91016
|
118
118
|
monai/bundle/utils.py,sha256=t-22uFvLn7Yy-dr1v1U33peNOxgAmU4TJiGAbsBrUKs,10108
|
119
119
|
monai/bundle/workflows.py,sha256=CuhmFq1AWsN3ATiYJCSakPOxrOdGutl6vkpo9sxe8gU,34369
|
120
120
|
monai/config/__init__.py,sha256=CN28CfTdsp301gv8YXfVvkbztCfbAqrLKrJi_C8oP9s,1048
|
@@ -195,7 +195,7 @@ monai/handlers/trt_handler.py,sha256=uWFdgC8QKRkcNwWfKIbQMdK6-MX_1ON0mKabeIn1ltI
|
|
195
195
|
monai/handlers/utils.py,sha256=Ib1u-PLrtIkiLqTfREnrCWpN4af1btdNzkyMZuuuYyU,10239
|
196
196
|
monai/handlers/validation_handler.py,sha256=NZO21c6zzXbmAgJZHkkdoZQSQIHwuxh94QD3PLUldGU,3674
|
197
197
|
monai/inferers/__init__.py,sha256=K74t_RCeUPdEZvHzIPzVAwZ9DtmouLqhb3qDEmFBWs4,1107
|
198
|
-
monai/inferers/inferer.py,sha256=
|
198
|
+
monai/inferers/inferer.py,sha256=_VPnBIErwYzbrJIA9eMMalSso1pSsc_8cONVUUvPJOw,93549
|
199
199
|
monai/inferers/merger.py,sha256=dZm-FVyXPlFb59q4DG52mbtPm8Iy4cNFWv3un0Z8k0M,16262
|
200
200
|
monai/inferers/splitter.py,sha256=_hTnFdvDNRckkA7ZGQehVsNZw83oXoGFWyk5VXNqgJg,21149
|
201
201
|
monai/inferers/utils.py,sha256=dvZBCAjaPa8xXcJuXRzNQ-fBzteauzkKbxE5YZdGBGY,20374
|
@@ -570,7 +570,7 @@ tests/handlers/test_trt_compile.py,sha256=p8Gr2CJmBo6gG8w7bGlAO--nDHtQvy9Ld3jtua
|
|
570
570
|
tests/handlers/test_write_metrics_reports.py,sha256=oKGYR1plj1hSAu-ntbxkw_TD4O5hKPwVH_BS3MdHIbs,3027
|
571
571
|
tests/inferers/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
572
572
|
tests/inferers/test_avg_merger.py,sha256=lMR2PcNGFD6sfF6CjJTkahrAiMA5m5LUs5A11P6h8n0,5952
|
573
|
-
tests/inferers/test_controlnet_inferers.py,sha256=
|
573
|
+
tests/inferers/test_controlnet_inferers.py,sha256=sWs5vkZHa-D0V3tWJ6149Z-RNq0for_XngDYxZRl_Ao,50285
|
574
574
|
tests/inferers/test_diffusion_inferer.py,sha256=1O2V_bEmifOZ4RvpbZgYUCooiJ97T73avaBuMJPpBs0,9992
|
575
575
|
tests/inferers/test_latent_diffusion_inferer.py,sha256=atJjmfVznUq8z9EjohFIMyA0Q1XT1Ly0Zepf_1xPz5I,32274
|
576
576
|
tests/inferers/test_patch_inferer.py,sha256=LkYXWVn71vWinP-OJsIvq3FPH3jr36T7nKRIH5PzaqY,9878
|
@@ -1178,8 +1178,8 @@ tests/visualize/test_vis_gradcam.py,sha256=WpA-pvTB75eZs7JoIc5qyvOV9PwgkzWI8-Vow
|
|
1178
1178
|
tests/visualize/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
1179
1179
|
tests/visualize/utils/test_blend_images.py,sha256=RVs2p_8RWQDfhLHDNNtZaMig27v8o0km7XxNa-zWjKE,2274
|
1180
1180
|
tests/visualize/utils/test_matshow3d.py,sha256=wXYj77L5Jvnp0f6DvL1rsi_-YlCxS0HJ9hiPmrbpuP8,5021
|
1181
|
-
monai_weekly-1.5.
|
1182
|
-
monai_weekly-1.5.
|
1183
|
-
monai_weekly-1.5.
|
1184
|
-
monai_weekly-1.5.
|
1185
|
-
monai_weekly-1.5.
|
1181
|
+
monai_weekly-1.5.dev2509.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
1182
|
+
monai_weekly-1.5.dev2509.dist-info/METADATA,sha256=h7L3w9XhzSfoxC5yRoqgKS_NeECPEORKyEX4E1WS6Vc,11909
|
1183
|
+
monai_weekly-1.5.dev2509.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
|
1184
|
+
monai_weekly-1.5.dev2509.dist-info/top_level.txt,sha256=hn2Y6P9xBf2R8faMeVMHhPMvrdDKxMsIOwMDYI0yTjs,12
|
1185
|
+
monai_weekly-1.5.dev2509.dist-info/RECORD,,
|
@@ -550,6 +550,8 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
550
550
|
def test_sampler_conditioned(self, model_params, controlnet_params, input_shape):
|
551
551
|
model_params["with_conditioning"] = True
|
552
552
|
model_params["cross_attention_dim"] = 3
|
553
|
+
controlnet_params["with_conditioning"] = True
|
554
|
+
controlnet_params["cross_attention_dim"] = 3
|
553
555
|
model = DiffusionModelUNet(**model_params)
|
554
556
|
controlnet = ControlNet(**controlnet_params)
|
555
557
|
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
@@ -619,8 +621,11 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
619
621
|
model_params = model_params.copy()
|
620
622
|
n_concat_channel = 2
|
621
623
|
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
|
624
|
+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
|
622
625
|
model_params["cross_attention_dim"] = None
|
626
|
+
controlnet_params["cross_attention_dim"] = None
|
623
627
|
model_params["with_conditioning"] = False
|
628
|
+
controlnet_params["with_conditioning"] = False
|
624
629
|
model = DiffusionModelUNet(**model_params)
|
625
630
|
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
626
631
|
model.to(device)
|
@@ -1023,8 +1028,10 @@ class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
1023
1028
|
if ae_model_type == "SPADEAutoencoderKL":
|
1024
1029
|
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
|
1025
1030
|
stage_2_params = stage_2_params.copy()
|
1031
|
+
controlnet_params = controlnet_params.copy()
|
1026
1032
|
n_concat_channel = 3
|
1027
1033
|
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
|
1034
|
+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
|
1028
1035
|
if dm_model_type == "SPADEDiffusionModelUNet":
|
1029
1036
|
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
|
1030
1037
|
else:
|
@@ -1106,8 +1113,10 @@ class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
1106
1113
|
if ae_model_type == "SPADEAutoencoderKL":
|
1107
1114
|
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
|
1108
1115
|
stage_2_params = stage_2_params.copy()
|
1116
|
+
controlnet_params = controlnet_params.copy()
|
1109
1117
|
n_concat_channel = 3
|
1110
1118
|
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
|
1119
|
+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
|
1111
1120
|
if dm_model_type == "SPADEDiffusionModelUNet":
|
1112
1121
|
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
|
1113
1122
|
else:
|
File without changes
|
File without changes
|