diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Tuple, Union
14
+ from typing import Optional, Tuple, Union
15
15
 
16
16
  import flax
17
17
  import flax.linen as nn
@@ -32,6 +32,14 @@ from .unet_2d_blocks_flax import (
32
32
 
33
33
  @flax.struct.dataclass
34
34
  class FlaxControlNetOutput(BaseOutput):
35
+ """
36
+ The output of [`FlaxControlNetModel`].
37
+
38
+ Args:
39
+ down_block_res_samples (`jnp.ndarray`):
40
+ mid_block_res_sample (`jnp.ndarray`):
41
+ """
42
+
35
43
  down_block_res_samples: jnp.ndarray
36
44
  mid_block_res_sample: jnp.ndarray
37
45
 
@@ -95,21 +103,17 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
95
103
  @flax_register_to_config
96
104
  class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
97
105
  r"""
98
- Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
99
- [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
100
- training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
101
- convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
102
- (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
103
- model) to encode image-space conditions ... into feature maps ..."
104
-
105
- This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
106
- implements for all the models (such as downloading or saving, etc.)
107
-
108
- Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
109
- subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
106
+ A ControlNet model.
107
+
108
+ This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
109
+ implemented for all models (such as downloading or saving).
110
+
111
+ This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
112
+ subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
110
113
  general usage and behavior.
111
114
 
112
- Finally, this model supports inherent JAX features such as:
115
+ Inherent JAX features such as the following are supported:
116
+
113
117
  - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
114
118
  - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
115
119
  - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
@@ -120,15 +124,16 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
120
124
  The size of the input sample.
121
125
  in_channels (`int`, *optional*, defaults to 4):
122
126
  The number of channels in the input sample.
123
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
124
- The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D",
125
- "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
127
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
128
+ The tuple of downsample blocks to use.
126
129
  block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
127
130
  The tuple of output channels for each block.
128
131
  layers_per_block (`int`, *optional*, defaults to 2):
129
132
  The number of layers per block.
130
133
  attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
131
134
  The dimension of the attention heads.
135
+ num_attention_heads (`int` or `Tuple[int]`, *optional*):
136
+ The number of attention heads.
132
137
  cross_attention_dim (`int`, *optional*, defaults to 768):
133
138
  The dimension of the cross attention features.
134
139
  dropout (`float`, *optional*, defaults to 0):
@@ -137,11 +142,9 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
137
142
  Whether to flip the sin to cos in the time embedding.
138
143
  freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
139
144
  controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
140
- The channel order of conditional image. Will convert it to `rgb` if it's `bgr`
145
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
141
146
  conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
142
- The tuple of output channel for each block in conditioning_embedding layer
143
-
144
-
147
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
145
148
  """
146
149
  sample_size: int = 32
147
150
  in_channels: int = 4
@@ -155,6 +158,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
155
158
  block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
156
159
  layers_per_block: int = 2
157
160
  attention_head_dim: Union[int, Tuple[int]] = 8
161
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None
158
162
  cross_attention_dim: int = 1280
159
163
  dropout: float = 0.0
160
164
  use_linear_projection: bool = False
@@ -182,6 +186,14 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
182
186
  block_out_channels = self.block_out_channels
183
187
  time_embed_dim = block_out_channels[0] * 4
184
188
 
189
+ # If `num_attention_heads` is not defined (which is the case for most models)
190
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
191
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
192
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
193
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
194
+ # which is why we correct for the naming here.
195
+ num_attention_heads = self.num_attention_heads or self.attention_head_dim
196
+
185
197
  # input
186
198
  self.conv_in = nn.Conv(
187
199
  block_out_channels[0],
@@ -206,9 +218,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
206
218
  if isinstance(only_cross_attention, bool):
207
219
  only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
208
220
 
209
- attention_head_dim = self.attention_head_dim
210
- if isinstance(attention_head_dim, int):
211
- attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
221
+ if isinstance(num_attention_heads, int):
222
+ num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
212
223
 
213
224
  # down
214
225
  down_blocks = []
@@ -237,7 +248,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
237
248
  out_channels=output_channel,
238
249
  dropout=self.dropout,
239
250
  num_layers=self.layers_per_block,
240
- attn_num_head_channels=attention_head_dim[i],
251
+ num_attention_heads=num_attention_heads[i],
241
252
  add_downsample=not is_final_block,
242
253
  use_linear_projection=self.use_linear_projection,
243
254
  only_cross_attention=only_cross_attention[i],
@@ -285,7 +296,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
285
296
  self.mid_block = FlaxUNetMidBlock2DCrossAttn(
286
297
  in_channels=mid_block_channel,
287
298
  dropout=self.dropout,
288
- attn_num_head_channels=attention_head_dim[-1],
299
+ num_attention_heads=num_attention_heads[-1],
289
300
  use_linear_projection=self.use_linear_projection,
290
301
  dtype=self.dtype,
291
302
  )
@@ -29,7 +29,7 @@ from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F
29
29
 
30
30
  deprecate(
31
31
  "cross_attention",
32
- "0.18.0",
32
+ "0.20.0",
33
33
  "Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.",
34
34
  standard_warn=False,
35
35
  )
@@ -40,55 +40,55 @@ AttnProcessor = AttentionProcessor
40
40
 
41
41
  class CrossAttention(Attention):
42
42
  def __init__(self, *args, **kwargs):
43
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
44
- deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
43
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
44
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
45
45
  super().__init__(*args, **kwargs)
46
46
 
47
47
 
48
48
  class CrossAttnProcessor(AttnProcessorRename):
49
49
  def __init__(self, *args, **kwargs):
50
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
51
- deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
50
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
51
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
52
52
  super().__init__(*args, **kwargs)
53
53
 
54
54
 
55
55
  class LoRACrossAttnProcessor(LoRAAttnProcessor):
56
56
  def __init__(self, *args, **kwargs):
57
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
58
- deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
57
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
58
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
59
59
  super().__init__(*args, **kwargs)
60
60
 
61
61
 
62
62
  class CrossAttnAddedKVProcessor(AttnAddedKVProcessor):
63
63
  def __init__(self, *args, **kwargs):
64
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
65
- deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
64
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
65
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
66
66
  super().__init__(*args, **kwargs)
67
67
 
68
68
 
69
69
  class XFormersCrossAttnProcessor(XFormersAttnProcessor):
70
70
  def __init__(self, *args, **kwargs):
71
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
72
- deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
71
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
72
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
73
73
  super().__init__(*args, **kwargs)
74
74
 
75
75
 
76
76
  class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor):
77
77
  def __init__(self, *args, **kwargs):
78
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
79
- deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
78
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
79
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
80
80
  super().__init__(*args, **kwargs)
81
81
 
82
82
 
83
83
  class SlicedCrossAttnProcessor(SlicedAttnProcessor):
84
84
  def __init__(self, *args, **kwargs):
85
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
86
- deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
85
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
86
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
87
87
  super().__init__(*args, **kwargs)
88
88
 
89
89
 
90
90
  class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor):
