monai-weekly 1.5.dev2508__py3-none-any.whl → 1.5.dev2509__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monai/__init__.py CHANGED
@@ -136,4 +136,4 @@ except BaseException:
136
136
 
137
137
  if MONAIEnvVars.debug():
138
138
  raise
139
- __commit_id__ = "a7905909e785d1ef24103c32a2d3a5a36e1059a2"
139
+ __commit_id__ = "a09c1f08461cec3d2131fde3939ef38c3c4ad5fc"
monai/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-02-23T02:28:09+0000",
11
+ "date": "2025-03-02T02:29:03+0000",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "e55b5cbfbbba1800a968a9c06b2deaaa5c9bec54",
15
- "version": "1.5.dev2508"
14
+ "full-revisionid": "5f85a7bfd54b91be03213999a7c177bfe2d583b2",
15
+ "version": "1.5.dev2509"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
monai/bundle/scripts.py CHANGED
@@ -15,6 +15,7 @@ import ast
15
15
  import json
16
16
  import os
17
17
  import re
18
+ import urllib
18
19
  import warnings
19
20
  import zipfile
20
21
  from collections.abc import Mapping, Sequence
@@ -58,7 +59,7 @@ from monai.utils import (
58
59
  validate, _ = optional_import("jsonschema", name="validate")
59
60
  ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
60
61
  Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
61
- requests_get, has_requests = optional_import("requests", name="get")
62
+ requests, has_requests = optional_import("requests")
62
63
  onnx, _ = optional_import("onnx")
63
64
  huggingface_hub, _ = optional_import("huggingface_hub")
64
65
 
@@ -206,6 +207,16 @@ def _download_from_monaihosting(download_path: Path, filename: str, version: str
206
207
  extractall(filepath=filepath, output_dir=download_path, has_base=True)
207
208
 
208
209
 
210
+ def _download_from_bundle_info(download_path: Path, filename: str, version: str, progress: bool) -> None:
211
+ bundle_info = get_bundle_info(bundle_name=filename, version=version)
212
+ if not bundle_info:
213
+ raise ValueError(f"Bundle info not found for {filename} v{version}.")
214
+ url = bundle_info["browser_download_url"]
215
+ filepath = download_path / f"{filename}_v{version}.zip"
216
+ download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
217
+ extractall(filepath=filepath, output_dir=download_path, has_base=True)
218
+
219
+
209
220
  def _add_ngc_prefix(name: str, prefix: str = "monai_") -> str:
210
221
  if name.startswith(prefix):
211
222
  return name
@@ -222,7 +233,7 @@ def _get_all_download_files(request_url: str, headers: dict | None = None) -> li
222
233
  if not has_requests:
223
234
  raise ValueError("requests package is required, please install it.")
224
235
  headers = {} if headers is None else headers
225
- response = requests_get(request_url, headers=headers)
236
+ response = requests.get(request_url, headers=headers)
226
237
  response.raise_for_status()
227
238
  model_info = json.loads(response.text)
228
239
 
@@ -266,7 +277,7 @@ def _download_from_ngc_private(
266
277
  request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
267
278
  if has_requests:
268
279
  headers = {} if headers is None else headers
269
- response = requests_get(request_url, headers=headers)
280
+ response = requests.get(request_url, headers=headers)
270
281
  response.raise_for_status()
271
282
  else:
272
283
  raise ValueError("NGC API requires requests package. Please install it.")
@@ -289,7 +300,7 @@ def _get_ngc_token(api_key, retry=0):
289
300
  url = "https://authn.nvidia.com/token?service=ngc"
290
301
  headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key}
291
302
  if has_requests:
292
- response = requests_get(url, headers=headers)
303
+ response = requests.get(url, headers=headers)
293
304
  if not response.ok:
294
305
  # retry 3 times, if failed, raise an error.
295
306
  if retry < 3:
@@ -303,14 +314,17 @@ def _get_ngc_token(api_key, retry=0):
303
314
 
304
315
  def _get_latest_bundle_version_monaihosting(name):
305
316
  full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}"
306
- requests_get, has_requests = optional_import("requests", name="get")
307
317
  if has_requests:
308
- resp = requests_get(full_url)
309
- resp.raise_for_status()
310
- else:
311
- raise ValueError("NGC API requires requests package. Please install it.")
312
- model_info = json.loads(resp.text)
313
- return model_info["model"]["latestVersionIdStr"]
318
+ resp = requests.get(full_url)
319
+ try:
320
+ resp.raise_for_status()
321
+ model_info = json.loads(resp.text)
322
+ return model_info["model"]["latestVersionIdStr"]
323
+ except requests.exceptions.HTTPError:
324
+ # for monaihosting bundles, if cannot find the version, get from model zoo model_info.json
325
+ return get_bundle_versions(name)["latest_version"]
326
+
327
+ raise ValueError("NGC API requires requests package. Please install it.")
314
328
 
315
329
 
316
330
  def _examine_monai_version(monai_version: str) -> tuple[bool, str]:
@@ -388,14 +402,14 @@ def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers:
388
402
  version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements
389
403
  if headers:
390
404
  version_header.update(headers)
391
- resp = requests_get(version_endpoint, headers=version_header)
405
+ resp = requests.get(version_endpoint, headers=version_header)
392
406
  resp.raise_for_status()
393
407
  model_info = json.loads(resp.text)
394
408
  latest_versions = _list_latest_versions(model_info)
395
409
 
396
410
  for version in latest_versions:
397
411
  file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json"
398
- resp = requests_get(file_endpoint, headers=headers)
412
+ resp = requests.get(file_endpoint, headers=headers)
399
413
  metadata = json.loads(resp.text)
400
414
  resp.raise_for_status()
401
415
  # if the package version is not available or the model is compatible with the package version
@@ -585,7 +599,16 @@ def download(
585
599
  name_ver = "_v".join([name_, version_]) if version_ is not None else name_
586
600
  _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)
587
601
  elif source_ == "monaihosting":
588
- _download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_)
602
+ try:
603
+ _download_from_monaihosting(
604
+ download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
605
+ )
606
+ except urllib.error.HTTPError:
607
+ # for monaihosting bundles, if cannot download from default host, download according to bundle_info
608
+ _download_from_bundle_info(
609
+ download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
610
+ )
611
+
589
612
  elif source_ == "ngc":
