diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,337 @@
1
+ from typing import Callable, List, Optional, Union
2
+
3
+ import torch
4
+
5
+ from ...models import UNet2DModel
6
+ from ...schedulers import CMStochasticIterativeScheduler
7
+ from ...utils import (
8
+ is_accelerate_available,
9
+ is_accelerate_version,
10
+ logging,
11
+ randn_tensor,
12
+ replace_example_docstring,
13
+ )
14
+ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
15
+
16
+
17
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
+
19
+
20
+ EXAMPLE_DOC_STRING = """
21
+ Examples:
22
+ ```py
23
+ >>> import torch
24
+
25
+ >>> from diffusers import ConsistencyModelPipeline
26
+
27
+ >>> device = "cuda"
28
+ >>> # Load the cd_imagenet64_l2 checkpoint.
29
+ >>> model_id_or_path = "openai/diffusers-cd_imagenet64_l2"
30
+ >>> pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
31
+ >>> pipe.to(device)
32
+
33
+ >>> # Onestep Sampling
34
+ >>> image = pipe(num_inference_steps=1).images[0]
35
+ >>> image.save("cd_imagenet64_l2_onestep_sample.png")
36
+
37
+ >>> # Onestep sampling, class-conditional image generation
38
+ >>> # ImageNet-64 class label 145 corresponds to king penguins
39
+ >>> image = pipe(num_inference_steps=1, class_labels=145).images[0]
40
+ >>> image.save("cd_imagenet64_l2_onestep_sample_penguin.png")
41
+
42
+ >>> # Multistep sampling, class-conditional image generation
43
+ >>> # Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo:
44
+ >>> # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77
45
+ >>> image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0]
46
+ >>> image.save("cd_imagenet64_l2_multistep_sample_penguin.png")
47
+ ```
48
+ """
49
+
50
+
51
+ class ConsistencyModelPipeline(DiffusionPipeline):
52
+ r"""
53
+ Pipeline for consistency models for unconditional or class-conditional image generation, as introduced in [1].
54
+
55
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
56
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
57
+
58
+ [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models"
59
+ https://arxiv.org/pdf/2303.01469
60
+
61
+ Args:
62
+ unet ([`UNet2DModel`]):
63
+ Unconditional or class-conditional U-Net architecture to denoise image latents.
64
+ scheduler ([`SchedulerMixin`]):
65
+ A scheduler to be used in combination with `unet` to denoise the image latents. Currently only compatible
66
+ with [`CMStochasticIterativeScheduler`].
67
+ """
68
+
69
+ def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None:
70
+ super().__init__()
71
+
72
+ self.register_modules(
73
+ unet=unet,
74
+ scheduler=scheduler,
75
+ )
76
+
77
+ self.safety_checker = None
78
+
79
+ def enable_sequential_cpu_offload(self, gpu_id=0):
80
+ r"""
81
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
82
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
83
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
84
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
85
+ `enable_model_cpu_offload`, but performance is lower.
86
+ """
87
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
88
+ from accelerate import cpu_offload
89
+ else:
90
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
91
+
92
+ device = torch.device(f"cuda:{gpu_id}")
93
+
94
+ if self.device.type != "cpu":
95
+ self.to("cpu", silence_dtype_warnings=True)
96
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
97
+
98
+ for cpu_offloaded_model in [self.unet]:
99
+ cpu_offload(cpu_offloaded_model, device)
100
+
101
+ if self.safety_checker is not None:
102
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
103
+
104
+ def enable_model_cpu_offload(self, gpu_id=0):
105
+ r"""
106
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
107
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
108
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
109
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
110
+ """
111
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
112
+ from accelerate import cpu_offload_with_hook
113
+ else:
114
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
115
+
116
+ device = torch.device(f"cuda:{gpu_id}")
117
+
118
+ if self.device.type != "cpu":
119
+ self.to("cpu", silence_dtype_warnings=True)
120
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
121
+
122
+ hook = None
123
+ for cpu_offloaded_model in [self.unet]:
124
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
125
+
126
+ if self.safety_checker is not None:
127
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
128
+
129
+ # We'll offload the last model manually.
130
+ self.final_offload_hook = hook
131
+
132
+ @property
133
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
134
+ def _execution_device(self):
135
+ r"""
136
+ Returns the device on which the pipeline's models will be executed. After calling
137
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
138
+ hooks.
139
+ """
140
+ if not hasattr(self.unet, "_hf_hook"):
141
+ return self.device
142
+ for module in self.unet.modules():
143
+ if (
144
+ hasattr(module, "_hf_hook")
145
+ and hasattr(module._hf_hook, "execution_device")
146
+ and module._hf_hook.execution_device is not None
147
+ ):
148
+ return torch.device(module._hf_hook.execution_device)
149
+ return self.device
150
+
151
+ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
152
+ shape = (batch_size, num_channels, height, width)
153
+ if isinstance(generator, list) and len(generator) != batch_size:
154
+ raise ValueError(
155
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
156
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
157
+ )
158
+
159
+ if latents is None:
160
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
161
+ else:
162
+ latents = latents.to(device=device, dtype=dtype)
163
+
164
+ # scale the initial noise by the standard deviation required by the scheduler
165
+ latents = latents * self.scheduler.init_noise_sigma
166
+ return latents
167
+
168
+ # Follows diffusers.VaeImageProcessor.postprocess
169
+ def postprocess_image(self, sample: torch.FloatTensor, output_type: str = "pil"):
170
+ if output_type not in ["pt", "np", "pil"]:
171
+ raise ValueError(
172
+ f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
173
+ )
174
+
175
+ # Equivalent to diffusers.VaeImageProcessor.denormalize
176
+ sample = (sample / 2 + 0.5).clamp(0, 1)
177
+ if output_type == "pt":
178
+ return sample
179
+
180
+ # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
181
+ sample = sample.cpu().permute(0, 2, 3, 1).numpy()
182
+ if output_type == "np":
183
+ return sample
184
+
185
+ # Output_type must be 'pil'
186
+ sample = self.numpy_to_pil(sample)
187
+ return sample
188
+
189
+ def prepare_class_labels(self, batch_size, device, class_labels=None):
190
+ if self.unet.config.num_class_embeds is not None:
191
+ if isinstance(class_labels, list):
192
+ class_labels = torch.tensor(class_labels, dtype=torch.int)
193
+ elif isinstance(class_labels, int):
194
+ assert batch_size == 1, "Batch size must be 1 if classes is an int"
195
+ class_labels = torch.tensor([class_labels], dtype=torch.int)
196
+ elif class_labels is None:
197
+ # Randomly generate batch_size class labels
198
+ # TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils
199
+ class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,))
200
+ class_labels = class_labels.to(device)
201
+ else:
202
+ class_labels = None
203
+ return class_labels
204
+
205
+ def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps):
206
+ if num_inference_steps is None and timesteps is None:
207
+ raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.")
208
+
209
+ if num_inference_steps is not None and timesteps is not None:
210
+ logger.warning(
211
+ f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;"
212
+ " `timesteps` will be used over `num_inference_steps`."
213
+ )
214
+
215
+ if latents is not None:
216
+ expected_shape = (batch_size, 3, img_size, img_size)
217
+ if latents.shape != expected_shape:
218
+ raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.")
219
+
220
+ if (callback_steps is None) or (
221
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
222
+ ):
223
+ raise ValueError(
224
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
225
+ f" {type(callback_steps)}."
226
+ )
227
+
228
+ @torch.no_grad()
229
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
230
+ def __call__(
231
+ self,
232
+ batch_size: int = 1,
233
+ class_labels: Optional[Union[torch.Tensor, List[int], int]] = None,
234
+ num_inference_steps: int = 1,
235
+ timesteps: List[int] = None,
236
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
237
+ latents: Optional[torch.FloatTensor] = None,
238
+ output_type: Optional[str] = "pil",
239
+ return_dict: bool = True,
240
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
241
+ callback_steps: int = 1,
242
+ ):
243
+ r"""
244
+ Args:
245
+ batch_size (`int`, *optional*, defaults to 1):
246
+ The number of images to generate.
247
+ class_labels (`torch.Tensor` or `List[int]` or `int`, *optional*):
248
+ Optional class labels for conditioning class-conditional consistency models. Will not be used if the
249
+ model is not class-conditional.
250
+ num_inference_steps (`int`, *optional*, defaults to 1):
251
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
252
+ expense of slower inference.
253
+ timesteps (`List[int]`, *optional*):
254
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
255
+ timesteps are used. Must be in descending order.
256
+ generator (`torch.Generator`, *optional*):
257
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
258
+ to make generation deterministic.
259
+ latents (`torch.FloatTensor`, *optional*):
260
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
261
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
262
+ tensor will ge generated by sampling using the supplied random `generator`.
263
+ output_type (`str`, *optional*, defaults to `"pil"`):
264
+ The output format of the generate image. Choose between
265
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
266
+ return_dict (`bool`, *optional*, defaults to `True`):
267
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
268
+ callback (`Callable`, *optional*):
269
+ A function that will be called every `callback_steps` steps during inference. The function will be
270
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
271
+ callback_steps (`int`, *optional*, defaults to 1):
272
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
273
+ called at every step.
274
+
275
+ Examples:
276
+
277
+ Returns:
278
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
279
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
280
+ """
281
+ # 0. Prepare call parameters
282
+ img_size = self.unet.config.sample_size
283
+ device = self._execution_device
284
+
285
+ # 1. Check inputs
286
+ self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps)
287
+
288
+ # 2. Prepare image latents
289
+ # Sample image latents x_0 ~ N(0, sigma_0^2 * I)
290
+ sample = self.prepare_latents(
291
+ batch_size=batch_size,
292
+ num_channels=self.unet.config.in_channels,
293
+ height=img_size,
294
+ width=img_size,
295
+ dtype=self.unet.dtype,
296
+ device=device,
297
+ generator=generator,
298
+ latents=latents,
299
+ )
300
+
301
+ # 3. Handle class_labels for class-conditional models
302
+ class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels)
303
+
304
+ # 4. Prepare timesteps
305
+ if timesteps is not None:
306
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
307
+ timesteps = self.scheduler.timesteps
308
+ num_inference_steps = len(timesteps)
309
+ else:
310
+ self.scheduler.set_timesteps(num_inference_steps)
311
+ timesteps = self.scheduler.timesteps
312
+
313
+ # 5. Denoising loop
314
+ # Multistep sampling: implements Algorithm 1 in the paper
315
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
316
+ for i, t in enumerate(timesteps):
317
+ scaled_sample = self.scheduler.scale_model_input(sample, t)
318
+ model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0]
319
+
320
+ sample = self.scheduler.step(model_output, t, sample, generator=generator)[0]
321
+
322
+ # call the callback, if provided
323
+ progress_bar.update()
324
+ if callback is not None and i % callback_steps == 0:
325
+ callback(i, t, sample)
326
+
327
+ # 6. Post-process image sample
328
+ image = self.postprocess_image(sample, output_type=output_type)
329
+
330
+ # Offload last model to CPU
331
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
332
+ self.final_offload_hook.offload()
333
+
334
+ if not return_dict:
335
+ return (image,)
336
+
337
+ return ImagePipelineOutput(images=image)
@@ -1,10 +1,15 @@
1
- from typing import Any, Dict, List, Optional, Tuple, Union
1
+ import os
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2
3
 