91
91
  def __init__(self, *args, **kwargs):
92
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
93
- deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
92
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
93
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
94
94
  super().__init__(*args, **kwargs)
@@ -376,6 +376,29 @@ class TextImageProjection(nn.Module):
376
376
  return torch.cat([image_text_embeds, text_embeds], dim=1)
377
377
 
378
378
 
379
+ class ImageProjection(nn.Module):
380
+ def __init__(
381
+ self,
382
+ image_embed_dim: int = 768,
383
+ cross_attention_dim: int = 768,
384
+ num_image_text_embeds: int = 32,
385
+ ):
386
+ super().__init__()
387
+
388
+ self.num_image_text_embeds = num_image_text_embeds
389
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
390
+ self.norm = nn.LayerNorm(cross_attention_dim)
391
+
392
+ def forward(self, image_embeds: torch.FloatTensor):
393
+ batch_size = image_embeds.shape[0]
394
+
395
+ # image
396
+ image_embeds = self.image_embeds(image_embeds)
397
+ image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
398
+ image_embeds = self.norm(image_embeds)
399
+ return image_embeds
400
+
401
+
379
402
  class CombinedTimestepLabelEmbeddings(nn.Module):
380
403
  def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
381
404
  super().__init__()
@@ -429,6 +452,50 @@ class TextImageTimeEmbedding(nn.Module):
429
452
  return time_image_embeds + time_text_embeds
430
453
 
431
454
 
