cache-dit 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +5 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +2 -0
- cache_dit/cache_factory/cache_adapters.py +375 -26
- cache_dit/cache_factory/cache_blocks/__init__.py +20 -0
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +270 -0
- cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +17 -18
- cache_dit/cache_factory/cache_blocks/utils.py +19 -0
- cache_dit/cache_factory/cache_context.py +32 -25
- cache_dit/cache_factory/cache_interface.py +8 -3
- cache_dit/cache_factory/forward_pattern.py +45 -24
- cache_dit/cache_factory/patch_functors/__init__.py +5 -0
- cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +273 -0
- cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +45 -31
- cache_dit/compile/utils.py +1 -1
- cache_dit/quantize/__init__.py +1 -0
- cache_dit/quantize/quantize_ao.py +196 -0
- cache_dit/quantize/quantize_interface.py +46 -0
- cache_dit/utils.py +49 -17
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/METADATA +43 -18
- cache_dit-0.2.26.dist-info/RECORD +42 -0
- cache_dit-0.2.24.dist-info/RECORD +0 -32
- /cache_dit/{cache_factory/patch/__init__.py → quantize/quantize_svdq.py} +0 -0
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import numpy as np
|
|
5
|
+
from typing import Tuple, Optional, Dict, Any, Union
|
|
6
|
+
from diffusers import ChromaTransformer2DModel
|
|
7
|
+
from diffusers.models.transformers.transformer_chroma import (
|
|
8
|
+
ChromaSingleTransformerBlock,
|
|
9
|
+
Transformer2DModelOutput,
|
|
10
|
+
)
|
|
11
|
+
from diffusers.utils import (
|
|
12
|
+
USE_PEFT_BACKEND,
|
|
13
|
+
scale_lora_layers,
|
|
14
|
+
unscale_lora_layers,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from cache_dit.cache_factory.patch_functors.functor_base import (
|
|
18
|
+
PatchFunctor,
|
|
19
|
+
)
|
|
20
|
+
from cache_dit.logger import init_logger
|
|
21
|
+
|
|
22
|
+
logger = init_logger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ChromaPatchFunctor(PatchFunctor):
|
|
26
|
+
|
|
27
|
+
def apply(
|
|
28
|
+
self,
|
|
29
|
+
transformer: ChromaTransformer2DModel,
|
|
30
|
+
blocks: torch.nn.ModuleList = None,
|
|
31
|
+
**kwargs,
|
|
32
|
+
) -> ChromaTransformer2DModel:
|
|
33
|
+
if blocks is None:
|
|
34
|
+
blocks = transformer.single_transformer_blocks
|
|
35
|
+
|
|
36
|
+
is_patched = False
|
|
37
|
+
for block in blocks:
|
|
38
|
+
if isinstance(block, ChromaSingleTransformerBlock):
|
|
39
|
+
forward_parameters = inspect.signature(
|
|
40
|
+
block.forward
|
|
41
|
+
).parameters.keys()
|
|
42
|
+
if "encoder_hidden_states" not in forward_parameters:
|
|
43
|
+
block.forward = __patch_single_forward__.__get__(block)
|
|
44
|
+
is_patched = True
|
|
45
|
+
|
|
46
|
+
if is_patched:
|
|
47
|
+
logger.warning("Patched Chroma for cache-dit.")
|
|
48
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
49
|
+
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
50
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
51
|
+
"parallized forward and cause a downgrade of performance."
|
|
52
|
+
)
|
|
53
|
+
transformer.forward = __patch_transformer_forward__.__get__(
|
|
54
|
+
transformer
|
|
55
|
+
)
|
|
56
|
+
transformer._is_patched = True
|
|
57
|
+
|
|
58
|
+
cls_name = transformer.__class__.__name__
|
|
59
|
+
logger.info(
|
|
60
|
+
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
61
|
+
f"Patch: {is_patched}."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return transformer
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
|
|
68
|
+
def __patch_single_forward__(
|
|
69
|
+
self: ChromaSingleTransformerBlock, # Almost same as FluxSingleTransformerBlock
|
|
70
|
+
hidden_states: torch.Tensor,
|
|
71
|
+
encoder_hidden_states: torch.Tensor,
|
|
72
|
+
temb: torch.Tensor,
|
|
73
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
74
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
75
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
76
|
+
text_seq_len = encoder_hidden_states.shape[1]
|
|
77
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
78
|
+
|
|
79
|
+
residual = hidden_states
|
|
80
|
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
|
81
|
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
|
82
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
83
|
+
attn_output = self.attn(
|
|
84
|
+
hidden_states=norm_hidden_states,
|
|
85
|
+
image_rotary_emb=image_rotary_emb,
|
|
86
|
+
**joint_attention_kwargs,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
|
90
|
+
gate = gate.unsqueeze(1)
|
|
91
|
+
hidden_states = gate * self.proj_out(hidden_states)
|
|
92
|
+
hidden_states = residual + hidden_states
|
|
93
|
+
if hidden_states.dtype == torch.float16:
|
|
94
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
|
95
|
+
|
|
96
|
+
encoder_hidden_states, hidden_states = (
|
|
97
|
+
hidden_states[:, :text_seq_len],
|
|
98
|
+
hidden_states[:, text_seq_len:],
|
|
99
|
+
)
|
|
100
|
+
return encoder_hidden_states, hidden_states
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
|
|
104
|
+
def __patch_transformer_forward__(
|
|
105
|
+
self: ChromaTransformer2DModel,
|
|
106
|
+
hidden_states: torch.Tensor,
|
|
107
|
+
encoder_hidden_states: torch.Tensor = None,
|
|
108
|
+
timestep: torch.LongTensor = None,
|
|
109
|
+
img_ids: torch.Tensor = None,
|
|
110
|
+
txt_ids: torch.Tensor = None,
|
|
111
|
+
attention_mask: torch.Tensor = None,
|
|
112
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
113
|
+
controlnet_block_samples=None,
|
|
114
|
+
controlnet_single_block_samples=None,
|
|
115
|
+
return_dict: bool = True,
|
|
116
|
+
controlnet_blocks_repeat: bool = False,
|
|
117
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
|
118
|
+
if joint_attention_kwargs is not None:
|
|
119
|
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
|
120
|
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
|
121
|
+
else:
|
|
122
|
+
lora_scale = 1.0
|
|
123
|
+
|
|
124
|
+
if USE_PEFT_BACKEND:
|
|
125
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
126
|
+
scale_lora_layers(self, lora_scale)
|
|
127
|
+
else:
|
|
128
|
+
if (
|
|
129
|
+
joint_attention_kwargs is not None
|
|
130
|
+
and joint_attention_kwargs.get("scale", None) is not None
|
|
131
|
+
):
|
|
132
|
+
logger.warning(
|
|
133
|
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
hidden_states = self.x_embedder(hidden_states)
|
|
137
|
+
|
|
138
|
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
|
139
|
+
|
|
140
|
+
input_vec = self.time_text_embed(timestep)
|
|
141
|
+
pooled_temb = self.distilled_guidance_layer(input_vec)
|
|
142
|
+
|
|
143
|
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
|
144
|
+
|
|
145
|
+
if txt_ids.ndim == 3:
|
|
146
|
+
logger.warning(
|
|
147
|
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
|
148
|
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
|
149
|
+
)
|
|
150
|
+
txt_ids = txt_ids[0]
|
|
151
|
+
if img_ids.ndim == 3:
|
|
152
|
+
logger.warning(
|
|
153
|
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
|
154
|
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
|
155
|
+
)
|
|
156
|
+
img_ids = img_ids[0]
|
|
157
|
+
|
|
158
|
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
|
159
|
+
image_rotary_emb = self.pos_embed(ids)
|
|
160
|
+
|
|
161
|
+
if (
|
|
162
|
+
joint_attention_kwargs is not None
|
|
163
|
+
and "ip_adapter_image_embeds" in joint_attention_kwargs
|
|
164
|
+
):
|
|
165
|
+
ip_adapter_image_embeds = joint_attention_kwargs.pop(
|
|
166
|
+
"ip_adapter_image_embeds"
|
|
167
|
+
)
|
|
168
|
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
|
169
|
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
|
170
|
+
|
|
171
|
+
for index_block, block in enumerate(self.transformer_blocks):
|
|
172
|
+
img_offset = 3 * len(self.single_transformer_blocks)
|
|
173
|
+
txt_offset = img_offset + 6 * len(self.transformer_blocks)
|
|
174
|
+
img_modulation = img_offset + 6 * index_block
|
|
175
|
+
text_modulation = txt_offset + 6 * index_block
|
|
176
|
+
temb = torch.cat(
|
|
177
|
+
(
|
|
178
|
+
pooled_temb[:, img_modulation : img_modulation + 6],
|
|
179
|
+
pooled_temb[:, text_modulation : text_modulation + 6],
|
|
180
|
+
),
|
|
181
|
+
dim=1,
|
|
182
|
+
)
|
|
183
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
184
|
+
encoder_hidden_states, hidden_states = (
|
|
185
|
+
self._gradient_checkpointing_func(
|
|
186
|
+
block,
|
|
187
|
+
hidden_states,
|
|
188
|
+
encoder_hidden_states,
|
|
189
|
+
temb,
|
|
190
|
+
image_rotary_emb,
|
|
191
|
+
attention_mask,
|
|
192
|
+
)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
else:
|
|
196
|
+
encoder_hidden_states, hidden_states = block(
|
|
197
|
+
hidden_states=hidden_states,
|
|
198
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
199
|
+
temb=temb,
|
|
200
|
+
image_rotary_emb=image_rotary_emb,
|
|
201
|
+
attention_mask=attention_mask,
|
|
202
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# controlnet residual
|
|
206
|
+
if controlnet_block_samples is not None:
|
|
207
|
+
interval_control = len(self.transformer_blocks) / len(
|
|
208
|
+
controlnet_block_samples
|
|
209
|
+
)
|
|
210
|
+
interval_control = int(np.ceil(interval_control))
|
|
211
|
+
# For Xlabs ControlNet.
|
|
212
|
+
if controlnet_blocks_repeat:
|
|
213
|
+
hidden_states = (
|
|
214
|
+
hidden_states
|
|
215
|
+
+ controlnet_block_samples[
|
|
216
|
+
index_block % len(controlnet_block_samples)
|
|
217
|
+
]
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
hidden_states = (
|
|
221
|
+
hidden_states
|
|
222
|
+
+ controlnet_block_samples[index_block // interval_control]
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
|
226
|
+
start_idx = 3 * index_block
|
|
227
|
+
temb = pooled_temb[:, start_idx : start_idx + 3]
|
|
228
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
229
|
+
encoder_hidden_states, hidden_states = (
|
|
230
|
+
self._gradient_checkpointing_func(
|
|
231
|
+
block,
|
|
232
|
+
hidden_states,
|
|
233
|
+
encoder_hidden_states,
|
|
234
|
+
temb,
|
|
235
|
+
image_rotary_emb,
|
|
236
|
+
)
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
else:
|
|
240
|
+
encoder_hidden_states, hidden_states = block(
|
|
241
|
+
hidden_states=hidden_states,
|
|
242
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
243
|
+
temb=temb,
|
|
244
|
+
image_rotary_emb=image_rotary_emb,
|
|
245
|
+
attention_mask=attention_mask,
|
|
246
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# controlnet residual
|
|
250
|
+
if controlnet_single_block_samples is not None:
|
|
251
|
+
interval_control = len(self.single_transformer_blocks) / len(
|
|
252
|
+
controlnet_single_block_samples
|
|
253
|
+
)
|
|
254
|
+
interval_control = int(np.ceil(interval_control))
|
|
255
|
+
hidden_states = (
|
|
256
|
+
hidden_states
|
|
257
|
+
+ controlnet_single_block_samples[
|
|
258
|
+
index_block // interval_control
|
|
259
|
+
]
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
temb = pooled_temb[:, -2:]
|
|
263
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
|
264
|
+
output = self.proj_out(hidden_states)
|
|
265
|
+
|
|
266
|
+
if USE_PEFT_BACKEND:
|
|
267
|
+
# remove `lora_scale` from each PEFT layer
|
|
268
|
+
unscale_lora_layers(self, lora_scale)
|
|
269
|
+
|
|
270
|
+
if not return_dict:
|
|
271
|
+
return (output,)
|
|
272
|
+
|
|
273
|
+
return Transformer2DModelOutput(sample=output)
|
|
@@ -14,12 +14,56 @@ from diffusers.utils import (
|
|
|
14
14
|
unscale_lora_layers,
|
|
15
15
|
)
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
from cache_dit.cache_factory.patch_functors.functor_base import (
|
|
18
|
+
PatchFunctor,
|
|
19
|
+
)
|
|
18
20
|
from cache_dit.logger import init_logger
|
|
19
21
|
|
|
20
22
|
logger = init_logger(__name__)
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
class FluxPatchFunctor(PatchFunctor):
|
|
26
|
+
|
|
27
|
+
def apply(
|
|
28
|
+
self,
|
|
29
|
+
transformer: FluxTransformer2DModel,
|
|
30
|
+
blocks: torch.nn.ModuleList = None,
|
|
31
|
+
**kwargs,
|
|
32
|
+
) -> FluxTransformer2DModel:
|
|
33
|
+
if blocks is None:
|
|
34
|
+
blocks = transformer.single_transformer_blocks
|
|
35
|
+
|
|
36
|
+
is_patched = False
|
|
37
|
+
for block in blocks:
|
|
38
|
+
if isinstance(block, FluxSingleTransformerBlock):
|
|
39
|
+
forward_parameters = inspect.signature(
|
|
40
|
+
block.forward
|
|
41
|
+
).parameters.keys()
|
|
42
|
+
if "encoder_hidden_states" not in forward_parameters:
|
|
43
|
+
block.forward = __patch_single_forward__.__get__(block)
|
|
44
|
+
is_patched = True
|
|
45
|
+
|
|
46
|
+
if is_patched:
|
|
47
|
+
logger.warning("Patched Flux for cache-dit.")
|
|
48
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
49
|
+
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
50
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
51
|
+
"parallized forward and cause a downgrade of performance."
|
|
52
|
+
)
|
|
53
|
+
transformer.forward = __patch_transformer_forward__.__get__(
|
|
54
|
+
transformer
|
|
55
|
+
)
|
|
56
|
+
transformer._is_patched = True
|
|
57
|
+
|
|
58
|
+
cls_name = transformer.__class__.__name__
|
|
59
|
+
logger.info(
|
|
60
|
+
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
61
|
+
f"Patch: {is_patched}."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return transformer
|
|
65
|
+
|
|
66
|
+
|
|
23
67
|
# copy from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
|
|
24
68
|
def __patch_single_forward__(
|
|
25
69
|
self: FluxSingleTransformerBlock,
|
|
@@ -217,33 +261,3 @@ def __patch_transformer_forward__(
|
|
|
217
261
|
return (output,)
|
|
218
262
|
|
|
219
263
|
return Transformer2DModelOutput(sample=output)
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
def maybe_patch_flux_transformer(
|
|
223
|
-
transformer: FluxTransformer2DModel,
|
|
224
|
-
blocks: torch.nn.ModuleList = None,
|
|
225
|
-
) -> FluxTransformer2DModel:
|
|
226
|
-
if blocks is None:
|
|
227
|
-
blocks = transformer.single_transformer_blocks
|
|
228
|
-
|
|
229
|
-
is_patched = False
|
|
230
|
-
for block in blocks:
|
|
231
|
-
if isinstance(block, FluxSingleTransformerBlock):
|
|
232
|
-
forward_parameters = inspect.signature(
|
|
233
|
-
block.forward
|
|
234
|
-
).parameters.keys()
|
|
235
|
-
if "encoder_hidden_states" not in forward_parameters:
|
|
236
|
-
block.forward = __patch_single_forward__.__get__(block)
|
|
237
|
-
is_patched = True
|
|
238
|
-
|
|
239
|
-
if is_patched:
|
|
240
|
-
logger.warning("Patched Flux for cache-dit.")
|
|
241
|
-
assert not getattr(transformer, "_is_parallelized", False), (
|
|
242
|
-
"Please call apply_cache_on_pipe before Parallelize, "
|
|
243
|
-
"the __patch_transformer_forward__ will overwrite the "
|
|
244
|
-
"parallized forward and cause a downgrade of performance."
|
|
245
|
-
)
|
|
246
|
-
transformer.forward = __patch_transformer_forward__.__get__(transformer)
|
|
247
|
-
transformer._is_patched = True
|
|
248
|
-
|
|
249
|
-
return transformer
|
cache_dit/compile/utils.py
CHANGED
|
@@ -24,7 +24,7 @@ def epilogue_prologue_fusion_enabled(**kwargs) -> bool:
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def set_compile_configs(
|
|
27
|
-
descent_tuning: bool =
|
|
27
|
+
descent_tuning: bool = False,
|
|
28
28
|
cuda_graphs: bool = False,
|
|
29
29
|
force_disable_compile_caches: bool = False,
|
|
30
30
|
use_fast_math: bool = False,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from cache_dit.quantize.quantize_interface import quantize
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import time
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Callable, Optional, List
|
|
5
|
+
from cache_dit.logger import init_logger
|
|
6
|
+
|
|
7
|
+
logger = init_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def quantize_ao(
|
|
11
|
+
module: torch.nn.Module,
|
|
12
|
+
quant_type: str = "fp8_w8a8_dq",
|
|
13
|
+
exclude_layers: List[str] = [
|
|
14
|
+
"embedder",
|
|
15
|
+
"embed",
|
|
16
|
+
],
|
|
17
|
+
filter_fn: Optional[Callable] = None,
|
|
18
|
+
# paramters for fp8 quantization
|
|
19
|
+
per_row: bool = True,
|
|
20
|
+
**kwargs,
|
|
21
|
+
) -> torch.nn.Module:
|
|
22
|
+
# Apply FP8 DQ for module and skip any `embed` modules
|
|
23
|
+
# by default to avoid non-trivial precision downgrade. Please
|
|
24
|
+
# set `exclude_layers` as `[]` if you don't want this behavior.
|
|
25
|
+
assert isinstance(module, torch.nn.Module)
|
|
26
|
+
|
|
27
|
+
quant_type = quant_type.lower()
|
|
28
|
+
assert quant_type in (
|
|
29
|
+
"fp8_w8a8_dq",
|
|
30
|
+
"fp8_w8a16_wo",
|
|
31
|
+
"int8_w8a8_dq",
|
|
32
|
+
"int8_w8a16_wo",
|
|
33
|
+
"int4_w4a8_dq",
|
|
34
|
+
"int4_w4a4_dq",
|
|
35
|
+
"int4_w4a16_wo",
|
|
36
|
+
), f"{quant_type} is not supported for torchao backend now!"
|
|
37
|
+
|
|
38
|
+
if "fp8" in quant_type:
|
|
39
|
+
assert torch.cuda.get_device_capability() >= (
|
|
40
|
+
8,
|
|
41
|
+
9,
|
|
42
|
+
), "FP8 is not supported for current device."
|
|
43
|
+
|
|
44
|
+
num_quant_linear = 0
|
|
45
|
+
num_skip_linear = 0
|
|
46
|
+
num_linear_layers = 0
|
|
47
|
+
num_layers = 0
|
|
48
|
+
|
|
49
|
+
# Ensure bfloat16 for per_row
|
|
50
|
+
def _filter_fn(m: torch.nn.Module, name: str) -> bool:
|
|
51
|
+
nonlocal num_quant_linear, num_skip_linear, num_linear_layers, num_layers
|
|
52
|
+
num_layers += 1
|
|
53
|
+
if isinstance(m, torch.nn.Linear):
|
|
54
|
+
num_linear_layers += 1
|
|
55
|
+
for exclude_name in exclude_layers:
|
|
56
|
+
if exclude_name in name:
|
|
57
|
+
logger.info(
|
|
58
|
+
f"Skip Quantization: {name} -> "
|
|
59
|
+
f"pattern<{exclude_name}>"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
num_skip_linear += 1
|
|
63
|
+
return False
|
|
64
|
+
|
|
65
|
+
if (
|
|
66
|
+
per_row
|
|
67
|
+
and m.weight.dtype != torch.bfloat16
|
|
68
|
+
and quant_type == "fp8_w8a8_dq"
|
|
69
|
+
):
|
|
70
|
+
logger.info(
|
|
71
|
+
f"Skip Quantization: {name} -> "
|
|
72
|
+
f"pattern<dtype({m.weight.dtype})!=bfloat16>"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
num_skip_linear += 1
|
|
76
|
+
return False
|
|
77
|
+
|
|
78
|
+
num_quant_linear += 1
|
|
79
|
+
return True
|
|
80
|
+
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
def _quantization_fn():
|
|
84
|
+
try:
|
|
85
|
+
if quant_type == "fp8_w8a8_dq":
|
|
86
|
+
from torchao.quantization import (
|
|
87
|
+
float8_dynamic_activation_float8_weight,
|
|
88
|
+
PerTensor,
|
|
89
|
+
PerRow,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
quantization_fn = float8_dynamic_activation_float8_weight(
|
|
93
|
+
weight_dtype=kwargs.get(
|
|
94
|
+
"weight_dtype",
|
|
95
|
+
torch.float8_e4m3fn,
|
|
96
|
+
),
|
|
97
|
+
activation_dtype=kwargs.get(
|
|
98
|
+
"activation_dtype",
|
|
99
|
+
torch.float8_e4m3fn,
|
|
100
|
+
),
|
|
101
|
+
granularity=(
|
|
102
|
+
((PerRow(), PerRow()))
|
|
103
|
+
if per_row
|
|
104
|
+
else ((PerTensor(), PerTensor()))
|
|
105
|
+
),
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
elif quant_type == "fp8_w8a16_wo":
|
|
109
|
+
from torchao.quantization import float8_weight_only
|
|
110
|
+
|
|
111
|
+
quantization_fn = float8_weight_only(
|
|
112
|
+
weight_dtype=kwargs.get(
|
|
113
|
+
"weight_dtype",
|
|
114
|
+
torch.float8_e4m3fn,
|
|
115
|
+
),
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
elif quant_type == "int8_w8a8_dq":
|
|
119
|
+
from torchao.quantization import (
|
|
120
|
+
int8_dynamic_activation_int8_weight,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
quantization_fn = int8_dynamic_activation_int8_weight()
|
|
124
|
+
|
|
125
|
+
elif quant_type == "int8_w8a16_wo":
|
|
126
|
+
from torchao.quantization import int8_weight_only
|
|
127
|
+
|
|
128
|
+
quantization_fn = int8_weight_only(
|
|
129
|
+
# group_size is None -> per_channel, else per group
|
|
130
|
+
group_size=kwargs.get("group_size", None),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
elif quant_type == "int4_w4a8_dq":
|
|
134
|
+
from torchao.quantization import (
|
|
135
|
+
int8_dynamic_activation_int4_weight,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
quantization_fn = int8_dynamic_activation_int4_weight(
|
|
139
|
+
group_size=kwargs.get("group_size", 32),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
elif quant_type == "int4_w4a4_dq":
|
|
143
|
+
from torchao.quantization import (
|
|
144
|
+
int4_dynamic_activation_int4_weight,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
quantization_fn = int4_dynamic_activation_int4_weight()
|
|
148
|
+
|
|
149
|
+
elif quant_type == "int4_w4a16_wo":
|
|
150
|
+
from torchao.quantization import int4_weight_only
|
|
151
|
+
|
|
152
|
+
quantization_fn = int4_weight_only(
|
|
153
|
+
group_size=kwargs.get("group_size", 32),
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
else:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"quant_type: {quant_type} is not supported now!"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
except ImportError as e:
|
|
162
|
+
e.msg += (
|
|
163
|
+
f"{quant_type} is not supported in torchao backend now! "
|
|
164
|
+
"Please upgrade the torchao library."
|
|
165
|
+
)
|
|
166
|
+
raise e
|
|
167
|
+
|
|
168
|
+
return quantization_fn
|
|
169
|
+
|
|
170
|
+
from torchao.quantization import quantize_
|
|
171
|
+
|
|
172
|
+
quantize_(
|
|
173
|
+
module,
|
|
174
|
+
_quantization_fn(),
|
|
175
|
+
filter_fn=_filter_fn if filter_fn is None else filter_fn,
|
|
176
|
+
device=kwargs.get("device", None),
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
force_empty_cache()
|
|
180
|
+
|
|
181
|
+
logger.info(
|
|
182
|
+
f"Quantized Linear Layers: {num_quant_linear:>5}\n"
|
|
183
|
+
f"Skipped Linear Layers: {num_skip_linear:>5}\n"
|
|
184
|
+
f"Total Linear Layers: {num_linear_layers:>5}\n"
|
|
185
|
+
f"Total (all) Layers: {num_layers:>5}"
|
|
186
|
+
)
|
|
187
|
+
return module
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def force_empty_cache():
|
|
191
|
+
time.sleep(1)
|
|
192
|
+
gc.collect()
|
|
193
|
+
torch.cuda.empty_cache()
|
|
194
|
+
time.sleep(1)
|
|
195
|
+
gc.collect()
|
|
196
|
+
torch.cuda.empty_cache()
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Callable, Optional, List
|
|
3
|
+
from cache_dit.logger import init_logger
|
|
4
|
+
|
|
5
|
+
logger = init_logger(__name__)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def quantize(
|
|
9
|
+
module: torch.nn.Module,
|
|
10
|
+
quant_type: str = "fp8_w8a8_dq",
|
|
11
|
+
backend: str = "ao",
|
|
12
|
+
exclude_layers: List[str] = [
|
|
13
|
+
"embedder",
|
|
14
|
+
"embed",
|
|
15
|
+
],
|
|
16
|
+
filter_fn: Optional[Callable] = None,
|
|
17
|
+
# only for fp8_w8a8_dq
|
|
18
|
+
per_row: bool = True,
|
|
19
|
+
**kwargs,
|
|
20
|
+
) -> torch.nn.Module:
|
|
21
|
+
assert isinstance(module, torch.nn.Module)
|
|
22
|
+
|
|
23
|
+
if backend.lower() in ("ao", "torchao"):
|
|
24
|
+
from cache_dit.quantize.quantize_ao import quantize_ao
|
|
25
|
+
|
|
26
|
+
quant_type = quant_type.lower()
|
|
27
|
+
assert quant_type in (
|
|
28
|
+
"fp8_w8a8_dq",
|
|
29
|
+
"fp8_w8a16_wo",
|
|
30
|
+
"int8_w8a8_dq",
|
|
31
|
+
"int8_w8a16_wo",
|
|
32
|
+
"int4_w4a8_dq",
|
|
33
|
+
"int4_w4a4_dq",
|
|
34
|
+
"int4_w4a16_wo",
|
|
35
|
+
), f"{quant_type} is not supported for torchao backend now!"
|
|
36
|
+
|
|
37
|
+
return quantize_ao(
|
|
38
|
+
module,
|
|
39
|
+
quant_type=quant_type,
|
|
40
|
+
per_row=per_row,
|
|
41
|
+
exclude_layers=exclude_layers,
|
|
42
|
+
filter_fn=filter_fn,
|
|
43
|
+
**kwargs,
|
|
44
|
+
)
|
|
45
|
+
else:
|
|
46
|
+
raise ValueError(f"backend: {backend} is not supported now!")
|