3
4
  import torch
4
5
  from torch import nn
5
6
 
6
7
  from ...models.controlnet import ControlNetModel, ControlNetOutput
7
8
  from ...models.modeling_utils import ModelMixin
9
+ from ...utils import logging
10
+
11
+
12
+ logger = logging.get_logger(__name__)
8
13
 
9
14
 
10
15
  class MultiControlNetModel(ModelMixin):
@@ -64,3 +69,117 @@ class MultiControlNetModel(ModelMixin):
64
69
  mid_block_res_sample += mid_sample
65
70
 
66
71
  return down_block_res_samples, mid_block_res_sample
72
+
73
+ def save_pretrained(
74
+ self,
75
+ save_directory: Union[str, os.PathLike],
76
+ is_main_process: bool = True,
77
+ save_function: Callable = None,
78
+ safe_serialization: bool = False,
79
+ variant: Optional[str] = None,
80
+ ):
81
+ """
82
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
83
+ `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.
84
+
85
+ Arguments:
86
+ save_directory (`str` or `os.PathLike`):
87
+ Directory to which to save. Will be created if it doesn't exist.
88
+ is_main_process (`bool`, *optional*, defaults to `True`):
89
+ Whether the process calling this is the main process or not. Useful when in distributed training like
90
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
91
+ the main process to avoid race conditions.
92
+ save_function (`Callable`):
93
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
94
+ need to replace `torch.save` by another method. Can be configured with the environment variable
95
+ `DIFFUSERS_SAVE_MODE`.
96
+ safe_serialization (`bool`, *optional*, defaults to `False`):
97
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
98
+ variant (`str`, *optional*):
99
+ If specified, weights are saved in the format pytorch_model.<variant>.bin.
100
+ """
101
+ idx = 0
102
+ model_path_to_save = save_directory
103
+ for controlnet in self.nets:
104
+ controlnet.save_pretrained(
105
+ model_path_to_save,
106
+ is_main_process=is_main_process,
107
+ save_function=save_function,
108
+ safe_serialization=safe_serialization,
109
+ variant=variant,
110
+ )
111
+
112
+ idx += 1
113
+ model_path_to_save = model_path_to_save + f"_{idx}"
114
+
115
+ @classmethod
116
+ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
117
+ r"""
118
+ Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
119
+
120
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
121
+ the model, you should first set it back in training mode with `model.train()`.
122
+
123
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
124
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
125
+ task.
126
+
127
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
128
+ weights are discarded.
129
+
130
+ Parameters:
131
+ pretrained_model_path (`os.PathLike`):
132
+ A path to a *directory* containing model weights saved using
133
+ [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g.,
134
+ `./my_model_directory/controlnet`.
135
+ torch_dtype (`str` or `torch.dtype`, *optional*):
136
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
137
+ will be automatically derived from the model's weights.
138
+ output_loading_info(`bool`, *optional*, defaults to `False`):
139
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
140
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
141
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
142
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
143
+ same device.
144
+
145
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
146
+ more information about each option see [designing a device
147
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
148
+ max_memory (`Dict`, *optional*):
149
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
150
+ GPU and the available CPU RAM if unset.
151
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
152
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
153
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
154
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
155
+ setting this argument to `True` will raise an error.
156
+ variant (`str`, *optional*):
157
+ If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
158
+ ignored when using `from_flax`.
159
+ use_safetensors (`bool`, *optional*, defaults to `None`):
160
+ If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
161
+ `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
162
+ `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
163
+ """
164
+ idx = 0
165
+ controlnets = []
166
+
167
+ # load controlnet and append to list until no controlnet directory exists anymore
168
+ # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
169
+ # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
170
+ model_path_to_load = pretrained_model_path
171
+ while os.path.isdir(model_path_to_load):
172
+ controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs)
173
+ controlnets.append(controlnet)
174
+
175
+ idx += 1
176
+ model_path_to_load = pretrained_model_path + f"_{idx}"
177
+
178
+ logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
179
+
180
+ if len(controlnets) == 0:
181
+ raise ValueError(
182
+ f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
183
+ )
184
+
185
+ return cls(controlnets)
@@ -14,7 +14,6 @@
14
14
 
15
15
 
16
16
  import inspect
17
- import os
18
17
  import warnings
19
18
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
19
 
@@ -492,6 +491,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
492
491
  prompt_embeds=None,
493
492
  negative_prompt_embeds=None,
494
493
  controlnet_conditioning_scale=1.0,
494
+ control_guidance_start=0.0,
495
+ control_guidance_end=1.0,
495
496
  ):
496
497
  if (callback_steps is None) or (
497
498
  callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
@@ -560,7 +561,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
560
561
  raise ValueError("A single batch of multiple conditionings are supported at the moment.")
561
562
  elif len(image) != len(self.controlnet.nets):
562
563
  raise ValueError(
563
- "For multiple controlnets: `image` must have the same length as the number of controlnets."
564
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
564
565
  )
565
566
 
566
567
  for image_ in image:
@@ -594,6 +595,27 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
594
595
  else:
595
596
  assert False
596
597
 
598
+ if len(control_guidance_start) != len(control_guidance_end):
599
+ raise ValueError(
600
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
601
+ )
602
+
603
+ if isinstance(self.controlnet, MultiControlNetModel):
604
+ if len(control_guidance_start) != len(self.controlnet.nets):
605
+ raise ValueError(
606
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
607
+ )
608
+
609
+ for start, end in zip(control_guidance_start, control_guidance_end):
610
+ if start >= end:
611
+ raise ValueError(
612
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
613
+ )
614
+ if start < 0.0:
615
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
616
+ if end > 1.0:
617
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
618
+
597
619
  def check_image(self, image, prompt, prompt_embeds):
598
620
  image_is_pil = isinstance(image, PIL.Image.Image)
599
621
  image_is_tensor = isinstance(image, torch.Tensor)
@@ -679,18 +701,6 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
679
701
  latents = latents * self.scheduler.init_noise_sigma
680
702
  return latents
681
703
 
682
- # override DiffusionPipeline
683
- def save_pretrained(
684
- self,
685
- save_directory: Union[str, os.PathLike],
686
- safe_serialization: bool = False,
687
- variant: Optional[str] = None,
688
- ):
689
- if isinstance(self.controlnet, ControlNetModel):
690
- super().save_pretrained(save_directory, safe_serialization, variant)
691
- else:
692
- raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")
693
-
694
704
  @torch.no_grad()
695
705
  @replace_example_docstring(EXAMPLE_DOC_STRING)
696
706
  def __call__(
@@ -722,6 +732,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
722
732
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
723
733
  controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
724
734
  guess_mode: bool = False,
735
+ control_guidance_start: Union[float, List[float]] = 0.0,
736
+ control_guidance_end: Union[float, List[float]] = 1.0,
725
737
  ):
726
738
  r"""
727
739
  Function invoked when calling the pipeline for generation.
@@ -797,6 +809,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
797
809
  guess_mode (`bool`, *optional*, defaults to `False`):
798
810
  In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
799
811
  you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
812
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
813
+ The percentage of total steps at which the controlnet starts applying.
814
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
815
+ The percentage of total steps at which the controlnet stops applying.
800
816
 
801
817
  Examples:
802
818
 
@@ -807,6 +823,18 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
807
823
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
808
824
  (nsfw) content, according to the `safety_checker`.
809
825
  """
826
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
827
+
828
+ # align format for control guidance
829
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
830
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
831
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
832
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
833
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
834
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
835
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
836
+ control_guidance_end
837
+ ]
810
838
 
