diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -12,183 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import inspect
16
- from typing import Dict, List, Optional, Union
17
15
 
18
- from ..utils import is_transformers_available, logging
19
16
  from .auto import DiffusersAutoQuantizer
20
17
  from .base import DiffusersQuantizer
21
- from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
22
-
23
-
24
- try:
25
- from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
26
- except ImportError:
27
-
28
- class TransformersQuantConfigMixin:
29
- pass
30
-
31
-
32
- logger = logging.get_logger(__name__)
33
-
34
-
35
- class PipelineQuantizationConfig:
36
- """
37
- Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
38
-
39
- Args:
40
- quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
41
- is available to both `diffusers` and `transformers`.
42
- quant_kwargs (`dict`): Params to initialize the quantization backend class.
43
- components_to_quantize (`list`): Components of a pipeline to be quantized.
44
- quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
45
- components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
46
- and `components_to_quantize`.
47
- """
48
-
49
- def __init__(
50
- self,
51
- quant_backend: str = None,
52
- quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
53
- components_to_quantize: Optional[List[str]] = None,
54
- quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
55
- ):
56
- self.quant_backend = quant_backend
57
- # Initialize kwargs to be {} to set to the defaults.
58
- self.quant_kwargs = quant_kwargs or {}
59
- self.components_to_quantize = components_to_quantize
60
- self.quant_mapping = quant_mapping
61
-
62
- self.post_init()
63
-
64
- def post_init(self):
65
- quant_mapping = self.quant_mapping
66
- self.is_granular = True if quant_mapping is not None else False
67
-
68
- self._validate_init_args()
69
-
70
- def _validate_init_args(self):
71
- if self.quant_backend and self.quant_mapping:
72
- raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
73
-
74
- if not self.quant_mapping and not self.quant_backend:
75
- raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
76
-
77
- if not self.quant_kwargs and not self.quant_mapping:
78
- raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
79
-
80
- if self.quant_backend is not None:
81
- self._validate_init_kwargs_in_backends()
82
-
83
- if self.quant_mapping is not None:
84
- self._validate_quant_mapping_args()
85
-
86
- def _validate_init_kwargs_in_backends(self):
87
- quant_backend = self.quant_backend
88
-
89
- self._check_backend_availability(quant_backend)
90
-
91
- quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
92
-
93
- if quant_config_mapping_transformers is not None:
94
- init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
95
- init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
96
- else:
97
- init_kwargs_transformers = None
98
-
99
- init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
100
- init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
101
-
102
- if init_kwargs_transformers != init_kwargs_diffusers:
103
- raise ValueError(
104
- "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
105
- f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
106
- "this mapping would look like."
107
- )
108
-
109
- def _validate_quant_mapping_args(self):
110
- quant_mapping = self.quant_mapping
111
- transformers_map, diffusers_map = self._get_quant_config_list()
112
-
113
- available_transformers = list(transformers_map.values()) if transformers_map else None
114
- available_diffusers = list(diffusers_map.values())
115
-
116
- for module_name, config in quant_mapping.items():
117
- if any(isinstance(config, cfg) for cfg in available_diffusers):
118
- continue
119
-
120
- if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
121
- continue
122
-
123
- if available_transformers:
124
- raise ValueError(
125
- f"Provided config for module_name={module_name} could not be found. "
126
- f"Available diffusers configs: {available_diffusers}; "
127
- f"Available transformers configs: {available_transformers}."
128
- )
129
- else:
130
- raise ValueError(
131
- f"Provided config for module_name={module_name} could not be found. "
132
- f"Available diffusers configs: {available_diffusers}."
133
- )
134
-
135
- def _check_backend_availability(self, quant_backend: str):
136
- quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
137
-
138
- available_backends_transformers = (
139
- list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
140
- )
141
- available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
142
-
143
- if (
144
- available_backends_transformers and quant_backend not in available_backends_transformers
145
- ) or quant_backend not in quant_config_mapping_diffusers:
146
- error_message = f"Provided quant_backend={quant_backend} was not found."
147
- if available_backends_transformers:
148
- error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
149
- error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
150
- raise ValueError(error_message)
151
-
152
- def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
153
- quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
154
-
155
- quant_mapping = self.quant_mapping
156
- components_to_quantize = self.components_to_quantize
157
-
158
- # Granular case
159
- if self.is_granular and module_name in quant_mapping:
160
- logger.debug(f"Initializing quantization config class for {module_name}.")
161
- config = quant_mapping[module_name]
162
- return config
163
-
164
- # Global config case
165
- else:
166
- should_quantize = False
167
- # Only quantize the modules requested for.
168
- if components_to_quantize and module_name in components_to_quantize:
169
- should_quantize = True
170
- # No specification for `components_to_quantize` means all modules should be quantized.
171
- elif not self.is_granular and not components_to_quantize:
172
- should_quantize = True
173
-
174
- if should_quantize:
175
- logger.debug(f"Initializing quantization config class for {module_name}.")
176
- mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
177
- quant_config_cls = mapping_to_use[self.quant_backend]
178
- quant_kwargs = self.quant_kwargs
179
- return quant_config_cls(**quant_kwargs)
180
-
181
- # Fallback: no applicable configuration found.
182
- return None
183
-
184
- def _get_quant_config_list(self):
185
- if is_transformers_available():
186
- from transformers.quantizers.auto import (
187
- AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
188
- )
189
- else:
190
- quant_config_mapping_transformers = None
191
-
192
- from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
193
-
194
- return quant_config_mapping_transformers, quant_config_mapping_diffusers
18
+ from .pipe_quant_config import PipelineQuantizationConfig
@@ -209,6 +209,17 @@ class DiffusersQuantizer(ABC):
209
209
 
210
210
  return model
211
211
 
212
+ def get_cuda_warm_up_factor(self):
213
+ """
214
+ The factor to be used in `caching_allocator_warmup` to get the number of bytes to pre-allocate to warm up cuda.
215
+ A factor of 2 means we allocate all bytes in the empty model (since we allocate in fp16), a factor of 4 means
216
+ we allocate half the memory of the weights residing in the empty model, etc...
217
+ """
218
+ # By default we return 4, i.e. half the model size (this corresponds to the case where the model is not
219
+ # really pre-processed, i.e. we do not have the info that weights are going to be 8 bits before actual
220
+ # weight loading)
221
+ return 4
222
+
212
223
  def _dequantize(self, model):
213
224
  raise NotImplementedError(
214
225
  f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
@@ -12,15 +12,15 @@
12
12
  # # See the License for the specific language governing permissions and
13
13
  # # limitations under the License.
14
14
 
15
-
16
15
  import inspect
16
+ import os
17
17
  from contextlib import nullcontext
18
18
 
19
19
  import gguf
20
20
  import torch
21
21
  import torch.nn as nn
22
22
 
23
- from ...utils import is_accelerate_available
23
+ from ...utils import is_accelerate_available, is_kernels_available
24
24
 
25
25
 
26
26
  if is_accelerate_available():
@@ -29,6 +29,82 @@ if is_accelerate_available():
29
29
  from accelerate.hooks import add_hook_to_module, remove_hook_from_module
30
30
 
31
31
 
32
+ can_use_cuda_kernels = (
33
+ os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"]
34
+ and torch.cuda.is_available()
35
+ and torch.cuda.get_device_capability()[0] >= 7
36
+ )
37
+ if can_use_cuda_kernels and is_kernels_available():
38
+ from kernels import get_kernel
39
+
40
+ ops = get_kernel("Isotr0py/ggml")
41
+ else:
42
+ ops = None
43
+
44
+ UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
45
+ STANDARD_QUANT_TYPES = {
46
+ gguf.GGMLQuantizationType.Q4_0,
47
+ gguf.GGMLQuantizationType.Q4_1,
48
+ gguf.GGMLQuantizationType.Q5_0,
49
+ gguf.GGMLQuantizationType.Q5_1,
50
+ gguf.GGMLQuantizationType.Q8_0,
51
+ gguf.GGMLQuantizationType.Q8_1,
52
+ }
53
+ KQUANT_TYPES = {
54
+ gguf.GGMLQuantizationType.Q2_K,
55
+ gguf.GGMLQuantizationType.Q3_K,
56
+ gguf.GGMLQuantizationType.Q4_K,
57
+ gguf.GGMLQuantizationType.Q5_K,
58
+ gguf.GGMLQuantizationType.Q6_K,
59
+ }
60
+ IMATRIX_QUANT_TYPES = {
61
+ gguf.GGMLQuantizationType.IQ1_M,
62
+ gguf.GGMLQuantizationType.IQ1_S,
63
+ gguf.GGMLQuantizationType.IQ2_XXS,
64
+ gguf.GGMLQuantizationType.IQ2_XS,
65
+ gguf.GGMLQuantizationType.IQ2_S,
66
+ gguf.GGMLQuantizationType.IQ3_XXS,
67
+ gguf.GGMLQuantizationType.IQ3_S,
68
+ gguf.GGMLQuantizationType.IQ4_XS,
69
+ gguf.GGMLQuantizationType.IQ4_NL,
70
+ }
71
+ # TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
72
+ # Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
73
+ # MMQ kernel for I-Matrix quantization.
74
+ DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
75
+ MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
76
+ MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
77
+
78
+
79
+ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
80
+ # there is no need to call any kernel for fp16/bf16
81
+ if qweight_type in UNQUANTIZED_TYPES:
82
+ return x @ qweight.T
83
+
84
+ # TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
85
+ # contiguous batching and inefficient with diffusers' batching,
86
+ # so we disabled it now.
87
+
88
+ # elif qweight_type in MMVQ_QUANT_TYPES:
89
+ # y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
90
+ # elif qweight_type in MMQ_QUANT_TYPES:
91
+ # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
92
+
93
+ # If there is no available MMQ kernel, fallback to dequantize
94
+ if qweight_type in DEQUANT_TYPES:
95
+ block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
96
+ shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
97
+ weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
98
+ y = x @ weight.to(x.dtype).T
99
+ else:
100
+ # Raise an error if the quantization type is not supported.
101
+ # Might be useful if llama.cpp adds a new quantization type.
102
+ # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
103
+ qweight_type = gguf.GGMLQuantizationType(qweight_type)
104
+ raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
105
+ return y.as_tensor()
106
+
107
+
32
108
  # Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
33
109
  def _create_accelerate_new_hook(old_hook):
34
110
  r"""
@@ -451,11 +527,24 @@ class GGUFLinear(nn.Linear):
451
527
  ) -> None:
452
528
  super().__init__(in_features, out_features, bias, device)
453
529
  self.compute_dtype = compute_dtype
530
+ self.device = device
531
+
532
+ def forward(self, inputs: torch.Tensor):
533
+ if ops is not None and self.weight.is_cuda and inputs.is_cuda:
534
+ return self.forward_cuda(inputs)
535
+ return self.forward_native(inputs)
454
536
 
455
- def forward(self, inputs):
537
+ def forward_native(self, inputs: torch.Tensor):
456
538
  weight = dequantize_gguf_tensor(self.weight)
457
539
  weight = weight.to(self.compute_dtype)
458
540
  bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
459
541
 
460
542
  output = torch.nn.functional.linear(inputs, weight, bias)
461
543
  return output
544
+
545
+ def forward_cuda(self, inputs: torch.Tensor):
546
+ quant_type = self.weight.quant_type
547
+ output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
548
+ if self.bias is not None:
549
+ output += self.bias.to(self.compute_dtype)
550
+ return output
@@ -0,0 +1,202 @@
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Dict, List, Optional, Union
17
+
18
+ from ..utils import is_transformers_available, logging
19
+ from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
20
+
21
+
22
+ try:
23
+ from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
24
+ except ImportError:
25
+
26
+ class TransformersQuantConfigMixin:
27
+ pass
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class PipelineQuantizationConfig:
34
+ """
35
+ Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
36
+
37
+ Args:
38
+ quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
39
+ is available to both `diffusers` and `transformers`.
40
+ quant_kwargs (`dict`): Params to initialize the quantization backend class.
41
+ components_to_quantize (`list`): Components of a pipeline to be quantized.
42
+ quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
43
+ components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
44
+ and `components_to_quantize`.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ quant_backend: str = None,
50
+ quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
51
+ components_to_quantize: Optional[List[str]] = None,
52
+ quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
53
+ ):
54
+ self.quant_backend = quant_backend
55
+ # Initialize kwargs to be {} to set to the defaults.
56
+ self.quant_kwargs = quant_kwargs or {}
57
+ self.components_to_quantize = components_to_quantize
58
+ self.quant_mapping = quant_mapping
59
+ self.config_mapping = {} # book-keeping Example: `{module_name: quant_config}`
60
+ self.post_init()
61
+
62
+ def post_init(self):
63
+ quant_mapping = self.quant_mapping
64
+ self.is_granular = True if quant_mapping is not None else False
65
+
66
+ self._validate_init_args()
67
+
68
+ def _validate_init_args(self):
69
+ if self.quant_backend and self.quant_mapping:
70
+ raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
71
+
72
+ if not self.quant_mapping and not self.quant_backend:
73
+ raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
74
+
75
+ if not self.quant_kwargs and not self.quant_mapping:
76
+ raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
77
+
78
+ if self.quant_backend is not None:
79
+ self._validate_init_kwargs_in_backends()
80
+
81
+ if self.quant_mapping is not None:
82
+ self._validate_quant_mapping_args()
83
+
84
+ def _validate_init_kwargs_in_backends(self):
85
+ quant_backend = self.quant_backend
86
+
87
+ self._check_backend_availability(quant_backend)
88
+
89
+ quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
90
+
91
+ if quant_config_mapping_transformers is not None:
92
+ init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
93
+ init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
94
+ else:
95
+ init_kwargs_transformers = None
96
+
97
+ init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
98
+ init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
99
+
100
+ if init_kwargs_transformers != init_kwargs_diffusers:
101
+ raise ValueError(
102
+ "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
103
+ f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
104
+ "this mapping would look like."
105
+ )
106
+
107
+ def _validate_quant_mapping_args(self):
108
+ quant_mapping = self.quant_mapping
109
+ transformers_map, diffusers_map = self._get_quant_config_list()
110
+
111
+ available_transformers = list(transformers_map.values()) if transformers_map else None
112
+ available_diffusers = list(diffusers_map.values())
113
+
114
+ for module_name, config in quant_mapping.items():
115
+ if any(isinstance(config, cfg) for cfg in available_diffusers):
116
+ continue
117
+
118
+ if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
119
+ continue
120
+
121
+ if available_transformers:
122
+ raise ValueError(
123
+ f"Provided config for module_name={module_name} could not be found. "
124
+ f"Available diffusers configs: {available_diffusers}; "
125
+ f"Available transformers configs: {available_transformers}."
126
+ )
127
+ else:
128
+ raise ValueError(
129
+ f"Provided config for module_name={module_name} could not be found. "
130
+ f"Available diffusers configs: {available_diffusers}."
131
+ )
132
+
133
+ def _check_backend_availability(self, quant_backend: str):
134
+ quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
135
+
136
+ available_backends_transformers = (
137
+ list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
138
+ )
139
+ available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
140
+
141
+ if (
142
+ available_backends_transformers and quant_backend not in available_backends_transformers
143
+ ) or quant_backend not in quant_config_mapping_diffusers:
144
+ error_message = f"Provided quant_backend={quant_backend} was not found."
145
+ if available_backends_transformers:
146
+ error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
147
+ error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
148
+ raise ValueError(error_message)
149
+
150
+ def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
151
+ quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
152
+
153
+ quant_mapping = self.quant_mapping
154
+ components_to_quantize = self.components_to_quantize
155
+
156
+ # Granular case
157
+ if self.is_granular and module_name in quant_mapping:
158
+ logger.debug(f"Initializing quantization config class for {module_name}.")
159
+ config = quant_mapping[module_name]
160
+ self.config_mapping.update({module_name: config})
161
+ return config
162
+
163
+ # Global config case
164
+ else:
165
+ should_quantize = False
166
+ # Only quantize the modules requested for.
167
+ if components_to_quantize and module_name in components_to_quantize:
168
+ should_quantize = True
169
+ # No specification for `components_to_quantize` means all modules should be quantized.
170
+ elif not self.is_granular and not components_to_quantize:
171
+ should_quantize = True
172
+
173
+ if should_quantize:
174
+ logger.debug(f"Initializing quantization config class for {module_name}.")
175
+ mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
176
+ quant_config_cls = mapping_to_use[self.quant_backend]
177
+ quant_kwargs = self.quant_kwargs
178
+ quant_obj = quant_config_cls(**quant_kwargs)
179
+ self.config_mapping.update({module_name: quant_obj})
180
+ return quant_obj
181
+
182
+ # Fallback: no applicable configuration found.
183
+ return None
184
+
185
+ def _get_quant_config_list(self):
186
+ if is_transformers_available():
187
+ from transformers.quantizers.auto import (
188
+ AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
189
+ )
190
+ else:
191
+ quant_config_mapping_transformers = None
192
+
193
+ from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
194
+
195
+ return quant_config_mapping_transformers, quant_config_mapping_diffusers
196
+
197
+ def __repr__(self):
198
+ out = ""
199
+ config_mapping = dict(sorted(self.config_mapping.copy().items()))
200
+ for module_name, config in config_mapping.items():
201
+ out += f"{module_name} {config}"
202
+ return out
@@ -19,6 +19,7 @@ https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac17
19
19
 
