diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -1
- diffusers/configuration_utils.py +34 -29
- diffusers/dependency_versions_table.py +4 -0
- diffusers/image_processor.py +125 -12
- diffusers/loaders.py +169 -203
- diffusers/models/attention.py +24 -1
- diffusers/models/attention_flax.py +10 -5
- diffusers/models/attention_processor.py +3 -0
- diffusers/models/autoencoder_kl.py +114 -33
- diffusers/models/controlnet.py +131 -14
- diffusers/models/controlnet_flax.py +37 -26
- diffusers/models/cross_attention.py +17 -17
- diffusers/models/embeddings.py +67 -0
- diffusers/models/modeling_flax_utils.py +64 -56
- diffusers/models/modeling_utils.py +193 -104
- diffusers/models/prior_transformer.py +207 -37
- diffusers/models/resnet.py +26 -26
- diffusers/models/transformer_2d.py +36 -41
- diffusers/models/transformer_temporal.py +24 -21
- diffusers/models/unet_1d.py +31 -25
- diffusers/models/unet_2d.py +43 -30
- diffusers/models/unet_2d_blocks.py +210 -89
- diffusers/models/unet_2d_blocks_flax.py +12 -12
- diffusers/models/unet_2d_condition.py +172 -64
- diffusers/models/unet_2d_condition_flax.py +38 -24
- diffusers/models/unet_3d_blocks.py +34 -31
- diffusers/models/unet_3d_condition.py +101 -34
- diffusers/models/vae.py +5 -5
- diffusers/models/vae_flax.py +37 -34
- diffusers/models/vq_model.py +23 -14
- diffusers/pipelines/__init__.py +24 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
- diffusers/pipelines/consistency_models/__init__.py +1 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
- diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/kandinsky/__init__.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
- diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_utils.py +124 -146
- diffusers/pipelines/shap_e/__init__.py +27 -0
- diffusers/pipelines/shap_e/camera.py +147 -0
- diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
- diffusers/pipelines/shap_e/renderer.py +709 -0
- diffusers/pipelines/stable_diffusion/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
- diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
- diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
- diffusers/schedulers/__init__.py +3 -0
- diffusers/schedulers/scheduling_consistency_models.py +380 -0
- diffusers/schedulers/scheduling_ddim.py +28 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
- diffusers/schedulers/scheduling_ddpm.py +53 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
- diffusers/schedulers/scheduling_deis_multistep.py +66 -11
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
- diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
- diffusers/schedulers/scheduling_euler_discrete.py +58 -8
- diffusers/schedulers/scheduling_heun_discrete.py +89 -14
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
- diffusers/schedulers/scheduling_lms_discrete.py +57 -8
- diffusers/schedulers/scheduling_pndm.py +46 -10
- diffusers/schedulers/scheduling_repaint.py +19 -4
- diffusers/schedulers/scheduling_sde_ve.py +5 -1
- diffusers/schedulers/scheduling_unclip.py +43 -4
- diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
- diffusers/utils/hub_utils.py +1 -1
- diffusers/utils/import_utils.py +20 -3
- diffusers/utils/logging.py +15 -18
- diffusers/utils/outputs.py +3 -3
- diffusers/utils/testing_utils.py +15 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
|
33
33
|
Dropout rate
|
34
34
|
num_layers (:obj:`int`, *optional*, defaults to 1):
|
35
35
|
Number of attention blocks layers
|
36
|
-
|
36
|
+
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
|
37
37
|
Number of attention heads of each spatial transformer block
|
38
38
|
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
39
39
|
Whether to add downsampling layer before each final output
|
@@ -46,7 +46,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
|
46
46
|
out_channels: int
|
47
47
|
dropout: float = 0.0
|
48
48
|
num_layers: int = 1
|
49
|
-
|
49
|
+
num_attention_heads: int = 1
|
50
50
|
add_downsample: bool = True
|
51
51
|
use_linear_projection: bool = False
|
52
52
|
only_cross_attention: bool = False
|
@@ -70,8 +70,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
|
70
70
|
|
71
71
|
attn_block = FlaxTransformer2DModel(
|
72
72
|
in_channels=self.out_channels,
|
73
|
-
n_heads=self.
|
74
|
-
d_head=self.out_channels // self.
|
73
|
+
n_heads=self.num_attention_heads,
|
74
|
+
d_head=self.out_channels // self.num_attention_heads,
|
75
75
|
depth=1,
|
76
76
|
use_linear_projection=self.use_linear_projection,
|
77
77
|
only_cross_attention=self.only_cross_attention,
|
@@ -172,7 +172,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
|
172
172
|
Dropout rate
|
173
173
|
num_layers (:obj:`int`, *optional*, defaults to 1):
|
174
174
|
Number of attention blocks layers
|
175
|
-
|
175
|
+
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
|
176
176
|
Number of attention heads of each spatial transformer block
|
177
177
|
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
|
178
178
|
Whether to add upsampling layer before each final output
|
@@ -186,7 +186,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
|
186
186
|
prev_output_channel: int
|
187
187
|
dropout: float = 0.0
|
188
188
|
num_layers: int = 1
|
189
|
-
|
189
|
+
num_attention_heads: int = 1
|
190
190
|
add_upsample: bool = True
|
191
191
|
use_linear_projection: bool = False
|
192
192
|
only_cross_attention: bool = False
|
@@ -211,8 +211,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
|
211
211
|
|
212
212
|
attn_block = FlaxTransformer2DModel(
|
213
213
|
in_channels=self.out_channels,
|
214
|
-
n_heads=self.
|
215
|
-
d_head=self.out_channels // self.
|
214
|
+
n_heads=self.num_attention_heads,
|
215
|
+
d_head=self.out_channels // self.num_attention_heads,
|
216
216
|
depth=1,
|
217
217
|
use_linear_projection=self.use_linear_projection,
|
218
218
|
only_cross_attention=self.only_cross_attention,
|
@@ -317,7 +317,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
|
317
317
|
Dropout rate
|
318
318
|
num_layers (:obj:`int`, *optional*, defaults to 1):
|
319
319
|
Number of attention blocks layers
|
320
|
-
|
320
|
+
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
|
321
321
|
Number of attention heads of each spatial transformer block
|
322
322
|
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
323
323
|
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
@@ -327,7 +327,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
|
327
327
|
in_channels: int
|
328
328
|
dropout: float = 0.0
|
329
329
|
num_layers: int = 1
|
330
|
-
|
330
|
+
num_attention_heads: int = 1
|
331
331
|
use_linear_projection: bool = False
|
332
332
|
use_memory_efficient_attention: bool = False
|
333
333
|
dtype: jnp.dtype = jnp.float32
|
@@ -348,8 +348,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
|
348
348
|
for _ in range(self.num_layers):
|
349
349
|
attn_block = FlaxTransformer2DModel(
|
350
350
|
in_channels=self.in_channels,
|
351
|
-
n_heads=self.
|
352
|
-
d_head=self.in_channels // self.
|
351
|
+
n_heads=self.num_attention_heads,
|
352
|
+
d_head=self.in_channels // self.num_attention_heads,
|
353
353
|
depth=1,
|
354
354
|
use_linear_projection=self.use_linear_projection,
|
355
355
|
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
@@ -25,6 +25,9 @@ from .activations import get_activation
|
|
25
25
|
from .attention_processor import AttentionProcessor, AttnProcessor
|
26
26
|
from .embeddings import (
|
27
27
|
GaussianFourierProjection,
|
28
|
+
ImageHintTimeEmbedding,
|
29
|
+
ImageProjection,
|
30
|
+
ImageTimeEmbedding,
|
28
31
|
TextImageProjection,
|
29
32
|
TextImageTimeEmbedding,
|
30
33
|
TextTimeEmbedding,
|
@@ -50,27 +53,29 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
50
53
|
@dataclass
|
51
54
|
class UNet2DConditionOutput(BaseOutput):
|
52
55
|
"""
|
56
|
+
The output of [`UNet2DConditionModel`].
|
57
|
+
|
53
58
|
Args:
|
54
59
|
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
55
|
-
|
60
|
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
56
61
|
"""
|
57
62
|
|
58
|
-
sample: torch.FloatTensor
|
63
|
+
sample: torch.FloatTensor = None
|
59
64
|
|
60
65
|
|
61
66
|
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
62
67
|
r"""
|
63
|
-
|
64
|
-
|
68
|
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
69
|
+
shaped output.
|
65
70
|
|
66
|
-
This model inherits from [`ModelMixin`]. Check the superclass documentation for
|
67
|
-
|
71
|
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
72
|
+
for all models (such as downloading or saving).
|
68
73
|
|
69
74
|
Parameters:
|
70
75
|
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
71
76
|
Height and width of input/output sample.
|
72
|
-
in_channels (`int`, *optional*, defaults to 4):
|
73
|
-
out_channels (`int`, *optional*, defaults to 4):
|
77
|
+
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
78
|
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
74
79
|
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
75
80
|
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
76
81
|
Whether to flip the sin to cos in the time embedding.
|
@@ -78,9 +83,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
78
83
|
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
79
84
|
The tuple of downsample blocks to use.
|
80
85
|
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
81
|
-
|
82
|
-
mid block layer
|
83
|
-
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"
|
86
|
+
Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
|
87
|
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
88
|
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
84
89
|
The tuple of upsample blocks to use.
|
85
90
|
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
86
91
|
Whether to include self-attention in the basic transformer blocks, see
|
@@ -92,50 +97,58 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
92
97
|
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
93
98
|
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
94
99
|
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
95
|
-
If `None`,
|
100
|
+
If `None`, normalization and activation layers is skipped in post-processing.
|
96
101
|
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
97
102
|
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
98
103
|
The dimension of the cross attention features.
|
104
|
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
105
|
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
106
|
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
107
|
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
99
108
|
encoder_hid_dim (`int`, *optional*, defaults to None):
|
100
109
|
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
101
110
|
dimension to `cross_attention_dim`.
|
102
|
-
encoder_hid_dim_type (`str`, *optional*, defaults to None):
|
103
|
-
If given, the `encoder_hidden_states` and potentially other embeddings
|
111
|
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
112
|
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
104
113
|
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
105
114
|
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
115
|
+
num_attention_heads (`int`, *optional*):
|
116
|
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
106
117
|
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
107
|
-
for
|
108
|
-
class_embed_type (`str`, *optional*, defaults to None):
|
118
|
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
119
|
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
109
120
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
110
121
|
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
111
|
-
addition_embed_type (`str`, *optional*, defaults to None):
|
122
|
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
112
123
|
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
113
124
|
"text". "text" will use the `TextTimeEmbedding` layer.
|
114
|
-
|
125
|
+
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
126
|
+
Dimension for the timestep embeddings.
|
127
|
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
115
128
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
116
129
|
class conditioning with `class_embed_type` equal to `None`.
|
117
|
-
time_embedding_type (`str`, *optional*,
|
130
|
+
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
118
131
|
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
119
|
-
time_embedding_dim (`int`, *optional*,
|
132
|
+
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
120
133
|
An optional override for the dimension of the projected time embedding.
|
121
|
-
time_embedding_act_fn (`str`, *optional*,
|
122
|
-
Optional activation function to use on the time embeddings
|
123
|
-
|
124
|
-
timestep_post_act (`str
|
134
|
+
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
135
|
+
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
136
|
+
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
137
|
+
timestep_post_act (`str`, *optional*, defaults to `None`):
|
125
138
|
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
126
|
-
time_cond_proj_dim (`int`, *optional*,
|
127
|
-
The dimension of `cond_proj` layer in timestep embedding.
|
139
|
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
140
|
+
The dimension of `cond_proj` layer in the timestep embedding.
|
128
141
|
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
129
142
|
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
130
143
|
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
131
|
-
|
144
|
+
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
132
145
|
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
133
146
|
embeddings with the class embeddings.
|
134
147
|
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
135
148
|
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
136
|
-
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None
|
137
|
-
`only_cross_attention` value
|
138
|
-
|
149
|
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
150
|
+
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
151
|
+
otherwise.
|
139
152
|
"""
|
140
153
|
|
141
154
|
_supports_gradient_checkpointing = True
|
@@ -166,13 +179,16 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
166
179
|
norm_num_groups: Optional[int] = 32,
|
167
180
|
norm_eps: float = 1e-5,
|
168
181
|
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
182
|
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
169
183
|
encoder_hid_dim: Optional[int] = None,
|
170
184
|
encoder_hid_dim_type: Optional[str] = None,
|
171
185
|
attention_head_dim: Union[int, Tuple[int]] = 8,
|
186
|
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
172
187
|
dual_cross_attention: bool = False,
|
173
188
|
use_linear_projection: bool = False,
|
174
189
|
class_embed_type: Optional[str] = None,
|
175
190
|
addition_embed_type: Optional[str] = None,
|
191
|
+
addition_time_embed_dim: Optional[int] = None,
|
176
192
|
num_class_embeds: Optional[int] = None,
|
177
193
|
upcast_attention: bool = False,
|
178
194
|
resnet_time_scale_shift: str = "default",
|
@@ -195,6 +211,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
195
211
|
|
196
212
|
self.sample_size = sample_size
|
197
213
|
|
214
|
+
if num_attention_heads is not None:
|
215
|
+
raise ValueError(
|
216
|
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
217
|
+
)
|
218
|
+
|
219
|
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
220
|
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
221
|
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
222
|
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
223
|
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
224
|
+
# which is why we correct for the naming here.
|
225
|
+
num_attention_heads = num_attention_heads or attention_head_dim
|
226
|
+
|
198
227
|
# Check inputs
|
199
228
|
if len(down_block_types) != len(up_block_types):
|
200
229
|
raise ValueError(
|
@@ -211,6 +240,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
211
240
|
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
212
241
|
)
|
213
242
|
|
243
|
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
244
|
+
raise ValueError(
|
245
|
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
246
|
+
)
|
247
|
+
|
214
248
|
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
215
249
|
raise ValueError(
|
216
250
|
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
@@ -280,7 +314,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
280
314
|
image_embed_dim=cross_attention_dim,
|
281
315
|
cross_attention_dim=cross_attention_dim,
|
282
316
|
)
|
283
|
-
|
317
|
+
elif encoder_hid_dim_type == "image_proj":
|
318
|
+
# Kandinsky 2.2
|
319
|
+
self.encoder_hid_proj = ImageProjection(
|
320
|
+
image_embed_dim=encoder_hid_dim,
|
321
|
+
cross_attention_dim=cross_attention_dim,
|
322
|
+
)
|
284
323
|
elif encoder_hid_dim_type is not None:
|
285
324
|
raise ValueError(
|
286
325
|
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
@@ -333,6 +372,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
333
372
|
self.add_embedding = TextImageTimeEmbedding(
|
334
373
|
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
335
374
|
)
|
375
|
+
elif addition_embed_type == "text_time":
|
376
|
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
377
|
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
378
|
+
elif addition_embed_type == "image":
|
379
|
+
# Kandinsky 2.2
|
380
|
+
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
381
|
+
elif addition_embed_type == "image_hint":
|
382
|
+
# Kandinsky 2.2 ControlNet
|
383
|
+
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
336
384
|
elif addition_embed_type is not None:
|
337
385
|
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
338
386
|
|
@@ -353,6 +401,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
353
401
|
if mid_block_only_cross_attention is None:
|
354
402
|
mid_block_only_cross_attention = False
|
355
403
|
|
404
|
+
if isinstance(num_attention_heads, int):
|
405
|
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
406
|
+
|
356
407
|
if isinstance(attention_head_dim, int):
|
357
408
|
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
358
409
|
|
@@ -362,6 +413,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
362
413
|
if isinstance(layers_per_block, int):
|
363
414
|
layers_per_block = [layers_per_block] * len(down_block_types)
|
364
415
|
|
416
|
+
if isinstance(transformer_layers_per_block, int):
|
417
|
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
418
|
+
|
365
419
|
if class_embeddings_concat:
|
366
420
|
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
367
421
|
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
@@ -380,6 +434,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
380
434
|
down_block = get_down_block(
|
381
435
|
down_block_type,
|
382
436
|
num_layers=layers_per_block[i],
|
437
|
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
383
438
|
in_channels=input_channel,
|
384
439
|
out_channels=output_channel,
|
385
440
|
temb_channels=blocks_time_embed_dim,
|
@@ -388,7 +443,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
388
443
|
resnet_act_fn=act_fn,
|
389
444
|
resnet_groups=norm_num_groups,
|
390
445
|
cross_attention_dim=cross_attention_dim[i],
|
391
|
-
|
446
|
+
num_attention_heads=num_attention_heads[i],
|
392
447
|
downsample_padding=downsample_padding,
|
393
448
|
dual_cross_attention=dual_cross_attention,
|
394
449
|
use_linear_projection=use_linear_projection,
|
@@ -398,12 +453,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
398
453
|
resnet_skip_time_act=resnet_skip_time_act,
|
399
454
|
resnet_out_scale_factor=resnet_out_scale_factor,
|
400
455
|
cross_attention_norm=cross_attention_norm,
|
456
|
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
401
457
|
)
|
402
458
|
self.down_blocks.append(down_block)
|
403
459
|
|
404
460
|
# mid
|
405
461
|
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
406
462
|
self.mid_block = UNetMidBlock2DCrossAttn(
|
463
|
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
407
464
|
in_channels=block_out_channels[-1],
|
408
465
|
temb_channels=blocks_time_embed_dim,
|
409
466
|
resnet_eps=norm_eps,
|
@@ -411,7 +468,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
411
468
|
output_scale_factor=mid_block_scale_factor,
|
412
469
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
413
470
|
cross_attention_dim=cross_attention_dim[-1],
|
414
|
-
|
471
|
+
num_attention_heads=num_attention_heads[-1],
|
415
472
|
resnet_groups=norm_num_groups,
|
416
473
|
dual_cross_attention=dual_cross_attention,
|
417
474
|
use_linear_projection=use_linear_projection,
|
@@ -425,7 +482,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
425
482
|
resnet_act_fn=act_fn,
|
426
483
|
output_scale_factor=mid_block_scale_factor,
|
427
484
|
cross_attention_dim=cross_attention_dim[-1],
|
428
|
-
|
485
|
+
attention_head_dim=attention_head_dim[-1],
|
429
486
|
resnet_groups=norm_num_groups,
|
430
487
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
431
488
|
skip_time_act=resnet_skip_time_act,
|
@@ -442,9 +499,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
442
499
|
|
443
500
|
# up
|
444
501
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
445
|
-
|
502
|
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
446
503
|
reversed_layers_per_block = list(reversed(layers_per_block))
|
447
504
|
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
505
|
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
448
506
|
only_cross_attention = list(reversed(only_cross_attention))
|
449
507
|
|
450
508
|
output_channel = reversed_block_out_channels[0]
|
@@ -465,6 +523,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
465
523
|
up_block = get_up_block(
|
466
524
|
up_block_type,
|
467
525
|
num_layers=reversed_layers_per_block[i] + 1,
|
526
|
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
468
527
|
in_channels=input_channel,
|
469
528
|
out_channels=output_channel,
|
470
529
|
prev_output_channel=prev_output_channel,
|
@@ -474,7 +533,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
474
533
|
resnet_act_fn=act_fn,
|
475
534
|
resnet_groups=norm_num_groups,
|
476
535
|
cross_attention_dim=reversed_cross_attention_dim[i],
|
477
|
-
|
536
|
+
num_attention_heads=reversed_num_attention_heads[i],
|
478
537
|
dual_cross_attention=dual_cross_attention,
|
479
538
|
use_linear_projection=use_linear_projection,
|
480
539
|
only_cross_attention=only_cross_attention[i],
|
@@ -483,6 +542,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
483
542
|
resnet_skip_time_act=resnet_skip_time_act,
|
484
543
|
resnet_out_scale_factor=resnet_out_scale_factor,
|
485
544
|
cross_attention_norm=cross_attention_norm,
|
545
|
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
486
546
|
)
|
487
547
|
self.up_blocks.append(up_block)
|
488
548
|
prev_output_channel = output_channel
|
@@ -530,11 +590,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
530
590
|
|
531
591
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
532
592
|
r"""
|
593
|
+
Sets the attention processor to use to compute attention.
|
594
|
+
|
533
595
|
Parameters:
|
534
|
-
|
596
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
535
597
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
536
|
-
|
537
|
-
|
598
|
+
for **all** `Attention` layers.
|
599
|
+
|
600
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
601
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
538
602
|
|
539
603
|
"""
|
540
604
|
count = len(self.attn_processors.keys())
|
@@ -568,13 +632,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
568
632
|
r"""
|
569
633
|
Enable sliced attention computation.
|
570
634
|
|
571
|
-
When this option is enabled, the attention module
|
572
|
-
|
635
|
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
636
|
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
573
637
|
|
574
638
|
Args:
|
575
639
|
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
576
|
-
When `"auto"`,
|
577
|
-
`"max"`, maximum amount of memory
|
640
|
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
641
|
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
578
642
|
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
579
643
|
must be a multiple of `slice_size`.
|
580
644
|
"""
|
@@ -649,29 +713,31 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
649
713
|
return_dict: bool = True,
|
650
714
|
) -> Union[UNet2DConditionOutput, Tuple]:
|
651
715
|
r"""
|
716
|
+
The [`UNet2DConditionModel`] forward method.
|
717
|
+
|
652
718
|
Args:
|
653
|
-
sample (`torch.FloatTensor`):
|
654
|
-
|
655
|
-
|
719
|
+
sample (`torch.FloatTensor`):
|
720
|
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
721
|
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
722
|
+
encoder_hidden_states (`torch.FloatTensor`):
|
723
|
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
656
724
|
encoder_attention_mask (`torch.Tensor`):
|
657
|
-
(batch, sequence_length)
|
658
|
-
|
659
|
-
corresponding to "discard" tokens.
|
725
|
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
726
|
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
727
|
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
660
728
|
return_dict (`bool`, *optional*, defaults to `True`):
|
661
|
-
Whether or not to return a [
|
729
|
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
730
|
+
tuple.
|
662
731
|
cross_attention_kwargs (`dict`, *optional*):
|
663
|
-
A kwargs dictionary that if specified is passed along to the `
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time
|
668
|
-
embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and
|
669
|
-
`addition_embed_type` for more information.
|
732
|
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
733
|
+
added_cond_kwargs: (`dict`, *optional*):
|
734
|
+
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
735
|
+
are passed along to the UNet blocks.
|
670
736
|
|
671
737
|
Returns:
|
672
738
|
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
673
|
-
|
674
|
-
|
739
|
+
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
740
|
+
a `tuple` is returned where the first element is the sample tensor.
|
675
741
|
"""
|
676
742
|
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
677
743
|
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
@@ -737,6 +803,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
737
803
|
t_emb = t_emb.to(dtype=sample.dtype)
|
738
804
|
|
739
805
|
emb = self.time_embedding(t_emb, timestep_cond)
|
806
|
+
aug_emb = None
|
740
807
|
|
741
808
|
if self.class_embedding is not None:
|
742
809
|
if class_labels is None:
|
@@ -758,9 +825,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
758
825
|
|
759
826
|
if self.config.addition_embed_type == "text":
|
760
827
|
aug_emb = self.add_embedding(encoder_hidden_states)
|
761
|
-
emb = emb + aug_emb
|
762
828
|
elif self.config.addition_embed_type == "text_image":
|
763
|
-
#
|
829
|
+
# Kandinsky 2.1 - style
|
764
830
|
if "image_embeds" not in added_cond_kwargs:
|
765
831
|
raise ValueError(
|
766
832
|
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
@@ -768,9 +834,44 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
768
834
|
|
769
835
|
image_embs = added_cond_kwargs.get("image_embeds")
|
770
836
|
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
771
|
-
|
772
837
|
aug_emb = self.add_embedding(text_embs, image_embs)
|
773
|
-
|
838
|
+
elif self.config.addition_embed_type == "text_time":
|
839
|
+
if "text_embeds" not in added_cond_kwargs:
|
840
|
+
raise ValueError(
|
841
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
842
|
+
)
|
843
|
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
844
|
+
if "time_ids" not in added_cond_kwargs:
|
845
|
+
raise ValueError(
|
846
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
847
|
+
)
|
848
|
+
time_ids = added_cond_kwargs.get("time_ids")
|
849
|
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
850
|
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
851
|
+
|
852
|
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
853
|
+
add_embeds = add_embeds.to(emb.dtype)
|
854
|
+
aug_emb = self.add_embedding(add_embeds)
|
855
|
+
elif self.config.addition_embed_type == "image":
|
856
|
+
# Kandinsky 2.2 - style
|
857
|
+
if "image_embeds" not in added_cond_kwargs:
|
858
|
+
raise ValueError(
|
859
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
860
|
+
)
|
861
|
+
image_embs = added_cond_kwargs.get("image_embeds")
|
862
|
+
aug_emb = self.add_embedding(image_embs)
|
863
|
+
elif self.config.addition_embed_type == "image_hint":
|
864
|
+
# Kandinsky 2.2 - style
|
865
|
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
866
|
+
raise ValueError(
|
867
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
868
|
+
)
|
869
|
+
image_embs = added_cond_kwargs.get("image_embeds")
|
870
|
+
hint = added_cond_kwargs.get("hint")
|
871
|
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
872
|
+
sample = torch.cat([sample, hint], dim=1)
|
873
|
+
|
874
|
+
emb = emb + aug_emb if aug_emb is not None else emb
|
774
875
|
|
775
876
|
if self.time_embed_act is not None:
|
776
877
|
emb = self.time_embed_act(emb)
|
@@ -786,7 +887,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
786
887
|
|
787
888
|
image_embeds = added_cond_kwargs.get("image_embeds")
|
788
889
|
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
789
|
-
|
890
|
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
891
|
+
# Kandinsky 2.2 - style
|
892
|
+
if "image_embeds" not in added_cond_kwargs:
|
893
|
+
raise ValueError(
|
894
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
895
|
+
)
|
896
|
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
897
|
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
790
898
|
# 2. pre-process
|
791
899
|
sample = self.conv_in(sample)
|
792
900
|
|