diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +74 -28
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,456 @@
1
+ # Copyright 2024 The HuggingFace Team and City96. All rights reserved.
2
+ # #
3
+ # # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # # you may not use this file except in compliance with the License.
5
+ # # You may obtain a copy of the License at
6
+ # #
7
+ # # http://www.apache.org/licenses/LICENSE-2.0
8
+ # #
9
+ # # Unless required by applicable law or agreed to in writing, software
10
+ # # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # # See the License for the specific language governing permissions and
13
+ # # limitations under the License.
14
+
15
+
16
+ import inspect
17
+ from contextlib import nullcontext
18
+
19
+ import gguf
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from ...utils import is_accelerate_available
24
+
25
+
26
+ if is_accelerate_available():
27
+ import accelerate
28
+ from accelerate import init_empty_weights
29
+ from accelerate.hooks import add_hook_to_module, remove_hook_from_module
30
+
31
+
32
+ # Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
33
+ def _create_accelerate_new_hook(old_hook):
34
+ r"""
35
+ Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of:
36
+ https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with
37
+ some changes
38
+ """
39
+ old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
40
+ old_hook_attr = old_hook.__dict__
41
+ filtered_old_hook_attr = {}
42
+ old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
43
+ for k in old_hook_attr.keys():
44
+ if k in old_hook_init_signature.parameters:
45
+ filtered_old_hook_attr[k] = old_hook_attr[k]
46
+ new_hook = old_hook_cls(**filtered_old_hook_attr)
47
+ return new_hook
48
+
49
+
50
+ def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[]):
51
+ def _should_convert_to_gguf(state_dict, prefix):
52
+ weight_key = prefix + "weight"
53
+ return weight_key in state_dict and isinstance(state_dict[weight_key], GGUFParameter)
54
+
55
+ has_children = list(model.children())
56
+ if not has_children:
57
+ return
58
+
59
+ for name, module in model.named_children():
60
+ module_prefix = prefix + name + "."
61
+ _replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix, modules_to_not_convert)
62
+
63
+ if (
64
+ isinstance(module, nn.Linear)
65
+ and _should_convert_to_gguf(state_dict, module_prefix)
66
+ and name not in modules_to_not_convert
67
+ ):
68
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
69
+ with ctx():
70
+ model._modules[name] = GGUFLinear(
71
+ module.in_features,
72
+ module.out_features,
73
+ module.bias is not None,
74
+ compute_dtype=compute_dtype,
75
+ )
76
+ model._modules[name].source_cls = type(module)
77
+ # Force requires_grad to False to avoid unexpected errors
78
+ model._modules[name].requires_grad_(False)
79
+
80
+ return model
81
+
82
+
83
+ def _dequantize_gguf_and_restore_linear(model, modules_to_not_convert=[]):
84
+ for name, module in model.named_children():
85
+ if isinstance(module, GGUFLinear) and name not in modules_to_not_convert:
86
+ device = module.weight.device
87
+ bias = getattr(module, "bias", None)
88
+
89
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
90
+ with ctx():
91
+ new_module = nn.Linear(
92
+ module.in_features,
93
+ module.out_features,
94
+ module.bias is not None,
95
+ device=device,
96
+ )
97
+ new_module.weight = nn.Parameter(dequantize_gguf_tensor(module.weight))
98
+ if bias is not None:
99
+ new_module.bias = bias
100
+
101
+ # Create a new hook and attach it in case we use accelerate
102
+ if hasattr(module, "_hf_hook"):
103
+ old_hook = module._hf_hook
104
+ new_hook = _create_accelerate_new_hook(old_hook)
105
+
106
+ remove_hook_from_module(module)
107
+ add_hook_to_module(new_module, new_hook)
108
+
109
+ new_module.to(device)
110
+ model._modules[name] = new_module
111
+
112
+ has_children = list(module.children())
113
+ if has_children:
114
+ _dequantize_gguf_and_restore_linear(module, modules_to_not_convert)
115
+
116
+ return model
117
+
118
+
119
+ # dequantize operations based on torch ports of GGUF dequantize_functions
120
+ # from City96
121
+ # more info: https://github.com/city96/ComfyUI-GGUF/blob/main/dequant.py
122
+
123
+
124
+ QK_K = 256
125
+ K_SCALE_SIZE = 12
126
+
127
+
128
+ def to_uint32(x):
129
+ x = x.view(torch.uint8).to(torch.int32)
130
+ return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
131
+
132
+
133
+ def split_block_dims(blocks, *args):
134
+ n_max = blocks.shape[1]
135
+ dims = list(args) + [n_max - sum(args)]
136
+ return torch.split(blocks, dims, dim=1)
137
+
138
+
139
+ def get_scale_min(scales):
140
+ n_blocks = scales.shape[0]
141
+ scales = scales.view(torch.uint8)
142
+ scales = scales.reshape((n_blocks, 3, 4))
143
+
144
+ d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
145
+
146
+ sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
147
+ min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
148
+
149
+ return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
150
+
151
+
152
+ def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None):
153
+ d, x = split_block_dims(blocks, 2)
154
+ d = d.view(torch.float16).to(dtype)
155
+ x = x.view(torch.int8)
156
+ return d * x
157
+
158
+
159
+ def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None):
160
+ n_blocks = blocks.shape[0]
161
+
162
+ d, m, qh, qs = split_block_dims(blocks, 2, 2, 4)
163
+ d = d.view(torch.float16).to(dtype)
164
+ m = m.view(torch.float16).to(dtype)
165
+ qh = to_uint32(qh)
166
+
167
+ qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
168
+ ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
169
+ [0, 4], device=d.device, dtype=torch.uint8
170
+ ).reshape(1, 1, 2, 1)
171
+ qh = (qh & 1).to(torch.uint8)
172
+ ql = (ql & 0x0F).reshape((n_blocks, -1))
173
+
174
+ qs = ql | (qh << 4)
175
+ return (d * qs) + m
176
+
177
+
178
+ def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None):
179
+ n_blocks = blocks.shape[0]
180
+
181
+ d, qh, qs = split_block_dims(blocks, 2, 4)
182
+ d = d.view(torch.float16).to(dtype)
183
+ qh = to_uint32(qh)
184
+
185
+ qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
186
+ ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor(
187
+ [0, 4], device=d.device, dtype=torch.uint8
188
+ ).reshape(1, 1, 2, 1)
189
+
190
+ qh = (qh & 1).to(torch.uint8)
191
+ ql = (ql & 0x0F).reshape(n_blocks, -1)
192
+
193
+ qs = (ql | (qh << 4)).to(torch.int8) - 16
194
+ return d * qs
195
+
196
+
197
+ def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None):
198
+ n_blocks = blocks.shape[0]
199
+
200
+ d, m, qs = split_block_dims(blocks, 2, 2)
201
+ d = d.view(torch.float16).to(dtype)
202
+ m = m.view(torch.float16).to(dtype)
203
+
204
+ qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
205
+ [0, 4], device=d.device, dtype=torch.uint8
206
+ ).reshape(1, 1, 2, 1)
207
+ qs = (qs & 0x0F).reshape(n_blocks, -1)
208
+
209
+ return (d * qs) + m
210
+
211
+
212
+ def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None):
213
+ n_blocks = blocks.shape[0]
214
+
215
+ d, qs = split_block_dims(blocks, 2)
216
+ d = d.view(torch.float16).to(dtype)
217
+
218
+ qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
219
+ [0, 4], device=d.device, dtype=torch.uint8
220
+ ).reshape((1, 1, 2, 1))
221
+ qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
222
+ return d * qs
223
+
224
+
225
+ def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None):
226
+ n_blocks = blocks.shape[0]
227
+
228
+ (
229
+ ql,
230
+ qh,
231
+ scales,
232
+ d,
233
+ ) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
234
+
235
+ scales = scales.view(torch.int8).to(dtype)
236
+ d = d.view(torch.float16).to(dtype)
237
+ d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
238
+
239
+ ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
240
+ (1, 1, 2, 1)
241
+ )
242
+ ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
243
+ qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
244
+ (1, 1, 4, 1)
245
+ )
246
+ qh = (qh & 0x03).reshape((n_blocks, -1, 32))
247
+ q = (ql | (qh << 4)).to(torch.int8) - 32
248
+ q = q.reshape((n_blocks, QK_K // 16, -1))
249
+
250
+ return (d * q).reshape((n_blocks, QK_K))
251
+
252
+
253
+ def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None):
254
+ n_blocks = blocks.shape[0]
255
+
256
+ d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
257
+
258
+ d = d.view(torch.float16).to(dtype)
259
+ dmin = dmin.view(torch.float16).to(dtype)
260
+
261
+ sc, m = get_scale_min(scales)
262
+
263
+ d = (d * sc).reshape((n_blocks, -1, 1))
264
+ dm = (dmin * m).reshape((n_blocks, -1, 1))
265
+
266
+ ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
267
+ (1, 1, 2, 1)
268
+ )
269
+ qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape(
270
+ (1, 1, 8, 1)
271
+ )
272
+ ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
273
+ qh = (qh & 0x01).reshape((n_blocks, -1, 32))
274
+ q = ql | (qh << 4)
275
+
276
+ return (d * q - dm).reshape((n_blocks, QK_K))
277
+
278
+
279
+ def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None):
280
+ n_blocks = blocks.shape[0]
281
+
282
+ d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
283
+ d = d.view(torch.float16).to(dtype)
284
+ dmin = dmin.view(torch.float16).to(dtype)
285
+
286
+ sc, m = get_scale_min(scales)
287
+
288
+ d = (d * sc).reshape((n_blocks, -1, 1))
289
+ dm = (dmin * m).reshape((n_blocks, -1, 1))
290
+
291
+ qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
292
+ (1, 1, 2, 1)
293
+ )
294
+ qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
295
+
296
+ return (d * qs - dm).reshape((n_blocks, QK_K))
297
+
298
+
299
+ def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None):
300
+ n_blocks = blocks.shape[0]
301
+
302
+ hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12)
303
+ d = d.view(torch.float16).to(dtype)
304
+
305
+ lscales, hscales = scales[:, :8], scales[:, 8:]
306
+ lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
307
+ (1, 2, 1)
308
+ )
309
+ lscales = lscales.reshape((n_blocks, 16))
310
+ hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor(
311
+ [0, 2, 4, 6], device=d.device, dtype=torch.uint8
312
+ ).reshape((1, 4, 1))
313
+ hscales = hscales.reshape((n_blocks, 16))
314
+ scales = (lscales & 0x0F) | ((hscales & 0x03) << 4)
315
+ scales = scales.to(torch.int8) - 32
316
+
317
+ dl = (d * scales).reshape((n_blocks, 16, 1))
318
+
319
+ ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
320
+ (1, 1, 4, 1)
321
+ )
322
+ qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape(
323
+ (1, 1, 8, 1)
324
+ )
325
+ ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3
326
+ qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1
327
+ q = ql.to(torch.int8) - (qh << 2).to(torch.int8)
328
+
329
+ return (dl * q).reshape((n_blocks, QK_K))
330
+
331
+
332
+ def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None):
333
+ n_blocks = blocks.shape[0]
334
+
335
+ scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2)
336
+ d = d.view(torch.float16).to(dtype)
337
+ dmin = dmin.view(torch.float16).to(dtype)
338
+
339
+ # (n_blocks, 16, 1)
340
+ dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
341
+ ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
342
+
343
+ shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
344
+
345
+ qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3
346
+ qs = qs.reshape((n_blocks, QK_K // 16, 16))
347
+ qs = dl * qs - ml
348
+
349
+ return qs.reshape((n_blocks, -1))
350
+
351
+
352
+ def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
353
+ return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
354
+
355
+
356
+ GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES
357
+ dequantize_functions = {
358
+ gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
359
+ gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
360
+ gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
361
+ gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
362
+ gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1,
363
+ gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
364
+ gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
365
+ gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
366
+ gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
367
+ gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K,
368
+ gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
369
+ }
370
+ SUPPORTED_GGUF_QUANT_TYPES = list(dequantize_functions.keys())
371
+
372
+
373
+ def _quant_shape_from_byte_shape(shape, type_size, block_size):
374
+ return (*shape[:-1], shape[-1] // type_size * block_size)
375
+
376
+
377
+ def dequantize_gguf_tensor(tensor):
378
+ if not hasattr(tensor, "quant_type"):
379
+ return tensor
380
+
381
+ quant_type = tensor.quant_type
382
+ dequant_fn = dequantize_functions[quant_type]
383
+
384
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
385
+
386
+ tensor = tensor.view(torch.uint8)
387
+ shape = _quant_shape_from_byte_shape(tensor.shape, type_size, block_size)
388
+
389
+ n_blocks = tensor.numel() // type_size
390
+ blocks = tensor.reshape((n_blocks, type_size))
391
+
392
+ dequant = dequant_fn(blocks, block_size, type_size)
393
+ dequant = dequant.reshape(shape)
394
+
395
+ return dequant.as_tensor()
396
+
397
+
398
+ class GGUFParameter(torch.nn.Parameter):
399
+ def __new__(cls, data, requires_grad=False, quant_type=None):
400
+ data = data if data is not None else torch.empty(0)
401
+ self = torch.Tensor._make_subclass(cls, data, requires_grad)
402
+ self.quant_type = quant_type
403
+
404
+ return self
405
+
406
+ def as_tensor(self):
407
+ return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad)
408
+
409
+ @classmethod
410
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
411
+ if kwargs is None:
412
+ kwargs = {}
413
+
414
+ result = super().__torch_function__(func, types, args, kwargs)
415
+
416
+ # When converting from original format checkpoints we often use splits, cats etc on tensors
417
+ # this method ensures that the returned tensor type from those operations remains GGUFParameter
418
+ # so that we preserve quant_type information
419
+ quant_type = None
420
+ for arg in args:
421
+ if isinstance(arg, list) and (arg[0], GGUFParameter):
422
+ quant_type = arg[0].quant_type
423
+ break
424
+ if isinstance(arg, GGUFParameter):
425
+ quant_type = arg.quant_type
426
+ break
427
+ if isinstance(result, torch.Tensor):
428
+ return cls(result, quant_type=quant_type)
429
+ # Handle tuples and lists
430
+ elif isinstance(result, (tuple, list)):
431
+ # Preserve the original type (tuple or list)
432
+ wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result]
433
+ return type(result)(wrapped)
434
+ else:
435
+ return result
436
+
437
+
438
+ class GGUFLinear(nn.Linear):
439
+ def __init__(
440
+ self,
441
+ in_features,
442
+ out_features,
443
+ bias=False,
444
+ compute_dtype=None,
445
+ device=None,
446
+ ) -> None:
447
+ super().__init__(in_features, out_features, bias, device)
448
+ self.compute_dtype = compute_dtype
449
+
450
+ def forward(self, inputs):
451
+ weight = dequantize_gguf_tensor(self.weight)
452
+ weight = weight.to(self.compute_dtype)
453
+ bias = self.bias.to(self.compute_dtype)
454
+
455
+ output = torch.nn.functional.linear(inputs, weight, bias)
456
+ return output