20
20
  import importlib
21
21
  import types
22
+ from fnmatch import fnmatch
22
23
  from typing import TYPE_CHECKING, Any, Dict, List, Union
23
24
 
24
25
  from packaging import version
@@ -278,6 +279,31 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
278
279
  module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
279
280
  quantize_(module, self.quantization_config.get_apply_tensor_subclass())
280
281
 
282
+ def get_cuda_warm_up_factor(self):
283
+ """
284
+ This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup.
285
+ - A factor of 2 means we pre-allocate the full memory footprint of the model.
286
+ - A factor of 4 means we pre-allocate half of that, and so on
287
+
288
+ However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give
289
+ the correct size for quantized weights (like int4 or int8) That's because TorchAO internally represents
290
+ quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the
291
+ torch_dtype not the actual bit-width of the quantized data.
292
+
293
+ To correct for this:
294
+ - Use a division factor of 8 for int4 weights
295
+ - Use a division factor of 4 for int8 weights
296
+ """
297
+ # Original mapping for non-AOBaseConfig types
298
+ # For the uint types, this is a best guess. Once these types become more used
299
+ # we can look into their nuances.
300
+ map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
301
+ quant_type = self.quantization_config.quant_type
302
+ for pattern, target_dtype in map_to_target_dtype.items():
303
+ if fnmatch(quant_type, pattern):
304
+ return target_dtype
305
+ raise ValueError(f"Unsupported quant_type: {quant_type!r}")
306
+
281
307
  def _process_model_before_weight_loading(
282
308
  self,
283
309
  model: "ModelMixin",
@@ -153,6 +153,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
153
153
  flow_shift: Optional[float] = 1.0,
154
154
  timestep_spacing: str = "linspace",
155
155
  steps_offset: int = 0,
156
+ use_dynamic_shifting: bool = False,
157
+ time_shift_type: str = "exponential",
156
158
  ):