590
613
  _download_from_ngc(
591
614
  download_path=bundle_dir_,
@@ -792,9 +815,9 @@ def _get_all_bundles_info(
792
815
 
793
816
  if auth_token is not None:
794
817
  headers = {"Authorization": f"Bearer {auth_token}"}
795
- resp = requests_get(request_url, headers=headers)
818
+ resp = requests.get(request_url, headers=headers)
796
819
  else:
797
- resp = requests_get(request_url)
820
+ resp = requests.get(request_url)
798
821
  resp.raise_for_status()
799
822
  else:
800
823
  raise ValueError("requests package is required, please install it.")
monai/inferers/inferer.py CHANGED
@@ -1334,13 +1334,15 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1334
1334
  raise NotImplementedError(f"{mode} condition is not supported")
1335
1335
 
1336
1336
  noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1337
- down_block_res_samples, mid_block_res_sample = controlnet(
1338
- x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
1339
- )
1337
+
1340
1338
  if mode == "concat" and condition is not None:
1341
1339
  noisy_image = torch.cat([noisy_image, condition], dim=1)
1342
1340
  condition = None
1343
1341
 
1342
+ down_block_res_samples, mid_block_res_sample = controlnet(
1343
+ x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition
1344
+ )
1345
+
1344
1346
  diffuse = diffusion_model
1345
1347
  if isinstance(diffusion_model, SPADEDiffusionModelUNet):
1346
1348
  diffuse = partial(diffusion_model, seg=seg)
@@ -1396,17 +1398,21 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1396
1398
  progress_bar = iter(scheduler.timesteps)
1397
1399
  intermediates = []
1398
1400
  for t in progress_bar:
1399
- # 1. ControlNet forward
1400
- down_block_res_samples, mid_block_res_sample = controlnet(
1401
- x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
1402
- )
1403
- # 2. predict noise model_output
1404
1401
  diffuse = diffusion_model
1405
1402
  if isinstance(diffusion_model, SPADEDiffusionModelUNet):
1406
1403
  diffuse = partial(diffusion_model, seg=seg)
1407
1404
 
1408
1405
  if mode == "concat" and conditioning is not None:
1406
+ # 1. Conditioning
1409
1407
  model_input = torch.cat([image, conditioning], dim=1)
1408
+ # 2. ControlNet forward
1409
+ down_block_res_samples, mid_block_res_sample = controlnet(
1410
+ x=model_input,
1411
+ timesteps=torch.Tensor((t,)).to(input_noise.device),
1412
+ controlnet_cond=cn_cond,
1413
+ context=None,
1414
+ )
1415
+ # 3. predict noise model_output
1410
1416
  model_output = diffuse(
1411
1417
  model_input,
1412
1418
  timesteps=torch.Tensor((t,)).to(input_noise.device),
@@ -1415,6 +1421,12 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1415
1421
  mid_block_additional_residual=mid_block_res_sample,
1416
1422
  )
1417
1423
  else:
1424
+ down_block_res_samples, mid_block_res_sample = controlnet(
1425
+ x=image,
1426
+ timesteps=torch.Tensor((t,)).to(input_noise.device),
1427
+ controlnet_cond=cn_cond,
1428
+ context=conditioning,
1429
+ )
1418
1430
  model_output = diffuse(
1419
1431
  image,
1420
1432
  timesteps=torch.Tensor((t,)).to(input_noise.device),
@@ -1485,9 +1497,6 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1485
1497
  for t in progress_bar:
1486
1498
  timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
1487
1499
  noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1488
- down_block_res_samples, mid_block_res_sample = controlnet(
1489
- x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
1490
- )
1491
1500
 
1492
1501
  diffuse = diffusion_model
1493
1502
  if isinstance(diffusion_model, SPADEDiffusionModelUNet):
@@ -1495,6 +1504,9 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1495
1504
 
1496
1505
  if mode == "concat" and conditioning is not None:
1497
1506
  noisy_image = torch.cat([noisy_image, conditioning], dim=1)
1507
+ down_block_res_samples, mid_block_res_sample = controlnet(
1508
+ x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond, context=None
1509
+ )
1498
1510
  model_output = diffuse(
1499
1511
  noisy_image,
1500
1512
  timesteps=timesteps,
@@ -1503,6 +1515,12 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1503
1515
  mid_block_additional_residual=mid_block_res_sample,
1504
1516
  )
1505
1517
  else:
1518
+ down_block_res_samples, mid_block_res_sample = controlnet(
1519
+ x=noisy_image,
1520
+ timesteps=torch.Tensor((t,)).to(inputs.device),
1521
+ controlnet_cond=cn_cond,
1522
+ context=conditioning,
1523
+ )
1506
1524
  model_output = diffuse(
1507
1525
  x=noisy_image,
1508
1526
  timesteps=timesteps,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: monai-weekly
3
- Version: 1.5.dev2508
3
+ Version: 1.5.dev2509
4
4
  Summary: AI Toolkit for Healthcare Imaging
5
5
  Home-page: https://monai.io/
6
6
  Author: MONAI Consortium
@@ -1,5 +1,5 @@
1
- monai/__init__.py,sha256=jHqt9Fx6mJlpL9TD8eihfJTg6IGs40j8bCpjE3PFrVI,4095
2
- monai/_version.py,sha256=sQZ38u2mKWN9p59gP2DeDhflJxmQX4ckQZtIE_MCnbg,503
1
+ monai/__init__.py,sha256=2QSN66gMNzIDVAeBWVrsS3xgXmpc90Ksxr0j3D3KLiQ,4095
2
+ monai/_version.py,sha256=3pISgTcfhG3j_LA8zhH9EcyDi6PgzKxbNALoD_5HCps,503
3
3
  monai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  monai/_extensions/__init__.py,sha256=NEBPreRhQ8H9gVvgrLr_y52_TmqB96u_u4VQmeNT93I,642
5
5
  monai/_extensions/loader.py,sha256=7SiKw36q-nOzH8CRbBurFrz7GM40GCu7rc93Tm8XpnI,3643
@@ -114,7 +114,7 @@ monai/bundle/config_item.py,sha256=rMjXSGkjJZdi04BwSHwCcIwzIb_TflmC3xDhC3SVJRs,1
114
114
  monai/bundle/config_parser.py,sha256=cGyEn-cqNk0rEEZ1Qiv6UydmIDvtWZcMVljyfVm5i50,23025
115
115
  monai/bundle/properties.py,sha256=iN3K4FVmN9ny1Hw9p5j7_ULcCdSD8PmrR7qXxbNz49k,11582
116
116
  monai/bundle/reference_resolver.py,sha256=GXCMK4iogxxE6VocsmAbUrcXosmC5arnjeG9zYhHKpg,16748
117
- monai/bundle/scripts.py,sha256=D0GnyZF0FCQyHZtqjoX9jen3IiAvnUeM1mtSWa2fu4E,89935
117
+ monai/bundle/scripts.py,sha256=p7wlT0BplTIdW4DbxRPotf_tLsgddvtklW1kcAEPBZQ,91016
118
118
  monai/bundle/utils.py,sha256=t-22uFvLn7Yy-dr1v1U33peNOxgAmU4TJiGAbsBrUKs,10108
119
119
  monai/bundle/workflows.py,sha256=CuhmFq1AWsN3ATiYJCSakPOxrOdGutl6vkpo9sxe8gU,34369
120
120
  monai/config/__init__.py,sha256=CN28CfTdsp301gv8YXfVvkbztCfbAqrLKrJi_C8oP9s,1048
@@ -195,7 +195,7 @@ monai/handlers/trt_handler.py,sha256=uWFdgC8QKRkcNwWfKIbQMdK6-MX_1ON0mKabeIn1ltI
195
195
  monai/handlers/utils.py,sha256=Ib1u-PLrtIkiLqTfREnrCWpN4af1btdNzkyMZuuuYyU,10239
196
196
  monai/handlers/validation_handler.py,sha256=NZO21c6zzXbmAgJZHkkdoZQSQIHwuxh94QD3PLUldGU,3674
197
197
  monai/inferers/__init__.py,sha256=K74t_RCeUPdEZvHzIPzVAwZ9DtmouLqhb3qDEmFBWs4,1107
198
- monai/inferers/inferer.py,sha256=UNZpsb97qpl9c7ylNV32_jk52nsX77BqYySOl0XxDQw,92802
198
+ monai/inferers/inferer.py,sha256=_VPnBIErwYzbrJIA9eMMalSso1pSsc_8cONVUUvPJOw,93549
199
199
  monai/inferers/merger.py,sha256=dZm-FVyXPlFb59q4DG52mbtPm8Iy4cNFWv3un0Z8k0M,16262
200
200
  monai/inferers/splitter.py,sha256=_hTnFdvDNRckkA7ZGQehVsNZw83oXoGFWyk5VXNqgJg,21149
201
201
  monai/inferers/utils.py,sha256=dvZBCAjaPa8xXcJuXRzNQ-fBzteauzkKbxE5YZdGBGY,20374
@@ -570,7 +570,7 @@ tests/handlers/test_trt_compile.py,sha256=p8Gr2CJmBo6gG8w7bGlAO--nDHtQvy9Ld3jtua
570
570
  tests/handlers/test_write_metrics_reports.py,sha256=oKGYR1plj1hSAu-ntbxkw_TD4O5hKPwVH_BS3MdHIbs,3027
571
571
  tests/inferers/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
572
572
  tests/inferers/test_avg_merger.py,sha256=lMR2PcNGFD6sfF6CjJTkahrAiMA5m5LUs5A11P6h8n0,5952
573
- tests/inferers/test_controlnet_inferers.py,sha256=SGluRyDlgwUJ8nm3BEWgXN3eb81fUGOaRXbLglC_ejc,49676
573
+ tests/inferers/test_controlnet_inferers.py,sha256=sWs5vkZHa-D0V3tWJ6149Z-RNq0for_XngDYxZRl_Ao,50285
574
574
  tests/inferers/test_diffusion_inferer.py,sha256=1O2V_bEmifOZ4RvpbZgYUCooiJ97T73avaBuMJPpBs0,9992
575
575
  tests/inferers/test_latent_diffusion_inferer.py,sha256=atJjmfVznUq8z9EjohFIMyA0Q1XT1Ly0Zepf_1xPz5I,32274
576
576
  tests/inferers/test_patch_inferer.py,sha256=LkYXWVn71vWinP-OJsIvq3FPH3jr36T7nKRIH5PzaqY,9878
@@ -1178,8 +1178,8 @@ tests/visualize/test_vis_gradcam.py,sha256=WpA-pvTB75eZs7JoIc5qyvOV9PwgkzWI8-Vow
1178
1178
  tests/visualize/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
1179
1179
  tests/visualize/utils/test_blend_images.py,sha256=RVs2p_8RWQDfhLHDNNtZaMig27v8o0km7XxNa-zWjKE,2274
1180
1180
  tests/visualize/utils/test_matshow3d.py,sha256=wXYj77L5Jvnp0f6DvL1rsi_-YlCxS0HJ9hiPmrbpuP8,5021
1181
- monai_weekly-1.5.dev2508.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
1182
- monai_weekly-1.5.dev2508.dist-info/METADATA,sha256=y-KfkVBP9_LhTnQo37SKpjDYJYsdujQuCCQiZpKdSv8,11909
1183
- monai_weekly-1.5.dev2508.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
1184
- monai_weekly-1.5.dev2508.dist-info/top_level.txt,sha256=hn2Y6P9xBf2R8faMeVMHhPMvrdDKxMsIOwMDYI0yTjs,12
1185
- monai_weekly-1.5.dev2508.dist-info/RECORD,,
1181
+ monai_weekly-1.5.dev2509.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
1182
+ monai_weekly-1.5.dev2509.dist-info/METADATA,sha256=h7L3w9XhzSfoxC5yRoqgKS_NeECPEORKyEX4E1WS6Vc,11909
1183
+ monai_weekly-1.5.dev2509.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
1184
+ monai_weekly-1.5.dev2509.dist-info/top_level.txt,sha256=hn2Y6P9xBf2R8faMeVMHhPMvrdDKxMsIOwMDYI0yTjs,12
1185
+ monai_weekly-1.5.dev2509.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (75.8.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -550,6 +550,8 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
550
550
  def test_sampler_conditioned(self, model_params, controlnet_params, input_shape):
551
551
  model_params["with_conditioning"] = True
552
552
  model_params["cross_attention_dim"] = 3
553
+ controlnet_params["with_conditioning"] = True
554
+ controlnet_params["cross_attention_dim"] = 3
553
555
  model = DiffusionModelUNet(**model_params)
554
556
  controlnet = ControlNet(**controlnet_params)
555
557
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -619,8 +621,11 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
619
621
  model_params = model_params.copy()
620
622
  n_concat_channel = 2
621
623
  model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
624
+ controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
622
625
  model_params["cross_attention_dim"] = None
626
+ controlnet_params["cross_attention_dim"] = None
623
627
  model_params["with_conditioning"] = False
628
+ controlnet_params["with_conditioning"] = False
624
629
  model = DiffusionModelUNet(**model_params)
625
630
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
626
631
  model.to(device)
@@ -1023,8 +1028,10 @@ class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
1023
1028
  if ae_model_type == "SPADEAutoencoderKL":
1024
1029
  stage_1 = SPADEAutoencoderKL(**autoencoder_params)
1025
1030
  stage_2_params = stage_2_params.copy()
1031
+ controlnet_params = controlnet_params.copy()
1026
1032
  n_concat_channel = 3
1027
1033
  stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
1034
+ controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
1028
1035
  if dm_model_type == "SPADEDiffusionModelUNet":
1029
1036
  stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
1030
1037
  else:
@@ -1106,8 +1113,10 @@ class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
1106
1113
  if ae_model_type == "SPADEAutoencoderKL":
1107
1114
  stage_1 = SPADEAutoencoderKL(**autoencoder_params)
1108
1115
  stage_2_params = stage_2_params.copy()
1116
+ controlnet_params = controlnet_params.copy()
1109
1117
  n_concat_channel = 3
1110
1118
  stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
1119
+ controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
1111
1120
  if dm_model_type == "SPADEDiffusionModelUNet":
1112
1121
  stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
1113
1122
  else: