diffusers 0.19.3__py3-none-any.whl → 0.20.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (114) hide show
  1. diffusers/__init__.py +3 -1
  2. diffusers/commands/fp16_safetensors.py +2 -7
  3. diffusers/configuration_utils.py +23 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/loaders.py +62 -64
  6. diffusers/models/__init__.py +1 -0
  7. diffusers/models/activations.py +2 -0
  8. diffusers/models/attention.py +45 -1
  9. diffusers/models/autoencoder_tiny.py +193 -0
  10. diffusers/models/controlnet.py +1 -1
  11. diffusers/models/embeddings.py +56 -0
  12. diffusers/models/lora.py +0 -6
  13. diffusers/models/modeling_flax_utils.py +28 -2
  14. diffusers/models/modeling_utils.py +33 -16
  15. diffusers/models/transformer_2d.py +26 -9
  16. diffusers/models/unet_1d.py +2 -2
  17. diffusers/models/unet_2d_blocks.py +106 -56
  18. diffusers/models/unet_2d_condition.py +20 -5
  19. diffusers/models/vae.py +106 -1
  20. diffusers/pipelines/__init__.py +1 -0
  21. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +10 -3
  22. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -3
  23. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  24. diffusers/pipelines/auto_pipeline.py +33 -43
  25. diffusers/pipelines/controlnet/multicontrolnet.py +4 -2
  26. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -4
  27. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +15 -7
  28. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +14 -4
  29. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +157 -10
  30. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -10
  31. diffusers/pipelines/deepfloyd_if/pipeline_if.py +1 -1
  32. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +1 -1
  33. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1 -1
  34. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1 -1
  35. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1 -1
  36. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +1 -1
  37. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +43 -2
  38. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +44 -2
  39. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  40. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  41. diffusers/pipelines/pipeline_flax_utils.py +41 -4
  42. diffusers/pipelines/pipeline_utils.py +60 -16
  43. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +2 -2
  44. diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  45. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +81 -37
  46. diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +10 -3
  47. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -3
  48. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -3
  49. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +10 -3
  50. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +12 -5
  51. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +832 -0
  52. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -3
  53. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +10 -3
  54. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +10 -3
  55. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +9 -2
  56. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +17 -8
  57. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +10 -3
  58. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +10 -3
  59. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +10 -3
  60. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +10 -3
  61. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +10 -3
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +10 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +10 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +10 -3
  65. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +3 -5
  66. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +75 -3
  67. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +76 -6
  68. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +1 -2
  69. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +10 -3
  70. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +10 -3
  71. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +11 -4
  72. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +1 -1
  73. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +131 -28
  74. diffusers/schedulers/scheduling_consistency_models.py +70 -57
  75. diffusers/schedulers/scheduling_ddim.py +76 -71
  76. diffusers/schedulers/scheduling_ddim_inverse.py +76 -44
  77. diffusers/schedulers/scheduling_ddim_parallel.py +11 -8
  78. diffusers/schedulers/scheduling_ddpm.py +68 -67
  79. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -15
  80. diffusers/schedulers/scheduling_deis_multistep.py +93 -85
  81. diffusers/schedulers/scheduling_dpmsolver_multistep.py +118 -120
  82. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +116 -109
  83. diffusers/schedulers/scheduling_dpmsolver_sde.py +57 -43
  84. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +122 -121
  85. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +54 -44
  86. diffusers/schedulers/scheduling_euler_discrete.py +63 -56
  87. diffusers/schedulers/scheduling_heun_discrete.py +57 -45
  88. diffusers/schedulers/scheduling_ipndm.py +27 -22
  89. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +54 -41
  90. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +52 -41
  91. diffusers/schedulers/scheduling_karras_ve.py +55 -45
  92. diffusers/schedulers/scheduling_lms_discrete.py +58 -52
  93. diffusers/schedulers/scheduling_pndm.py +77 -62
  94. diffusers/schedulers/scheduling_repaint.py +56 -38
  95. diffusers/schedulers/scheduling_sde_ve.py +62 -50
  96. diffusers/schedulers/scheduling_sde_vp.py +32 -11
  97. diffusers/schedulers/scheduling_unclip.py +3 -3
  98. diffusers/schedulers/scheduling_unipc_multistep.py +131 -91
  99. diffusers/schedulers/scheduling_utils.py +41 -35
  100. diffusers/schedulers/scheduling_utils_flax.py +8 -2
  101. diffusers/schedulers/scheduling_vq_diffusion.py +39 -68
  102. diffusers/utils/__init__.py +2 -2
  103. diffusers/utils/dummy_pt_objects.py +15 -0
  104. diffusers/utils/dummy_torch_and_transformers_objects.py +15 -0
  105. diffusers/utils/hub_utils.py +105 -2
  106. diffusers/utils/import_utils.py +0 -4
  107. diffusers/utils/pil_utils.py +19 -0
  108. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/METADATA +5 -7
  109. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/RECORD +113 -112
  110. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/WHEEL +1 -1
  111. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/entry_points.txt +0 -1
  112. diffusers/models/cross_attention.py +0 -94
  113. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/LICENSE +0 -0
  114. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,193 @@
1
+ # Copyright 2023 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Tuple, Union
18
+
19
+ import torch
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..utils import BaseOutput, apply_forward_hook
23
+ from .modeling_utils import ModelMixin
24
+ from .vae import DecoderOutput, DecoderTiny, EncoderTiny
25
+
26
+
27
+ @dataclass
28
+ class AutoencoderTinyOutput(BaseOutput):
29
+ """
30
+ Output of AutoencoderTiny encoding method.
31
+
32
+ Args:
33
+ latents (`torch.Tensor`): Encoded outputs of the `Encoder`.
34
+
35
+ """
36
+
37
+ latents: torch.Tensor
38
+
39
+
40
+ class AutoencoderTiny(ModelMixin, ConfigMixin):
41
+ r"""
42
+ A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
43
+
44
+ [`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`.
45
+
46
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
47
+ all models (such as downloading or saving).
48
+
49
+ Parameters:
50
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
51
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
52
+ encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
53
+ Tuple of integers representing the number of output channels for each encoder block. The length of the
54
+ tuple should be equal to the number of encoder blocks.
55
+ decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
56
+ Tuple of integers representing the number of output channels for each decoder block. The length of the
57
+ tuple should be equal to the number of decoder blocks.
58
+ act_fn (`str`, *optional*, defaults to `"relu"`):
59
+ Activation function to be used throughout the model.
60
+ latent_channels (`int`, *optional*, defaults to 4):
61
+ Number of channels in the latent representation. The latent space acts as a compressed representation of
62
+ the input image.
63
+ upsampling_scaling_factor (`int`, *optional*, defaults to 2):
64
+ Scaling factor for upsampling in the decoder. It determines the size of the output image during the
65
+ upsampling process.
66
+ num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
67
+ Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
68
+ length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
69
+ number of encoder blocks.
70
+ num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
71
+ Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
72
+ length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
73
+ number of decoder blocks.
74
+ latent_magnitude (`float`, *optional*, defaults to 3.0):
75
+ Magnitude of the latent representation. This parameter scales the latent representation values to control
76
+ the extent of information preservation.
77
+ latent_shift (float, *optional*, defaults to 0.5):
78
+ Shift applied to the latent representation. This parameter controls the center of the latent space.
79
+ scaling_factor (`float`, *optional*, defaults to 1.0):
80
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
81
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
82
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
83
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
84
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
85
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
86
+ however, no such scaling factor was used, hence the value of 1.0 as the default.
87
+ force_upcast (`bool`, *optional*, default to `False`):
88
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
89
+ can be fine-tuned / trained to a lower range without losing too much precision, in which case
90
+ `force_upcast` can be set to `False` (see this fp16-friendly
91
+ [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
92
+ """
93
+ _supports_gradient_checkpointing = True
94
+
95
+ @register_to_config
96
+ def __init__(
97
+ self,
98
+ in_channels=3,
99
+ out_channels=3,
100
+ encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
101
+ decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
102
+ act_fn: str = "relu",
103
+ latent_channels: int = 4,
104
+ upsampling_scaling_factor: int = 2,
105
+ num_encoder_blocks: Tuple[int] = (1, 3, 3, 3),
106
+ num_decoder_blocks: Tuple[int] = (3, 3, 3, 1),
107
+ latent_magnitude: int = 3,
108
+ latent_shift: float = 0.5,
109
+ force_upcast: float = False,
110
+ scaling_factor: float = 1.0,
111
+ ):
112
+ super().__init__()
113
+
114
+ if len(encoder_block_out_channels) != len(num_encoder_blocks):
115
+ raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
116
+ if len(decoder_block_out_channels) != len(num_decoder_blocks):
117
+ raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")
118
+
119
+ self.encoder = EncoderTiny(
120
+ in_channels=in_channels,
121
+ out_channels=latent_channels,
122
+ num_blocks=num_encoder_blocks,
123
+ block_out_channels=encoder_block_out_channels,
124
+ act_fn=act_fn,
125
+ )
126
+
127
+ self.decoder = DecoderTiny(
128
+ in_channels=latent_channels,
129
+ out_channels=out_channels,
130
+ num_blocks=num_decoder_blocks,
131
+ block_out_channels=decoder_block_out_channels,
132
+ upsampling_scaling_factor=upsampling_scaling_factor,
133
+ act_fn=act_fn,
134
+ )
135
+
136
+ self.latent_magnitude = latent_magnitude
137
+ self.latent_shift = latent_shift
138
+ self.scaling_factor = scaling_factor
139
+
140
+ def _set_gradient_checkpointing(self, module, value=False):
141
+ if isinstance(module, (EncoderTiny, DecoderTiny)):
142
+ module.gradient_checkpointing = value
143
+
144
+ def scale_latents(self, x):
145
+ """raw latents -> [0, 1]"""
146
+ return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
147
+
148
+ def unscale_latents(self, x):
149
+ """[0, 1] -> raw latents"""
150
+ return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
151
+
152
+ @apply_forward_hook
153
+ def encode(
154
+ self, x: torch.FloatTensor, return_dict: bool = True
155
+ ) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
156
+ output = self.encoder(x)
157
+
158
+ if not return_dict:
159
+ return (output,)
160
+
161
+ return AutoencoderTinyOutput(latents=output)
162
+
163
+ @apply_forward_hook
164
+ def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
165
+ output = self.decoder(x)
166
+ # Refer to the following discussion to know why this is needed.
167
+ # https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
168
+ output = output.mul_(2).sub_(1)
169
+
170
+ if not return_dict:
171
+ return (output,)
172
+
173
+ return DecoderOutput(sample=output)
174
+
175
+ def forward(
176
+ self,
177
+ sample: torch.FloatTensor,
178
+ return_dict: bool = True,
179
+ ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
180
+ r"""
181
+ Args:
182
+ sample (`torch.FloatTensor`): Input sample.
183
+ return_dict (`bool`, *optional*, defaults to `True`):
184
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
185
+ """
186
+ enc = self.encode(sample).latents
187
+ scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
188
+ unscaled_enc = self.unscale_latents(scaled_enc)
189
+ dec = self.decode(unscaled_enc)
190
+
191
+ if not return_dict:
192
+ return (dec,)
193
+ return DecoderOutput(sample=dec)
@@ -723,7 +723,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
723
723
  class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
724
724
  emb = emb + class_emb
725
725
 
726
- if "addition_embed_type" in self.config:
726
+ if self.config.addition_embed_type is not None:
727
727
  if self.config.addition_embed_type == "text":
728
728
  aug_emb = self.add_embedding(encoder_hidden_states)
729
729
 
@@ -544,3 +544,59 @@ class AttentionPooling(nn.Module):
544
544
  a = a.reshape(bs, -1, 1).transpose(1, 2)
545
545
 
546
546
  return a[:, 0, :] # cls_token
547
+
548
+
549
+ class FourierEmbedder(nn.Module):
550
+ def __init__(self, num_freqs=64, temperature=100):
551
+ super().__init__()
552
+
553
+ self.num_freqs = num_freqs
554
+ self.temperature = temperature
555
+
556
+ freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
557
+ freq_bands = freq_bands[None, None, None]
558
+ self.register_buffer("freq_bands", freq_bands, persistent=False)
559
+
560
+ def __call__(self, x):
561
+ x = self.freq_bands * x.unsqueeze(-1)
562
+ return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
563
+
564
+
565
+ class PositionNet(nn.Module):
566
+ def __init__(self, positive_len, out_dim, fourier_freqs=8):
567
+ super().__init__()
568
+ self.positive_len = positive_len
569
+ self.out_dim = out_dim
570
+
571
+ self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
572
+ self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
573
+
574
+ if isinstance(out_dim, tuple):
575
+ out_dim = out_dim[0]
576
+ self.linears = nn.Sequential(
577
+ nn.Linear(self.positive_len + self.position_dim, 512),
578
+ nn.SiLU(),
579
+ nn.Linear(512, 512),
580
+ nn.SiLU(),
581
+ nn.Linear(512, out_dim),
582
+ )
583
+
584
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
585
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
586
+
587
+ def forward(self, boxes, masks, positive_embeddings):
588
+ masks = masks.unsqueeze(-1)
589
+
590
+ # embedding position (it may includes padding as placeholder)
591
+ xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
592
+
593
+ # learnable null embedding
594
+ positive_null = self.null_positive_feature.view(1, 1, -1)
595
+ xyxy_null = self.null_position_feature.view(1, 1, -1)
596
+
597
+ # replace padding with learnable null embedding
598
+ positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
599
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
600
+
601
+ objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
602
+ return objs
diffusers/models/lora.py CHANGED
@@ -22,9 +22,6 @@ class LoRALinearLayer(nn.Module):
22
22
  def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
23
23
  super().__init__()
24
24
 
25
- if rank > min(in_features, out_features):
26
- raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
27
-
28
25
  self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
29
26
  self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
30
27
  # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
@@ -54,9 +51,6 @@ class LoRAConv2dLayer(nn.Module):
54
51
  ):
55
52
  super().__init__()
56
53
 
57
- if rank > min(in_features, out_features):
58
- raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
59
-
60
54
  self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
61
55
  # according to the official kohya_ss trainer kernel_size are always fixed for the up layer
62
56
  # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
@@ -23,7 +23,7 @@ import msgpack.exceptions
23
23
  from flax.core.frozen_dict import FrozenDict, unfreeze
24
24
  from flax.serialization import from_bytes, to_bytes
25
25
  from flax.traverse_util import flatten_dict, unflatten_dict
26
- from huggingface_hub import hf_hub_download
26
+ from huggingface_hub import create_repo, hf_hub_download
27
27
  from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
28
28
  from requests import HTTPError
29
29
 
@@ -34,6 +34,7 @@ from ..utils import (
34
34
  FLAX_WEIGHTS_NAME,
35
35
  HUGGINGFACE_CO_RESOLVE_ENDPOINT,
36
36
  WEIGHTS_NAME,
37
+ PushToHubMixin,
37
38
  logging,
38
39
  )
39
40
  from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
@@ -42,7 +43,7 @@ from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
42
43
  logger = logging.get_logger(__name__)
43
44
 
44
45
 
45
- class FlaxModelMixin:
46
+ class FlaxModelMixin(PushToHubMixin):
46
47
  r"""
47
48
  Base class for all Flax models.
48
49
 
@@ -497,6 +498,8 @@ class FlaxModelMixin:
497
498
  save_directory: Union[str, os.PathLike],
498
499
  params: Union[Dict, FrozenDict],
499
500
  is_main_process: bool = True,
501
+ push_to_hub: bool = False,
502
+ **kwargs,
500
503
  ):
501
504
  """
