diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -1
- diffusers/configuration_utils.py +34 -29
- diffusers/dependency_versions_table.py +4 -0
- diffusers/image_processor.py +125 -12
- diffusers/loaders.py +169 -203
- diffusers/models/attention.py +24 -1
- diffusers/models/attention_flax.py +10 -5
- diffusers/models/attention_processor.py +3 -0
- diffusers/models/autoencoder_kl.py +114 -33
- diffusers/models/controlnet.py +131 -14
- diffusers/models/controlnet_flax.py +37 -26
- diffusers/models/cross_attention.py +17 -17
- diffusers/models/embeddings.py +67 -0
- diffusers/models/modeling_flax_utils.py +64 -56
- diffusers/models/modeling_utils.py +193 -104
- diffusers/models/prior_transformer.py +207 -37
- diffusers/models/resnet.py +26 -26
- diffusers/models/transformer_2d.py +36 -41
- diffusers/models/transformer_temporal.py +24 -21
- diffusers/models/unet_1d.py +31 -25
- diffusers/models/unet_2d.py +43 -30
- diffusers/models/unet_2d_blocks.py +210 -89
- diffusers/models/unet_2d_blocks_flax.py +12 -12
- diffusers/models/unet_2d_condition.py +172 -64
- diffusers/models/unet_2d_condition_flax.py +38 -24
- diffusers/models/unet_3d_blocks.py +34 -31
- diffusers/models/unet_3d_condition.py +101 -34
- diffusers/models/vae.py +5 -5
- diffusers/models/vae_flax.py +37 -34
- diffusers/models/vq_model.py +23 -14
- diffusers/pipelines/__init__.py +24 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
- diffusers/pipelines/consistency_models/__init__.py +1 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
- diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/kandinsky/__init__.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
- diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_utils.py +124 -146
- diffusers/pipelines/shap_e/__init__.py +27 -0
- diffusers/pipelines/shap_e/camera.py +147 -0
- diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
- diffusers/pipelines/shap_e/renderer.py +709 -0
- diffusers/pipelines/stable_diffusion/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
- diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
- diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
- diffusers/schedulers/__init__.py +3 -0
- diffusers/schedulers/scheduling_consistency_models.py +380 -0
- diffusers/schedulers/scheduling_ddim.py +28 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
- diffusers/schedulers/scheduling_ddpm.py +53 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
- diffusers/schedulers/scheduling_deis_multistep.py +66 -11
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
- diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
- diffusers/schedulers/scheduling_euler_discrete.py +58 -8
- diffusers/schedulers/scheduling_heun_discrete.py +89 -14
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
- diffusers/schedulers/scheduling_lms_discrete.py +57 -8
- diffusers/schedulers/scheduling_pndm.py +46 -10
- diffusers/schedulers/scheduling_repaint.py +19 -4
- diffusers/schedulers/scheduling_sde_ve.py +5 -1
- diffusers/schedulers/scheduling_unclip.py +43 -4
- diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
- diffusers/utils/hub_utils.py +1 -1
- diffusers/utils/import_utils.py +20 -3
- diffusers/utils/logging.py +15 -18
- diffusers/utils/outputs.py +3 -3
- diffusers/utils/testing_utils.py +15 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
from typing import Tuple, Union
|
14
|
+
from typing import Optional, Tuple, Union
|
15
15
|
|
16
16
|
import flax
|
17
17
|
import flax.linen as nn
|
@@ -32,6 +32,14 @@ from .unet_2d_blocks_flax import (
|
|
32
32
|
|
33
33
|
@flax.struct.dataclass
|
34
34
|
class FlaxControlNetOutput(BaseOutput):
|
35
|
+
"""
|
36
|
+
The output of [`FlaxControlNetModel`].
|
37
|
+
|
38
|
+
Args:
|
39
|
+
down_block_res_samples (`jnp.ndarray`):
|
40
|
+
mid_block_res_sample (`jnp.ndarray`):
|
41
|
+
"""
|
42
|
+
|
35
43
|
down_block_res_samples: jnp.ndarray
|
36
44
|
mid_block_res_sample: jnp.ndarray
|
37
45
|
|
@@ -95,21 +103,17 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
|
|
95
103
|
@flax_register_to_config
|
96
104
|
class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
97
105
|
r"""
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
model
|
104
|
-
|
105
|
-
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
|
106
|
-
implements for all the models (such as downloading or saving, etc.)
|
107
|
-
|
108
|
-
Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
109
|
-
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
|
106
|
+
A ControlNet model.
|
107
|
+
|
108
|
+
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
|
109
|
+
implemented for all models (such as downloading or saving).
|
110
|
+
|
111
|
+
This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
112
|
+
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
|
110
113
|
general usage and behavior.
|
111
114
|
|
112
|
-
|
115
|
+
Inherent JAX features such as the following are supported:
|
116
|
+
|
113
117
|
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
114
118
|
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
115
119
|
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
@@ -120,15 +124,16 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
120
124
|
The size of the input sample.
|
121
125
|
in_channels (`int`, *optional*, defaults to 4):
|
122
126
|
The number of channels in the input sample.
|
123
|
-
down_block_types (`Tuple[str]`, *optional*, defaults to `("
|
124
|
-
The tuple of downsample blocks to use.
|
125
|
-
"FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
|
127
|
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
|
128
|
+
The tuple of downsample blocks to use.
|
126
129
|
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
127
130
|
The tuple of output channels for each block.
|
128
131
|
layers_per_block (`int`, *optional*, defaults to 2):
|
129
132
|
The number of layers per block.
|
130
133
|
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
|
131
134
|
The dimension of the attention heads.
|
135
|
+
num_attention_heads (`int` or `Tuple[int]`, *optional*):
|
136
|
+
The number of attention heads.
|
132
137
|
cross_attention_dim (`int`, *optional*, defaults to 768):
|
133
138
|
The dimension of the cross attention features.
|
134
139
|
dropout (`float`, *optional*, defaults to 0):
|
@@ -137,11 +142,9 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
137
142
|
Whether to flip the sin to cos in the time embedding.
|
138
143
|
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
139
144
|
controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
|
140
|
-
The channel order of conditional image. Will convert
|
145
|
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
141
146
|
conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
|
142
|
-
The tuple of output channel for each block in conditioning_embedding layer
|
143
|
-
|
144
|
-
|
147
|
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
145
148
|
"""
|
146
149
|
sample_size: int = 32
|
147
150
|
in_channels: int = 4
|
@@ -155,6 +158,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
155
158
|
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
156
159
|
layers_per_block: int = 2
|
157
160
|
attention_head_dim: Union[int, Tuple[int]] = 8
|
161
|
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
|
158
162
|
cross_attention_dim: int = 1280
|
159
163
|
dropout: float = 0.0
|
160
164
|
use_linear_projection: bool = False
|
@@ -182,6 +186,14 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
182
186
|
block_out_channels = self.block_out_channels
|
183
187
|
time_embed_dim = block_out_channels[0] * 4
|
184
188
|
|
189
|
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
190
|
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
191
|
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
192
|
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
193
|
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
194
|
+
# which is why we correct for the naming here.
|
195
|
+
num_attention_heads = self.num_attention_heads or self.attention_head_dim
|
196
|
+
|
185
197
|
# input
|
186
198
|
self.conv_in = nn.Conv(
|
187
199
|
block_out_channels[0],
|
@@ -206,9 +218,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
206
218
|
if isinstance(only_cross_attention, bool):
|
207
219
|
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
|
208
220
|
|
209
|
-
|
210
|
-
|
211
|
-
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
|
221
|
+
if isinstance(num_attention_heads, int):
|
222
|
+
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
|
212
223
|
|
213
224
|
# down
|
214
225
|
down_blocks = []
|
@@ -237,7 +248,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
237
248
|
out_channels=output_channel,
|
238
249
|
dropout=self.dropout,
|
239
250
|
num_layers=self.layers_per_block,
|
240
|
-
|
251
|
+
num_attention_heads=num_attention_heads[i],
|
241
252
|
add_downsample=not is_final_block,
|
242
253
|
use_linear_projection=self.use_linear_projection,
|
243
254
|
only_cross_attention=only_cross_attention[i],
|
@@ -285,7 +296,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
285
296
|
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
286
297
|
in_channels=mid_block_channel,
|
287
298
|
dropout=self.dropout,
|
288
|
-
|
299
|
+
num_attention_heads=num_attention_heads[-1],
|
289
300
|
use_linear_projection=self.use_linear_projection,
|
290
301
|
dtype=self.dtype,
|
291
302
|
)
|
@@ -29,7 +29,7 @@ from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F
|
|
29
29
|
|
30
30
|
deprecate(
|
31
31
|
"cross_attention",
|
32
|
-
"0.
|
32
|
+
"0.20.0",
|
33
33
|
"Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.",
|
34
34
|
standard_warn=False,
|
35
35
|
)
|
@@ -40,55 +40,55 @@ AttnProcessor = AttentionProcessor
|
|
40
40
|
|
41
41
|
class CrossAttention(Attention):
|
42
42
|
def __init__(self, *args, **kwargs):
|
43
|
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.
|
44
|
-
deprecate("cross_attention", "0.
|
43
|
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
44
|
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
45
45
|
super().__init__(*args, **kwargs)
|
46
46
|
|
47
47
|
|
48
48
|
class CrossAttnProcessor(AttnProcessorRename):
|
49
49
|
def __init__(self, *args, **kwargs):
|
50
|
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.
|
51
|
-
deprecate("cross_attention", "0.
|
50
|
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
51
|
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
52
52
|
super().__init__(*args, **kwargs)
|
53
53
|
|
54
54
|
|
55
55
|
class LoRACrossAttnProcessor(LoRAAttnProcessor):
|
56
56
|
def __init__(self, *args, **kwargs):
|
57
|
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.
|
58
|
-
deprecate("cross_attention", "0.
|
57
|
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
58
|
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
59
59
|
super().__init__(*args, **kwargs)
|
60
60
|
|
61
61
|
|
62
62
|
class CrossAttnAddedKVProcessor(AttnAddedKVProcessor):
|
63
63
|
def __init__(self, *args, **kwargs):
|
64
|
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.
|
65
|
-
deprecate("cross_attention", "0.
|
64
|
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
65
|
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
66
66
|
super().__init__(*args, **kwargs)
|
67
67
|
|
68
68
|
|
69
69
|
class XFormersCrossAttnProcessor(XFormersAttnProcessor):
|
70
70
|
def __init__(self, *args, **kwargs):
|
71
|
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.
|
72
|
-
deprecate("cross_attention", "0.
|
71
|
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
72
|
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
73
73
|
super().__init__(*args, **kwargs)
|
74
74
|
|
75
75
|
|
76
76
|
class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor):
|
77
77
|
def __init__(self, *args, **kwargs):
|
78
|
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.
|
79
|
-
deprecate("cross_attention", "0.
|
78
|
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
79
|
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
80
80
|
super().__init__(*args, **kwargs)
|
81
81
|
|
82
82
|
|
83
83
|
class SlicedCrossAttnProcessor(SlicedAttnProcessor):
|
84
84
|
def __init__(self, *args, **kwargs):
|
85
|
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.
|
86
|
-
deprecate("cross_attention", "0.
|
85
|
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
86
|
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
87
87
|
super().__init__(*args, **kwargs)
|
88
88
|
|
89
89
|
|
90
90
|
class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor):
|
91
91
|
def __init__(self, *args, **kwargs):
|
92
|
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.
|
93
|
-
deprecate("cross_attention", "0.
|
92
|
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
93
|
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
94
94
|
super().__init__(*args, **kwargs)
|
diffusers/models/embeddings.py
CHANGED
@@ -376,6 +376,29 @@ class TextImageProjection(nn.Module):
|
|
376
376
|
return torch.cat([image_text_embeds, text_embeds], dim=1)
|
377
377
|
|
378
378
|
|
379
|
+
class ImageProjection(nn.Module):
|
380
|
+
def __init__(
|
381
|
+
self,
|
382
|
+
image_embed_dim: int = 768,
|
383
|
+
cross_attention_dim: int = 768,
|
384
|
+
num_image_text_embeds: int = 32,
|
385
|
+
):
|
386
|
+
super().__init__()
|
387
|
+
|
388
|
+
self.num_image_text_embeds = num_image_text_embeds
|
389
|
+
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
390
|
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
391
|
+
|
392
|
+
def forward(self, image_embeds: torch.FloatTensor):
|
393
|
+
batch_size = image_embeds.shape[0]
|
394
|
+
|
395
|
+
# image
|
396
|
+
image_embeds = self.image_embeds(image_embeds)
|
397
|
+
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
398
|
+
image_embeds = self.norm(image_embeds)
|
399
|
+
return image_embeds
|
400
|
+
|
401
|
+
|
379
402
|
class CombinedTimestepLabelEmbeddings(nn.Module):
|
380
403
|
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
381
404
|
super().__init__()
|
@@ -429,6 +452,50 @@ class TextImageTimeEmbedding(nn.Module):
|
|
429
452
|
return time_image_embeds + time_text_embeds
|
430
453
|
|
431
454
|
|
455
|
+
class ImageTimeEmbedding(nn.Module):
|
456
|
+
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
457
|
+
super().__init__()
|
458
|
+
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
459
|
+
self.image_norm = nn.LayerNorm(time_embed_dim)
|
460
|
+
|
461
|
+
def forward(self, image_embeds: torch.FloatTensor):
|
462
|
+
# image
|
463
|
+
time_image_embeds = self.image_proj(image_embeds)
|
464
|
+
time_image_embeds = self.image_norm(time_image_embeds)
|
465
|
+
return time_image_embeds
|
466
|
+
|
467
|
+
|
468
|
+
class ImageHintTimeEmbedding(nn.Module):
|
469
|
+
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
470
|
+
super().__init__()
|
471
|
+
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
472
|
+
self.image_norm = nn.LayerNorm(time_embed_dim)
|
473
|
+
self.input_hint_block = nn.Sequential(
|
474
|
+
nn.Conv2d(3, 16, 3, padding=1),
|
475
|
+
nn.SiLU(),
|
476
|
+
nn.Conv2d(16, 16, 3, padding=1),
|
477
|
+
nn.SiLU(),
|
478
|
+
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
479
|
+
nn.SiLU(),
|
480
|
+
nn.Conv2d(32, 32, 3, padding=1),
|
481
|
+
nn.SiLU(),
|
482
|
+
nn.Conv2d(32, 96, 3, padding=1, stride=2),
|
483
|
+
nn.SiLU(),
|
484
|
+
nn.Conv2d(96, 96, 3, padding=1),
|
485
|
+
nn.SiLU(),
|
486
|
+
nn.Conv2d(96, 256, 3, padding=1, stride=2),
|
487
|
+
nn.SiLU(),
|
488
|
+
nn.Conv2d(256, 4, 3, padding=1),
|
489
|
+
)
|
490
|
+
|
491
|
+
def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
|
492
|
+
# image
|
493
|
+
time_image_embeds = self.image_proj(image_embeds)
|
494
|
+
time_image_embeds = self.image_norm(time_image_embeds)
|
495
|
+
hint = self.input_hint_block(hint)
|
496
|
+
return time_image_embeds, hint
|
497
|
+
|
498
|
+
|
432
499
|
class AttentionPooling(nn.Module):
|
433
500
|
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
|
434
501
|
|
@@ -44,10 +44,12 @@ logger = logging.get_logger(__name__)
|
|
44
44
|
|
45
45
|
class FlaxModelMixin:
|
46
46
|
r"""
|
47
|
-
Base class for all
|
47
|
+
Base class for all Flax models.
|
48
48
|
|
49
|
-
[`FlaxModelMixin`] takes care of storing the configuration
|
50
|
-
|
49
|
+
[`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
|
50
|
+
saving models.
|
51
|
+
|
52
|
+
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
|
51
53
|
"""
|
52
54
|
config_name = CONFIG_NAME
|
53
55
|
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
@@ -89,15 +91,15 @@ class FlaxModelMixin:
|
|
89
91
|
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
|
90
92
|
the `params` in place.
|
91
93
|
|
92
|
-
This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
|
94
|
+
This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
|
93
95
|
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
|
94
96
|
|
95
97
|
Arguments:
|
96
98
|
params (`Union[Dict, FrozenDict]`):
|
97
99
|
A `PyTree` of model parameters.
|
98
100
|
mask (`Union[Dict, FrozenDict]`):
|
99
|
-
A `PyTree` with same structure as the `params` tree. The leaves should be booleans
|
100
|
-
you want to cast, and
|
101
|
+
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
|
102
|
+
for params you want to cast, and `False` for those you want to skip.
|
101
103
|
|
102
104
|
Examples:
|
103
105
|
|
@@ -132,8 +134,8 @@ class FlaxModelMixin:
|
|
132
134
|
params (`Union[Dict, FrozenDict]`):
|
133
135
|
A `PyTree` of model parameters.
|
134
136
|
mask (`Union[Dict, FrozenDict]`):
|
135
|
-
A `PyTree` with same structure as the `params` tree. The leaves should be booleans
|
136
|
-
you want to cast, and
|
137
|
+
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
|
138
|
+
for params you want to cast, and `False` for those you want to skip.
|
137
139
|
|
138
140
|
Examples:
|
139
141
|
|
@@ -155,15 +157,15 @@ class FlaxModelMixin:
|
|
155
157
|
Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
|
156
158
|
`params` in place.
|
157
159
|
|
158
|
-
This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
|
160
|
+
This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
|
159
161
|
half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
|
160
162
|
|
161
163
|
Arguments:
|
162
164
|
params (`Union[Dict, FrozenDict]`):
|
163
165
|
A `PyTree` of model parameters.
|
164
166
|
mask (`Union[Dict, FrozenDict]`):
|
165
|
-
A `PyTree` with same structure as the `params` tree. The leaves should be booleans
|
166
|
-
you want to cast, and
|
167
|
+
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
|
168
|
+
for params you want to cast, and `False` for those you want to skip.
|
167
169
|
|
168
170
|
Examples:
|
169
171
|
|
@@ -201,71 +203,68 @@ class FlaxModelMixin:
|
|
201
203
|
**kwargs,
|
202
204
|
):
|
203
205
|
r"""
|
204
|
-
Instantiate a pretrained
|
205
|
-
|
206
|
-
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
207
|
-
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
208
|
-
task.
|
209
|
-
|
210
|
-
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
211
|
-
weights are discarded.
|
206
|
+
Instantiate a pretrained Flax model from a pretrained model configuration.
|
212
207
|
|
213
208
|
Parameters:
|
214
209
|
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
215
210
|
Can be either:
|
216
211
|
|
217
|
-
- A string, the *model id* of a pretrained model
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
e.g., `./my_model_directory/`.
|
212
|
+
- A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
|
213
|
+
hosted on the Hub.
|
214
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
215
|
+
using [`~FlaxModelMixin.save_pretrained`].
|
222
216
|
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
223
217
|
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
224
218
|
`jax.numpy.bfloat16` (on TPUs).
|
225
219
|
|
226
220
|
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
227
|
-
specified all the computation will be performed with the given `dtype`.
|
221
|
+
specified, all the computation will be performed with the given `dtype`.
|
222
|
+
|
223
|
+
<Tip>
|
224
|
+
|
225
|
+
This only specifies the dtype of the *computation* and does not influence the dtype of model
|
226
|
+
parameters.
|
228
227
|
|
229
|
-
|
230
|
-
|
228
|
+
If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
|
229
|
+
[`~FlaxModelMixin.to_bf16`].
|
230
|
+
|
231
|
+
</Tip>
|
231
232
|
|
232
|
-
If you wish to change the dtype of the model parameters, see [`~ModelMixin.to_fp16`] and
|
233
|
-
[`~ModelMixin.to_bf16`].
|
234
233
|
model_args (sequence of positional arguments, *optional*):
|
235
|
-
All remaining positional arguments
|
234
|
+
All remaining positional arguments are passed to the underlying model's `__init__` method.
|
236
235
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
237
|
-
Path to a directory
|
238
|
-
|
236
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
237
|
+
is not used.
|
239
238
|
force_download (`bool`, *optional*, defaults to `False`):
|
240
239
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
241
240
|
cached versions if they exist.
|
242
241
|
resume_download (`bool`, *optional*, defaults to `False`):
|
243
|
-
Whether or not to
|
244
|
-
|
242
|
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
243
|
+
incompletely downloaded files are deleted.
|
245
244
|
proxies (`Dict[str, str]`, *optional*):
|
246
|
-
A dictionary of proxy servers to use by protocol or endpoint,
|
245
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
247
246
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
248
247
|
local_files_only(`bool`, *optional*, defaults to `False`):
|
249
|
-
Whether
|
248
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
249
|
+
won't be downloaded from the Hub.
|
250
250
|
revision (`str`, *optional*, defaults to `"main"`):
|
251
|
-
The specific model version to use. It can be a branch name, a tag name,
|
252
|
-
|
253
|
-
identifier allowed by git.
|
251
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
252
|
+
allowed by Git.
|
254
253
|
from_pt (`bool`, *optional*, defaults to `False`):
|
255
254
|
Load the model weights from a PyTorch checkpoint save file.
|
256
255
|
kwargs (remaining dictionary of keyword arguments, *optional*):
|
257
|
-
Can be used to update the configuration object (after it
|
258
|
-
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
256
|
+
Can be used to update the configuration object (after it is loaded) and initiate the model (for
|
257
|
+
example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
259
258
|
automatically loaded:
|
260
259
|
|
261
|
-
- If a configuration is provided with `config`,
|
262
|
-
|
263
|
-
|
264
|
-
- If a configuration is not provided, `kwargs`
|
265
|
-
initialization function
|
266
|
-
a configuration attribute
|
267
|
-
|
268
|
-
|
260
|
+
- If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
|
261
|
+
model's `__init__` method (we assume all relevant updates to the configuration have already been
|
262
|
+
done).
|
263
|
+
- If a configuration is not provided, `kwargs` are first passed to the configuration class
|
264
|
+
initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
|
265
|
+
to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
|
266
|
+
Remaining keys that do not correspond to any configuration attribute are passed to the underlying
|
267
|
+
model's `__init__` function.
|
269
268
|
|
270
269
|
Examples:
|
271
270
|
|
@@ -276,7 +275,16 @@ class FlaxModelMixin:
|
|
276
275
|
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
277
276
|
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
|
278
277
|
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
|
279
|
-
```
|
278
|
+
```
|
279
|
+
|
280
|
+
If you get the error message below, you need to finetune the weights for your downstream task:
|
281
|
+
|
282
|
+
```bash
|
283
|
+
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
284
|
+
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
285
|
+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
286
|
+
```
|
287
|
+
"""
|
280
288
|
config = kwargs.pop("config", None)
|
281
289
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
282
290
|
force_download = kwargs.pop("force_download", False)
|
@@ -491,18 +499,18 @@ class FlaxModelMixin:
|
|
491
499
|
is_main_process: bool = True,
|
492
500
|
):
|
493
501
|
"""
|
494
|
-
Save a model and its configuration file to a directory
|
495
|
-
|
502
|
+
Save a model and its configuration file to a directory so that it can be reloaded using the
|
503
|
+
[`~FlaxModelMixin.from_pretrained`] class method.
|
496
504
|
|
497
505
|
Arguments:
|
498
506
|
save_directory (`str` or `os.PathLike`):
|
499
|
-
Directory to
|
507
|
+
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
|
500
508
|
params (`Union[Dict, FrozenDict]`):
|
501
509
|
A `PyTree` of model parameters.
|
502
510
|
is_main_process (`bool`, *optional*, defaults to `True`):
|
503
|
-
Whether the process calling this is the main process or not. Useful
|
504
|
-
|
505
|
-
|
511
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
512
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
513
|
+
process to avoid race conditions.
|
506
514
|
"""
|
507
515
|
if os.path.isfile(save_directory):
|
508
516
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|