157
159
  if self.config.use_beta_sigmas and not is_scipy_available():
158
160
  raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -232,7 +234,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
232
234
  """
233
235
  self._begin_index = begin_index
234
236
 
235
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
237
+ def set_timesteps(
238
+ self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
239
+ ):
236
240
  """
237
241
  Sets the discrete timesteps used for the diffusion chain (to be run before inference).
238
242
 
@@ -242,6 +246,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
242
246
  device (`str` or `torch.device`, *optional*):
243
247
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
244
248
  """
249
+ if mu is not None:
250
+ assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
251
+ self.config.flow_shift = np.exp(mu)
245
252
  # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
246
253
  if self.config.timestep_spacing == "linspace":
247
254
  timesteps = (
@@ -230,6 +230,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
230
230
  timestep_spacing: str = "linspace",
231
231
  steps_offset: int = 0,
232
232
  rescale_betas_zero_snr: bool = False,
233
+ use_dynamic_shifting: bool = False,
234
+ time_shift_type: str = "exponential",
233
235
  ):
234
236
  if self.config.use_beta_sigmas and not is_scipy_available():
235
237
  raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -330,6 +332,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
330
332
  self,
331
333
  num_inference_steps: int = None,
332
334
  device: Union[str, torch.device] = None,
335
+ mu: Optional[float] = None,
333
336
  timesteps: Optional[List[int]] = None,
334
337
  ):
335
338
  """
