diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
diffusers/quantizers/__init__.py
CHANGED
@@ -12,183 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import inspect
|
16
|
-
from typing import Dict, List, Optional, Union
|
17
15
|
|
18
|
-
from ..utils import is_transformers_available, logging
|
19
16
|
from .auto import DiffusersAutoQuantizer
|
20
17
|
from .base import DiffusersQuantizer
|
21
|
-
from .
|
22
|
-
|
23
|
-
|
24
|
-
try:
|
25
|
-
from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
|
26
|
-
except ImportError:
|
27
|
-
|
28
|
-
class TransformersQuantConfigMixin:
|
29
|
-
pass
|
30
|
-
|
31
|
-
|
32
|
-
logger = logging.get_logger(__name__)
|
33
|
-
|
34
|
-
|
35
|
-
class PipelineQuantizationConfig:
|
36
|
-
"""
|
37
|
-
Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
|
38
|
-
|
39
|
-
Args:
|
40
|
-
quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
|
41
|
-
is available to both `diffusers` and `transformers`.
|
42
|
-
quant_kwargs (`dict`): Params to initialize the quantization backend class.
|
43
|
-
components_to_quantize (`list`): Components of a pipeline to be quantized.
|
44
|
-
quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
|
45
|
-
components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
|
46
|
-
and `components_to_quantize`.
|
47
|
-
"""
|
48
|
-
|
49
|
-
def __init__(
|
50
|
-
self,
|
51
|
-
quant_backend: str = None,
|
52
|
-
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
|
53
|
-
components_to_quantize: Optional[List[str]] = None,
|
54
|
-
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
|
55
|
-
):
|
56
|
-
self.quant_backend = quant_backend
|
57
|
-
# Initialize kwargs to be {} to set to the defaults.
|
58
|
-
self.quant_kwargs = quant_kwargs or {}
|
59
|
-
self.components_to_quantize = components_to_quantize
|
60
|
-
self.quant_mapping = quant_mapping
|
61
|
-
|
62
|
-
self.post_init()
|
63
|
-
|
64
|
-
def post_init(self):
|
65
|
-
quant_mapping = self.quant_mapping
|
66
|
-
self.is_granular = True if quant_mapping is not None else False
|
67
|
-
|
68
|
-
self._validate_init_args()
|
69
|
-
|
70
|
-
def _validate_init_args(self):
|
71
|
-
if self.quant_backend and self.quant_mapping:
|
72
|
-
raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
|
73
|
-
|
74
|
-
if not self.quant_mapping and not self.quant_backend:
|
75
|
-
raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
|
76
|
-
|
77
|
-
if not self.quant_kwargs and not self.quant_mapping:
|
78
|
-
raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
|
79
|
-
|
80
|
-
if self.quant_backend is not None:
|
81
|
-
self._validate_init_kwargs_in_backends()
|
82
|
-
|
83
|
-
if self.quant_mapping is not None:
|
84
|
-
self._validate_quant_mapping_args()
|
85
|
-
|
86
|
-
def _validate_init_kwargs_in_backends(self):
|
87
|
-
quant_backend = self.quant_backend
|
88
|
-
|
89
|
-
self._check_backend_availability(quant_backend)
|
90
|
-
|
91
|
-
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
|
92
|
-
|
93
|
-
if quant_config_mapping_transformers is not None:
|
94
|
-
init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
|
95
|
-
init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
|
96
|
-
else:
|
97
|
-
init_kwargs_transformers = None
|
98
|
-
|
99
|
-
init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
|
100
|
-
init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
|
101
|
-
|
102
|
-
if init_kwargs_transformers != init_kwargs_diffusers:
|
103
|
-
raise ValueError(
|
104
|
-
"The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
|
105
|
-
f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
|
106
|
-
"this mapping would look like."
|
107
|
-
)
|
108
|
-
|
109
|
-
def _validate_quant_mapping_args(self):
|
110
|
-
quant_mapping = self.quant_mapping
|
111
|
-
transformers_map, diffusers_map = self._get_quant_config_list()
|
112
|
-
|
113
|
-
available_transformers = list(transformers_map.values()) if transformers_map else None
|
114
|
-
available_diffusers = list(diffusers_map.values())
|
115
|
-
|
116
|
-
for module_name, config in quant_mapping.items():
|
117
|
-
if any(isinstance(config, cfg) for cfg in available_diffusers):
|
118
|
-
continue
|
119
|
-
|
120
|
-
if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
|
121
|
-
continue
|
122
|
-
|
123
|
-
if available_transformers:
|
124
|
-
raise ValueError(
|
125
|
-
f"Provided config for module_name={module_name} could not be found. "
|
126
|
-
f"Available diffusers configs: {available_diffusers}; "
|
127
|
-
f"Available transformers configs: {available_transformers}."
|
128
|
-
)
|
129
|
-
else:
|
130
|
-
raise ValueError(
|
131
|
-
f"Provided config for module_name={module_name} could not be found. "
|
132
|
-
f"Available diffusers configs: {available_diffusers}."
|
133
|
-
)
|
134
|
-
|
135
|
-
def _check_backend_availability(self, quant_backend: str):
|
136
|
-
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
|
137
|
-
|
138
|
-
available_backends_transformers = (
|
139
|
-
list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
|
140
|
-
)
|
141
|
-
available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
|
142
|
-
|
143
|
-
if (
|
144
|
-
available_backends_transformers and quant_backend not in available_backends_transformers
|
145
|
-
) or quant_backend not in quant_config_mapping_diffusers:
|
146
|
-
error_message = f"Provided quant_backend={quant_backend} was not found."
|
147
|
-
if available_backends_transformers:
|
148
|
-
error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
|
149
|
-
error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
|
150
|
-
raise ValueError(error_message)
|
151
|
-
|
152
|
-
def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
|
153
|
-
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
|
154
|
-
|
155
|
-
quant_mapping = self.quant_mapping
|
156
|
-
components_to_quantize = self.components_to_quantize
|
157
|
-
|
158
|
-
# Granular case
|
159
|
-
if self.is_granular and module_name in quant_mapping:
|
160
|
-
logger.debug(f"Initializing quantization config class for {module_name}.")
|
161
|
-
config = quant_mapping[module_name]
|
162
|
-
return config
|
163
|
-
|
164
|
-
# Global config case
|
165
|
-
else:
|
166
|
-
should_quantize = False
|
167
|
-
# Only quantize the modules requested for.
|
168
|
-
if components_to_quantize and module_name in components_to_quantize:
|
169
|
-
should_quantize = True
|
170
|
-
# No specification for `components_to_quantize` means all modules should be quantized.
|
171
|
-
elif not self.is_granular and not components_to_quantize:
|
172
|
-
should_quantize = True
|
173
|
-
|
174
|
-
if should_quantize:
|
175
|
-
logger.debug(f"Initializing quantization config class for {module_name}.")
|
176
|
-
mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
|
177
|
-
quant_config_cls = mapping_to_use[self.quant_backend]
|
178
|
-
quant_kwargs = self.quant_kwargs
|
179
|
-
return quant_config_cls(**quant_kwargs)
|
180
|
-
|
181
|
-
# Fallback: no applicable configuration found.
|
182
|
-
return None
|
183
|
-
|
184
|
-
def _get_quant_config_list(self):
|
185
|
-
if is_transformers_available():
|
186
|
-
from transformers.quantizers.auto import (
|
187
|
-
AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
|
188
|
-
)
|
189
|
-
else:
|
190
|
-
quant_config_mapping_transformers = None
|
191
|
-
|
192
|
-
from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
|
193
|
-
|
194
|
-
return quant_config_mapping_transformers, quant_config_mapping_diffusers
|
18
|
+
from .pipe_quant_config import PipelineQuantizationConfig
|
diffusers/quantizers/base.py
CHANGED
@@ -209,6 +209,17 @@ class DiffusersQuantizer(ABC):
|
|
209
209
|
|
210
210
|
return model
|
211
211
|
|
212
|
+
def get_cuda_warm_up_factor(self):
|
213
|
+
"""
|
214
|
+
The factor to be used in `caching_allocator_warmup` to get the number of bytes to pre-allocate to warm up cuda.
|
215
|
+
A factor of 2 means we allocate all bytes in the empty model (since we allocate in fp16), a factor of 4 means
|
216
|
+
we allocate half the memory of the weights residing in the empty model, etc...
|
217
|
+
"""
|
218
|
+
# By default we return 4, i.e. half the model size (this corresponds to the case where the model is not
|
219
|
+
# really pre-processed, i.e. we do not have the info that weights are going to be 8 bits before actual
|
220
|
+
# weight loading)
|
221
|
+
return 4
|
222
|
+
|
212
223
|
def _dequantize(self, model):
|
213
224
|
raise NotImplementedError(
|
214
225
|
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
|
@@ -12,15 +12,15 @@
|
|
12
12
|
# # See the License for the specific language governing permissions and
|
13
13
|
# # limitations under the License.
|
14
14
|
|
15
|
-
|
16
15
|
import inspect
|
16
|
+
import os
|
17
17
|
from contextlib import nullcontext
|
18
18
|
|
19
19
|
import gguf
|
20
20
|
import torch
|
21
21
|
import torch.nn as nn
|
22
22
|
|
23
|
-
from ...utils import is_accelerate_available
|
23
|
+
from ...utils import is_accelerate_available, is_kernels_available
|
24
24
|
|
25
25
|
|
26
26
|
if is_accelerate_available():
|
@@ -29,6 +29,82 @@ if is_accelerate_available():
|
|
29
29
|
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
30
30
|
|
31
31
|
|
32
|
+
can_use_cuda_kernels = (
|
33
|
+
os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"]
|
34
|
+
and torch.cuda.is_available()
|
35
|
+
and torch.cuda.get_device_capability()[0] >= 7
|
36
|
+
)
|
37
|
+
if can_use_cuda_kernels and is_kernels_available():
|
38
|
+
from kernels import get_kernel
|
39
|
+
|
40
|
+
ops = get_kernel("Isotr0py/ggml")
|
41
|
+
else:
|
42
|
+
ops = None
|
43
|
+
|
44
|
+
UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
|
45
|
+
STANDARD_QUANT_TYPES = {
|
46
|
+
gguf.GGMLQuantizationType.Q4_0,
|
47
|
+
gguf.GGMLQuantizationType.Q4_1,
|
48
|
+
gguf.GGMLQuantizationType.Q5_0,
|
49
|
+
gguf.GGMLQuantizationType.Q5_1,
|
50
|
+
gguf.GGMLQuantizationType.Q8_0,
|
51
|
+
gguf.GGMLQuantizationType.Q8_1,
|
52
|
+
}
|
53
|
+
KQUANT_TYPES = {
|
54
|
+
gguf.GGMLQuantizationType.Q2_K,
|
55
|
+
gguf.GGMLQuantizationType.Q3_K,
|
56
|
+
gguf.GGMLQuantizationType.Q4_K,
|
57
|
+
gguf.GGMLQuantizationType.Q5_K,
|
58
|
+
gguf.GGMLQuantizationType.Q6_K,
|
59
|
+
}
|
60
|
+
IMATRIX_QUANT_TYPES = {
|
61
|
+
gguf.GGMLQuantizationType.IQ1_M,
|
62
|
+
gguf.GGMLQuantizationType.IQ1_S,
|
63
|
+
gguf.GGMLQuantizationType.IQ2_XXS,
|
64
|
+
gguf.GGMLQuantizationType.IQ2_XS,
|
65
|
+
gguf.GGMLQuantizationType.IQ2_S,
|
66
|
+
gguf.GGMLQuantizationType.IQ3_XXS,
|
67
|
+
gguf.GGMLQuantizationType.IQ3_S,
|
68
|
+
gguf.GGMLQuantizationType.IQ4_XS,
|
69
|
+
gguf.GGMLQuantizationType.IQ4_NL,
|
70
|
+
}
|
71
|
+
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
|
72
|
+
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
|
73
|
+
# MMQ kernel for I-Matrix quantization.
|
74
|
+
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
75
|
+
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
76
|
+
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
77
|
+
|
78
|
+
|
79
|
+
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
|
80
|
+
# there is no need to call any kernel for fp16/bf16
|
81
|
+
if qweight_type in UNQUANTIZED_TYPES:
|
82
|
+
return x @ qweight.T
|
83
|
+
|
84
|
+
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
|
85
|
+
# contiguous batching and inefficient with diffusers' batching,
|
86
|
+
# so we disabled it now.
|
87
|
+
|
88
|
+
# elif qweight_type in MMVQ_QUANT_TYPES:
|
89
|
+
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
|
90
|
+
# elif qweight_type in MMQ_QUANT_TYPES:
|
91
|
+
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
92
|
+
|
93
|
+
# If there is no available MMQ kernel, fallback to dequantize
|
94
|
+
if qweight_type in DEQUANT_TYPES:
|
95
|
+
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
96
|
+
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
|
97
|
+
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
|
98
|
+
y = x @ weight.to(x.dtype).T
|
99
|
+
else:
|
100
|
+
# Raise an error if the quantization type is not supported.
|
101
|
+
# Might be useful if llama.cpp adds a new quantization type.
|
102
|
+
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
|
103
|
+
qweight_type = gguf.GGMLQuantizationType(qweight_type)
|
104
|
+
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
|
105
|
+
return y.as_tensor()
|
106
|
+
|
107
|
+
|
32
108
|
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
|
33
109
|
def _create_accelerate_new_hook(old_hook):
|
34
110
|
r"""
|
@@ -451,11 +527,24 @@ class GGUFLinear(nn.Linear):
|
|
451
527
|
) -> None:
|
452
528
|
super().__init__(in_features, out_features, bias, device)
|
453
529
|
self.compute_dtype = compute_dtype
|
530
|
+
self.device = device
|
531
|
+
|
532
|
+
def forward(self, inputs: torch.Tensor):
|
533
|
+
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
|
534
|
+
return self.forward_cuda(inputs)
|
535
|
+
return self.forward_native(inputs)
|
454
536
|
|
455
|
-
def
|
537
|
+
def forward_native(self, inputs: torch.Tensor):
|
456
538
|
weight = dequantize_gguf_tensor(self.weight)
|
457
539
|
weight = weight.to(self.compute_dtype)
|
458
540
|
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
|
459
541
|
|
460
542
|
output = torch.nn.functional.linear(inputs, weight, bias)
|
461
543
|
return output
|
544
|
+
|
545
|
+
def forward_cuda(self, inputs: torch.Tensor):
|
546
|
+
quant_type = self.weight.quant_type
|
547
|
+
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
|
548
|
+
if self.bias is not None:
|
549
|
+
output += self.bias.to(self.compute_dtype)
|
550
|
+
return output
|
@@ -0,0 +1,202 @@
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import inspect
|
16
|
+
from typing import Dict, List, Optional, Union
|
17
|
+
|
18
|
+
from ..utils import is_transformers_available, logging
|
19
|
+
from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
|
20
|
+
|
21
|
+
|
22
|
+
try:
|
23
|
+
from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
|
24
|
+
except ImportError:
|
25
|
+
|
26
|
+
class TransformersQuantConfigMixin:
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
30
|
+
logger = logging.get_logger(__name__)
|
31
|
+
|
32
|
+
|
33
|
+
class PipelineQuantizationConfig:
|
34
|
+
"""
|
35
|
+
Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
|
36
|
+
|
37
|
+
Args:
|
38
|
+
quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
|
39
|
+
is available to both `diffusers` and `transformers`.
|
40
|
+
quant_kwargs (`dict`): Params to initialize the quantization backend class.
|
41
|
+
components_to_quantize (`list`): Components of a pipeline to be quantized.
|
42
|
+
quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
|
43
|
+
components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
|
44
|
+
and `components_to_quantize`.
|
45
|
+
"""
|
46
|
+
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
quant_backend: str = None,
|
50
|
+
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
|
51
|
+
components_to_quantize: Optional[List[str]] = None,
|
52
|
+
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
|
53
|
+
):
|
54
|
+
self.quant_backend = quant_backend
|
55
|
+
# Initialize kwargs to be {} to set to the defaults.
|
56
|
+
self.quant_kwargs = quant_kwargs or {}
|
57
|
+
self.components_to_quantize = components_to_quantize
|
58
|
+
self.quant_mapping = quant_mapping
|
59
|
+
self.config_mapping = {} # book-keeping Example: `{module_name: quant_config}`
|
60
|
+
self.post_init()
|
61
|
+
|
62
|
+
def post_init(self):
|
63
|
+
quant_mapping = self.quant_mapping
|
64
|
+
self.is_granular = True if quant_mapping is not None else False
|
65
|
+
|
66
|
+
self._validate_init_args()
|
67
|
+
|
68
|
+
def _validate_init_args(self):
|
69
|
+
if self.quant_backend and self.quant_mapping:
|
70
|
+
raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
|
71
|
+
|
72
|
+
if not self.quant_mapping and not self.quant_backend:
|
73
|
+
raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
|
74
|
+
|
75
|
+
if not self.quant_kwargs and not self.quant_mapping:
|
76
|
+
raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
|
77
|
+
|
78
|
+
if self.quant_backend is not None:
|
79
|
+
self._validate_init_kwargs_in_backends()
|
80
|
+
|
81
|
+
if self.quant_mapping is not None:
|
82
|
+
self._validate_quant_mapping_args()
|
83
|
+
|
84
|
+
def _validate_init_kwargs_in_backends(self):
|
85
|
+
quant_backend = self.quant_backend
|
86
|
+
|
87
|
+
self._check_backend_availability(quant_backend)
|
88
|
+
|
89
|
+
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
|
90
|
+
|
91
|
+
if quant_config_mapping_transformers is not None:
|
92
|
+
init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
|
93
|
+
init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
|
94
|
+
else:
|
95
|
+
init_kwargs_transformers = None
|
96
|
+
|
97
|
+
init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
|
98
|
+
init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
|
99
|
+
|
100
|
+
if init_kwargs_transformers != init_kwargs_diffusers:
|
101
|
+
raise ValueError(
|
102
|
+
"The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
|
103
|
+
f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
|
104
|
+
"this mapping would look like."
|
105
|
+
)
|
106
|
+
|
107
|
+
def _validate_quant_mapping_args(self):
|
108
|
+
quant_mapping = self.quant_mapping
|
109
|
+
transformers_map, diffusers_map = self._get_quant_config_list()
|
110
|
+
|
111
|
+
available_transformers = list(transformers_map.values()) if transformers_map else None
|
112
|
+
available_diffusers = list(diffusers_map.values())
|
113
|
+
|
114
|
+
for module_name, config in quant_mapping.items():
|
115
|
+
if any(isinstance(config, cfg) for cfg in available_diffusers):
|
116
|
+
continue
|
117
|
+
|
118
|
+
if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
|
119
|
+
continue
|
120
|
+
|
121
|
+
if available_transformers:
|
122
|
+
raise ValueError(
|
123
|
+
f"Provided config for module_name={module_name} could not be found. "
|
124
|
+
f"Available diffusers configs: {available_diffusers}; "
|
125
|
+
f"Available transformers configs: {available_transformers}."
|
126
|
+
)
|
127
|
+
else:
|
128
|
+
raise ValueError(
|
129
|
+
f"Provided config for module_name={module_name} could not be found. "
|
130
|
+
f"Available diffusers configs: {available_diffusers}."
|
131
|
+
)
|
132
|
+
|
133
|
+
def _check_backend_availability(self, quant_backend: str):
|
134
|
+
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
|
135
|
+
|
136
|
+
available_backends_transformers = (
|
137
|
+
list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
|
138
|
+
)
|
139
|
+
available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
|
140
|
+
|
141
|
+
if (
|
142
|
+
available_backends_transformers and quant_backend not in available_backends_transformers
|
143
|
+
) or quant_backend not in quant_config_mapping_diffusers:
|
144
|
+
error_message = f"Provided quant_backend={quant_backend} was not found."
|
145
|
+
if available_backends_transformers:
|
146
|
+
error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
|
147
|
+
error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
|
148
|
+
raise ValueError(error_message)
|
149
|
+
|
150
|
+
def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
|
151
|
+
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
|
152
|
+
|
153
|
+
quant_mapping = self.quant_mapping
|
154
|
+
components_to_quantize = self.components_to_quantize
|
155
|
+
|
156
|
+
# Granular case
|
157
|
+
if self.is_granular and module_name in quant_mapping:
|
158
|
+
logger.debug(f"Initializing quantization config class for {module_name}.")
|
159
|
+
config = quant_mapping[module_name]
|
160
|
+
self.config_mapping.update({module_name: config})
|
161
|
+
return config
|
162
|
+
|
163
|
+
# Global config case
|
164
|
+
else:
|
165
|
+
should_quantize = False
|
166
|
+
# Only quantize the modules requested for.
|
167
|
+
if components_to_quantize and module_name in components_to_quantize:
|
168
|
+
should_quantize = True
|
169
|
+
# No specification for `components_to_quantize` means all modules should be quantized.
|
170
|
+
elif not self.is_granular and not components_to_quantize:
|
171
|
+
should_quantize = True
|
172
|
+
|
173
|
+
if should_quantize:
|
174
|
+
logger.debug(f"Initializing quantization config class for {module_name}.")
|
175
|
+
mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
|
176
|
+
quant_config_cls = mapping_to_use[self.quant_backend]
|
177
|
+
quant_kwargs = self.quant_kwargs
|
178
|
+
quant_obj = quant_config_cls(**quant_kwargs)
|
179
|
+
self.config_mapping.update({module_name: quant_obj})
|
180
|
+
return quant_obj
|
181
|
+
|
182
|
+
# Fallback: no applicable configuration found.
|
183
|
+
return None
|
184
|
+
|
185
|
+
def _get_quant_config_list(self):
|
186
|
+
if is_transformers_available():
|
187
|
+
from transformers.quantizers.auto import (
|
188
|
+
AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
|
189
|
+
)
|
190
|
+
else:
|
191
|
+
quant_config_mapping_transformers = None
|
192
|
+
|
193
|
+
from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
|
194
|
+
|
195
|
+
return quant_config_mapping_transformers, quant_config_mapping_diffusers
|
196
|
+
|
197
|
+
def __repr__(self):
|
198
|
+
out = ""
|
199
|
+
config_mapping = dict(sorted(self.config_mapping.copy().items()))
|
200
|
+
for module_name, config in config_mapping.items():
|
201
|
+
out += f"{module_name} {config}"
|
202
|
+
return out
|
@@ -19,6 +19,7 @@ https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac17
|
|
19
19
|
|
20
20
|
import importlib
|
21
21
|
import types
|
22
|
+
from fnmatch import fnmatch
|
22
23
|
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
23
24
|
|
24
25
|
from packaging import version
|
@@ -278,6 +279,31 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
|
278
279
|
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
|
279
280
|
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
280
281
|
|
282
|
+
def get_cuda_warm_up_factor(self):
|
283
|
+
"""
|
284
|
+
This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup.
|
285
|
+
- A factor of 2 means we pre-allocate the full memory footprint of the model.
|
286
|
+
- A factor of 4 means we pre-allocate half of that, and so on
|
287
|
+
|
288
|
+
However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give
|
289
|
+
the correct size for quantized weights (like int4 or int8) That's because TorchAO internally represents
|
290
|
+
quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the
|
291
|
+
torch_dtype not the actual bit-width of the quantized data.
|
292
|
+
|
293
|
+
To correct for this:
|
294
|
+
- Use a division factor of 8 for int4 weights
|
295
|
+
- Use a division factor of 4 for int8 weights
|
296
|
+
"""
|
297
|
+
# Original mapping for non-AOBaseConfig types
|
298
|
+
# For the uint types, this is a best guess. Once these types become more used
|
299
|
+
# we can look into their nuances.
|
300
|
+
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
|
301
|
+
quant_type = self.quantization_config.quant_type
|
302
|
+
for pattern, target_dtype in map_to_target_dtype.items():
|
303
|
+
if fnmatch(quant_type, pattern):
|
304
|
+
return target_dtype
|
305
|
+
raise ValueError(f"Unsupported quant_type: {quant_type!r}")
|
306
|
+
|
281
307
|
def _process_model_before_weight_loading(
|
282
308
|
self,
|
283
309
|
model: "ModelMixin",
|
@@ -153,6 +153,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
153
153
|
flow_shift: Optional[float] = 1.0,
|
154
154
|
timestep_spacing: str = "linspace",
|
155
155
|
steps_offset: int = 0,
|
156
|
+
use_dynamic_shifting: bool = False,
|
157
|
+
time_shift_type: str = "exponential",
|
156
158
|
):
|
157
159
|
if self.config.use_beta_sigmas and not is_scipy_available():
|
158
160
|
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
@@ -232,7 +234,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
232
234
|
"""
|
233
235
|
self._begin_index = begin_index
|
234
236
|
|
235
|
-
def set_timesteps(
|
237
|
+
def set_timesteps(
|
238
|
+
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
|
239
|
+
):
|
236
240
|
"""
|
237
241
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
238
242
|
|
@@ -242,6 +246,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
242
246
|
device (`str` or `torch.device`, *optional*):
|
243
247
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
244
248
|
"""
|
249
|
+
if mu is not None:
|
250
|
+
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
|
251
|
+
self.config.flow_shift = np.exp(mu)
|
245
252
|
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
|
246
253
|
if self.config.timestep_spacing == "linspace":
|
247
254
|
timesteps = (
|
@@ -230,6 +230,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
230
230
|
timestep_spacing: str = "linspace",
|
231
231
|
steps_offset: int = 0,
|
232
232
|
rescale_betas_zero_snr: bool = False,
|
233
|
+
use_dynamic_shifting: bool = False,
|
234
|
+
time_shift_type: str = "exponential",
|
233
235
|
):
|
234
236
|
if self.config.use_beta_sigmas and not is_scipy_available():
|
235
237
|
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
@@ -330,6 +332,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
330
332
|
self,
|
331
333
|
num_inference_steps: int = None,
|
332
334
|
device: Union[str, torch.device] = None,
|
335
|
+
mu: Optional[float] = None,
|
333
336
|
timesteps: Optional[List[int]] = None,
|
334
337
|
):
|
335
338
|
"""
|
@@ -345,6 +348,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
345
348
|
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
346
349
|
must be `None`, and `timestep_spacing` attribute will be ignored.
|
347
350
|
"""
|
351
|
+
if mu is not None:
|
352
|
+
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
|
353
|
+
self.config.flow_shift = np.exp(mu)
|
348
354
|
if num_inference_steps is None and timesteps is None:
|
349
355
|
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
|
350
356
|
if num_inference_steps is not None and timesteps is not None:
|
@@ -169,6 +169,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
169
169
|
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
170
170
|
lambda_min_clipped: float = -float("inf"),
|
171
171
|
variance_type: Optional[str] = None,
|
172
|
+
use_dynamic_shifting: bool = False,
|
173
|
+
time_shift_type: str = "exponential",
|
172
174
|
):
|
173
175
|
if self.config.use_beta_sigmas and not is_scipy_available():
|
174
176
|
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
@@ -301,6 +303,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
301
303
|
self,
|
302
304
|
num_inference_steps: int = None,
|
303
305
|
device: Union[str, torch.device] = None,
|
306
|
+
mu: Optional[float] = None,
|
304
307
|
timesteps: Optional[List[int]] = None,
|
305
308
|
):
|
306
309
|
"""
|
@@ -316,6 +319,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
316
319
|
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
|
317
320
|
passed, `num_inference_steps` must be `None`.
|
318
321
|
"""
|
322
|
+
if mu is not None:
|
323
|
+
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
|
324
|
+
self.config.flow_shift = np.exp(mu)
|
319
325
|
if num_inference_steps is None and timesteps is None:
|
320
326
|
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
|
321
327
|
if num_inference_steps is not None and timesteps is not None:
|
@@ -168,7 +168,6 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
|
|
168
168
|
else:
|
169
169
|
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
|
170
170
|
self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
|
171
|
-
print(f"Set timesteps: {self.timesteps}")
|
172
171
|
|
173
172
|
self._step_index = None
|
174
173
|
self._begin_index = None
|