502
505
  Save a model and its configuration file to a directory so that it can be reloaded using the
@@ -511,6 +514,12 @@ class FlaxModelMixin:
511
514
  Whether the process calling this is the main process or not. Useful during distributed training and you
512
515
  need to call this function on all processes. In this case, set `is_main_process=True` only on the main
513
516
  process to avoid race conditions.
517
+ push_to_hub (`bool`, *optional*, defaults to `False`):
518
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
519
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
520
+ namespace).
521
+ kwargs (`Dict[str, Any]`, *optional*):
522
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
514
523
  """
515
524
  if os.path.isfile(save_directory):
516
525
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
@@ -518,6 +527,14 @@ class FlaxModelMixin:
518
527
 
519
528
  os.makedirs(save_directory, exist_ok=True)
520
529
 
530
+ if push_to_hub:
531
+ commit_message = kwargs.pop("commit_message", None)
532
+ private = kwargs.pop("private", False)
533
+ create_pr = kwargs.pop("create_pr", False)
534
+ token = kwargs.pop("token", None)
535
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
536
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
537
+
521
538
  model_to_save = self
522
539
 
523
540
  # Attach architecture to the config
@@ -532,3 +549,12 @@ class FlaxModelMixin:
532
549
  f.write(model_bytes)
533
550
 
534
551
  logger.info(f"Model weights saved in {output_model_file}")
552
+
553
+ if push_to_hub:
554
+ self._upload_folder(
555
+ save_directory,
556
+ repo_id,
557
+ token=token,
558
+ commit_message=commit_message,
559
+ create_pr=create_pr,
560
+ )
@@ -21,7 +21,9 @@ import re
21
21
  from functools import partial
22
22
  from typing import Any, Callable, List, Optional, Tuple, Union
23
23
 
24
+ import safetensors
24
25
  import torch
26
+ from huggingface_hub import create_repo
25
27
  from torch import Tensor, device, nn
26
28
 
27
29
  from .. import __version__
@@ -36,10 +38,10 @@ from ..utils import (
36
38
  _get_model_file,
37
39
  deprecate,
38
40
  is_accelerate_available,
39
- is_safetensors_available,
40
41
  is_torch_version,
41
42
  logging,
42
43
  )
44
+ from ..utils.hub_utils import PushToHubMixin
43
45
 
44
46
 
45
47
  logger = logging.get_logger(__name__)
@@ -56,9 +58,6 @@ if is_accelerate_available():
56
58
  from accelerate.utils import set_module_tensor_to_device
57
59
  from accelerate.utils.versions import is_torch_version
58
60
 
59
- if is_safetensors_available():
60
- import safetensors
61
-
62
61
 
63
62
  def get_parameter_device(parameter: torch.nn.Module):
64
63
  try:
@@ -150,7 +149,7 @@ def _load_state_dict_into_model(model_to_load, state_dict):
150
149
  return error_msgs
151
150
 
152
151
 
153
- class ModelMixin(torch.nn.Module):
152
+ class ModelMixin(torch.nn.Module, PushToHubMixin):
154
153
  r"""