@@ -345,6 +348,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
345
348
  based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
346
349
  must be `None`, and `timestep_spacing` attribute will be ignored.
347
350
  """
351
+ if mu is not None:
352
+ assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
353
+ self.config.flow_shift = np.exp(mu)
348
354
  if num_inference_steps is None and timesteps is None:
349
355
  raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
350
356
  if num_inference_steps is not None and timesteps is not None:
@@ -169,6 +169,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
169
169
  final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
170
170
  lambda_min_clipped: float = -float("inf"),
171
171
  variance_type: Optional[str] = None,
172
+ use_dynamic_shifting: bool = False,
173
+ time_shift_type: str = "exponential",
172
174
  ):
173
175
  if self.config.use_beta_sigmas and not is_scipy_available():
174
176
  raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -301,6 +303,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
301
303
  self,
302
304
  num_inference_steps: int = None,
303
305
  device: Union[str, torch.device] = None,
306
+ mu: Optional[float] = None,
304
307
  timesteps: Optional[List[int]] = None,
305
308
  ):
306
309
  """
@@ -316,6 +319,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
316
319
  timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
317
320
  passed, `num_inference_steps` must be `None`.
318
321
  """
322
+ if mu is not None:
323
+ assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
324
+ self.config.flow_shift = np.exp(mu)
319
325
  if num_inference_steps is None and timesteps is None:
320
326
  raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
321
327
  if num_inference_steps is not None and timesteps is not None:
@@ -168,7 +168,6 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
168
168
  else:
169
169
  # max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
170
170
  self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
171
- print(f"Set timesteps: {self.timesteps}")
172
171
 
173
172
  self._step_index = None
174
173
  self._begin_index = None