455
+ class ImageTimeEmbedding(nn.Module):
456
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
457
+ super().__init__()
458
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
459
+ self.image_norm = nn.LayerNorm(time_embed_dim)
460
+
461
+ def forward(self, image_embeds: torch.FloatTensor):
462
+ # image
463
+ time_image_embeds = self.image_proj(image_embeds)
464
+ time_image_embeds = self.image_norm(time_image_embeds)
465
+ return time_image_embeds
466
+
467
+
468
+ class ImageHintTimeEmbedding(nn.Module):
469
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
470
+ super().__init__()
471
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
472
+ self.image_norm = nn.LayerNorm(time_embed_dim)
473
+ self.input_hint_block = nn.Sequential(
474
+ nn.Conv2d(3, 16, 3, padding=1),
475
+ nn.SiLU(),
476
+ nn.Conv2d(16, 16, 3, padding=1),
477
+ nn.SiLU(),
478
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
479
+ nn.SiLU(),
480
+ nn.Conv2d(32, 32, 3, padding=1),
481
+ nn.SiLU(),
482
+ nn.Conv2d(32, 96, 3, padding=1, stride=2),
483
+ nn.SiLU(),
484
+ nn.Conv2d(96, 96, 3, padding=1),
485
+ nn.SiLU(),
486
+ nn.Conv2d(96, 256, 3, padding=1, stride=2),
487
+ nn.SiLU(),
488
+ nn.Conv2d(256, 4, 3, padding=1),
489
+ )
490
+
491
+ def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
492
+ # image
493
+ time_image_embeds = self.image_proj(image_embeds)
494
+ time_image_embeds = self.image_norm(time_image_embeds)
495
+ hint = self.input_hint_block(hint)
496
+ return time_image_embeds, hint
497
+
498
+
432
499
  class AttentionPooling(nn.Module):
433
500
  # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
434
501
 
@@ -44,10 +44,12 @@ logger = logging.get_logger(__name__)
44
44
 
45
45
  class FlaxModelMixin:
46
46
  r"""
47
- Base class for all flax models.
47
+ Base class for all Flax models.
48
48
 
49
- [`FlaxModelMixin`] takes care of storing the configuration of the models and handles methods for loading,
50
- downloading and saving models.
49
+ [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
50
+ saving models.
51
+
52
+ - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
51
53
  """
52
54
  config_name = CONFIG_NAME
53
55
  _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
@@ -89,15 +91,15 @@ class FlaxModelMixin:
89
91
  Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
90
92
  the `params` in place.
91
93
 
92
- This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
94
+ This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
93
95
  half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
94
96
 
95
97
  Arguments:
96
98
  params (`Union[Dict, FrozenDict]`):
97
99
  A `PyTree` of model parameters.
98
100
  mask (`Union[Dict, FrozenDict]`):
99
- A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
100
- you want to cast, and should be `False` for those you want to skip.
101
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
102
+ for params you want to cast, and `False` for those you want to skip.
101
103
 
102
104
  Examples:
103
105
 
@@ -132,8 +134,8 @@ class FlaxModelMixin:
132
134
  params (`Union[Dict, FrozenDict]`):
133
135
  A `PyTree` of model parameters.
134
136
  mask (`Union[Dict, FrozenDict]`):
135
- A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
136
- you want to cast, and should be `False` for those you want to skip
137
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
138
+ for params you want to cast, and `False` for those you want to skip.
137
139
 
138
140
  Examples:
139
141
 
@@ -155,15 +157,15 @@ class FlaxModelMixin:
155
157
  Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
156
158
  `params` in place.
157
159
 
158
- This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
160
+ This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
159
161
  half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
160
162
 
161
163
  Arguments:
162
164
  params (`Union[Dict, FrozenDict]`):
163
165
  A `PyTree` of model parameters.
164
166
  mask (`Union[Dict, FrozenDict]`):
165
- A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
166
- you want to cast, and should be `False` for those you want to skip
167
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
168
+ for params you want to cast, and `False` for those you want to skip.
167
169
 
168
170
  Examples:
169
171
 
@@ -201,71 +203,68 @@ class FlaxModelMixin:
201
203
  **kwargs,
202
204
  ):