811
839
  # 1. Check inputs. Raise error if not correct
812
840
  self.check_inputs(
@@ -817,6 +845,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
817
845
  prompt_embeds,
818
846
  negative_prompt_embeds,
819
847
  controlnet_conditioning_scale,
848
+ control_guidance_start,
849
+ control_guidance_end,
820
850
  )
821
851
 
822
852
  # 2. Define call parameters
@@ -833,8 +863,6 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
833
863
  # corresponds to doing no classifier free guidance.
834
864
  do_classifier_free_guidance = guidance_scale > 1.0
835
865
 
836
- controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
837
-
838
866
  if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
839
867
  controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
840
868
 
@@ -917,6 +945,15 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
917
945
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
918
946
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
919
947
 
948
+ # 7.1 Create tensor stating which controlnets to keep
949
+ controlnet_keep = []
950
+ for i in range(len(timesteps)):
951
+ keeps = [
952
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
953
+ for s, e in zip(control_guidance_start, control_guidance_end)
954
+ ]
955
+ controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
956
+
920
957
  # 8. Denoising loop
921
958
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
922
959
  with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -935,12 +972,17 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
935
972
  control_model_input = latent_model_input
936
973
  controlnet_prompt_embeds = prompt_embeds
937
974
 
975
+ if isinstance(controlnet_keep[i], list):
976
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
977
+ else:
978
+ cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
979
+
938
980
  down_block_res_samples, mid_block_res_sample = self.controlnet(
939
981
  control_model_input,
940
982
  t,
941
983
  encoder_hidden_states=controlnet_prompt_embeds,
942
984
  controlnet_cond=image,
943
- conditioning_scale=controlnet_conditioning_scale,
985
+ conditioning_scale=cond_scale,
944
986
  guess_mode=guess_mode,
945
987
  return_dict=False,
946
988
  )