155
154
  Base class for all models.
156
155
 
@@ -273,8 +272,10 @@ class ModelMixin(torch.nn.Module):
273
272
  save_directory: Union[str, os.PathLike],
274
273
  is_main_process: bool = True,
275
274
  save_function: Callable = None,
276
- safe_serialization: bool = False,
275
+ safe_serialization: bool = True,
277
276
  variant: Optional[str] = None,
277
+ push_to_hub: bool = False,
278
+ **kwargs,
278
279
  ):
279
280
  """
280
281
  Save a model and its configuration file to a directory so that it can be reloaded using the
@@ -291,20 +292,32 @@ class ModelMixin(torch.nn.Module):
291
292
  The function to use to save the state dictionary. Useful during distributed training when you need to
292
293
  replace `torch.save` with another method. Can be configured with the environment variable
293
294
  `DIFFUSERS_SAVE_MODE`.
294
- safe_serialization (`bool`, *optional*, defaults to `False`):
295
+ safe_serialization (`bool`, *optional*, defaults to `True`):
295
296
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
296
297
  variant (`str`, *optional*):
297
298
  If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
299
+ push_to_hub (`bool`, *optional*, defaults to `False`):
300
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
301
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
302
+ namespace).
303
+ kwargs (`Dict[str, Any]`, *optional*):
304
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
298
305
  """
299
- if safe_serialization and not is_safetensors_available():
300
- raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
301
-
302
306
  if os.path.isfile(save_directory):
303
307
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
304
308
  return
305
309
 
306
310
  os.makedirs(save_directory, exist_ok=True)
307
311
 