203
205
  r"""
204
- Instantiate a pretrained flax model from a pre-trained model configuration.
205
-
206
- The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
207
- pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
208
- task.
209
-
210
- The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
211
- weights are discarded.
206
+ Instantiate a pretrained Flax model from a pretrained model configuration.
212
207
 
213
208
  Parameters:
214
209
  pretrained_model_name_or_path (`str` or `os.PathLike`):
215
210
  Can be either:
216
211
 
217
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
218
- Valid model ids are namespaced under a user or organization name, like
219
- `runwayml/stable-diffusion-v1-5`.
220
- - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
221
- e.g., `./my_model_directory/`.
212
+ - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
213
+ hosted on the Hub.
214
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
215
+ using [`~FlaxModelMixin.save_pretrained`].
222
216
  dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
223
217
  The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
224
218
  `jax.numpy.bfloat16` (on TPUs).
225
219
 
226
220
  This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
227
- specified all the computation will be performed with the given `dtype`.
221
+ specified, all the computation will be performed with the given `dtype`.
222
+
223
+ <Tip>
224
+
225
+ This only specifies the dtype of the *computation* and does not influence the dtype of model
226
+ parameters.
228
227
 
229
- **Note that this only specifies the dtype of the computation and does not influence the dtype of model
230
- parameters.**
228
+ If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
229
+ [`~FlaxModelMixin.to_bf16`].
230
+
231
+ </Tip>
231
232
 
232
- If you wish to change the dtype of the model parameters, see [`~ModelMixin.to_fp16`] and
233
- [`~ModelMixin.to_bf16`].
234
233
  model_args (sequence of positional arguments, *optional*):
235
- All remaining positional arguments will be passed to the underlying model's `__init__` method.
234
+ All remaining positional arguments are passed to the underlying model's `__init__` method.
236
235
  cache_dir (`Union[str, os.PathLike]`, *optional*):
237
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
238
- standard cache should not be used.
236
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
237
+ is not used.
239
238
  force_download (`bool`, *optional*, defaults to `False`):
240
239
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
241
240
  cached versions if they exist.
242
241
  resume_download (`bool`, *optional*, defaults to `False`):
243
- Whether or not to delete incompletely received files. Will attempt to resume the download if such a
244
- file exists.
242
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
243
+ incompletely downloaded files are deleted.
245
244
  proxies (`Dict[str, str]`, *optional*):
246
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
245
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
247
246
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
248
247
  local_files_only(`bool`, *optional*, defaults to `False`):
249
- Whether or not to only look at local files (i.e., do not try to download the model).
248
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
249
+ won't be downloaded from the Hub.
250
250
  revision (`str`, *optional*, defaults to `"main"`):
251
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
252
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
253
- identifier allowed by git.
251
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
252
+ allowed by Git.
254
253
  from_pt (`bool`, *optional*, defaults to `False`):
255
254
  Load the model weights from a PyTorch checkpoint save file.
256
255
  kwargs (remaining dictionary of keyword arguments, *optional*):
257
- Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
258
- `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
256
+ Can be used to update the configuration object (after it is loaded) and initiate the model (for
257
+ example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
259
258
  automatically loaded:
260
259
 
261
- - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
262
- underlying model's `__init__` method (we assume all relevant updates to the configuration have
263
- already been done)
264
- - If a configuration is not provided, `kwargs` will be first passed to the configuration class
265
- initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that corresponds to
266
- a configuration attribute will be used to override said attribute with the supplied `kwargs`
267
- value. Remaining keys that do not correspond to any configuration attribute will be passed to the
268
- underlying model's `__init__` function.
260
+ - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
261
+ model's `__init__` method (we assume all relevant updates to the configuration have already been
262
+ done).
263
+ - If a configuration is not provided, `kwargs` are first passed to the configuration class
264
+ initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
265
+ to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
266
+ Remaining keys that do not correspond to any configuration attribute are passed to the underlying
267
+ model's `__init__` function.
269
268
 
270
269
  Examples:
271
270
 
@@ -276,7 +275,16 @@ class FlaxModelMixin:
276
275
  >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
277
276
  >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
278
277
  >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
279
- ```"""
278
+ ```
279
+
280
+ If you get the error message below, you need to finetune the weights for your downstream task:
281
+
282
+ ```bash
283
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
284
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
285
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
286
+ ```
287
+ """
280
288
  config = kwargs.pop("config", None)
281
289
  cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
282
290
  force_download = kwargs.pop("force_download", False)
@@ -491,18 +499,18 @@ class FlaxModelMixin:
491
499
  is_main_process: bool = True,
492
500
  ):
493
501
  """
494
- Save a model and its configuration file to a directory, so that it can be re-loaded using the
495
- `[`~FlaxModelMixin.from_pretrained`]` class method
502
+ Save a model and its configuration file to a directory so that it can be reloaded using the
503
+ [`~FlaxModelMixin.from_pretrained`] class method.
496
504
 
497
505
  Arguments:
498
506
  save_directory (`str` or `os.PathLike`):
499
- Directory to which to save. Will be created if it doesn't exist.
507
+ Directory to save a model and its configuration file to. Will be created if it doesn't exist.
500
508
  params (`Union[Dict, FrozenDict]`):
501
509
  A `PyTree` of model parameters.
502
510
  is_main_process (`bool`, *optional*, defaults to `True`):
503
- Whether the process calling this is the main process or not. Useful when in distributed training like
504
- TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
505
- the main process to avoid race conditions.
511
+ Whether the process calling this is the main process or not. Useful during distributed training and you
512
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
513
+ process to avoid race conditions.
506
514
  """
507
515
  if os.path.isfile(save_directory):
508
516
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")