312
+ if push_to_hub:
313
+ commit_message = kwargs.pop("commit_message", None)
314
+ private = kwargs.pop("private", False)
315
+ create_pr = kwargs.pop("create_pr", False)
316
+ token = kwargs.pop("token", None)
317
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
318
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
319
+
320
+ # Only save the model itself if we are using distributed training
308
321
  model_to_save = self
309
322
 
310
323
  # Attach architecture to the config
@@ -328,6 +341,15 @@ class ModelMixin(torch.nn.Module):
328
341
 
329
342
  logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
330
343
 
344
+ if push_to_hub:
345
+ self._upload_folder(
346
+ save_directory,
347
+ repo_id,
348
+ token=token,
349
+ commit_message=commit_message,
350
+ create_pr=create_pr,
351
+ )
352
+
331
353
  @classmethod
332
354
  def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
333
355
  r"""
@@ -454,14 +476,9 @@ class ModelMixin(torch.nn.Module):
454
476
  variant = kwargs.pop("variant", None)
455
477
  use_safetensors = kwargs.pop("use_safetensors", None)
456
478
 
457
- if use_safetensors and not is_safetensors_available():
458
- raise ValueError(
459
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
460
- )
461
-
462
479
  allow_pickle = False
463
480
  if use_safetensors is None:
464
- use_safetensors = is_safetensors_available()
481
+ use_safetensors = True
465
482
  allow_pickle = True
466
483
 
467
484
  if low_cpu_mem_usage and not is_accelerate_available():
@@ -91,6 +91,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
91
91
  upcast_attention: bool = False,
92
92
  norm_type: str = "layer_norm",
93
93
  norm_elementwise_affine: bool = True,
94
+ attention_type: str = "default",
94
95
  ):
95
96
  super().__init__()
96
97
  self.use_linear_projection = use_linear_projection
@@ -183,6 +184,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
183
184
  upcast_attention=upcast_attention,
184
185
  norm_type=norm_type,
185
186
  norm_elementwise_affine=norm_elementwise_affine,
187
+ attention_type=attention_type,
186
188
  )
187
189
  for d in range(num_layers)
188
190
  ]
@@ -204,6 +206,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
204
206
  self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
205
207
  self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
206
208
 
209
+ self.gradient_checkpointing = False
210
+
207
211
  def forward(
208
212
  self,
209
213
  hidden_states: torch.Tensor,
@@ -289,15 +293,28 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
289
293
 
290
294
  # 2. Blocks
291
295
  for block in self.transformer_blocks:
292
- hidden_states = block(
293
- hidden_states,
294
- attention_mask=attention_mask,
295
- encoder_hidden_states=encoder_hidden_states,
296
- encoder_attention_mask=encoder_attention_mask,
297
- timestep=timestep,
298
- cross_attention_kwargs=cross_attention_kwargs,
299
- class_labels=class_labels,
300
- )
296
+ if self.training and self.gradient_checkpointing:
297
+ hidden_states = torch.utils.checkpoint.checkpoint(
298
+ block,
299
+ hidden_states,
300
+ attention_mask,
301
+ encoder_hidden_states,
302
+ encoder_attention_mask,
303
+ timestep,
304
+ cross_attention_kwargs,
305
+ class_labels,
306
+ use_reentrant=False,
307
+ )
308
+ else:
309
+ hidden_states = block(
310
+ hidden_states,
311
+ attention_mask=attention_mask,
312
+ encoder_hidden_states=encoder_hidden_states,
313
+ encoder_attention_mask=encoder_attention_mask,
314
+ timestep=timestep,
315
+ cross_attention_kwargs=cross_attention_kwargs,
316
+ class_labels=class_labels,
317
+ )
301
318
 
302
319
  # 3. Output
303
320
  if self.is_input_continuous:
@@ -56,9 +56,9 @@ class UNet1DModel(ModelMixin, ConfigMixin):
56
56
  freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
57
57
  flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
58
58
  Whether to flip sin to cos for Fourier time embedding.
59
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`):
59
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`):
60
60
  Tuple of downsample block types.
61
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`):
61
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`):
62
62
  Tuple of upsample block types.
63
63
  block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
64
64
  Tuple of block output channels.