cache-dit 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +9 -4
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +16 -3
- cache_dit/cache_factory/block_adapters/__init__.py +538 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
- cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
- cache_dit/cache_factory/cache_adapters.py +121 -563
- cache_dit/cache_factory/cache_blocks/__init__.py +18 -0
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +275 -0
- cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +100 -82
- cache_dit/cache_factory/cache_blocks/utils.py +23 -0
- cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
- cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +94 -56
- cache_dit/cache_factory/cache_interface.py +24 -16
- cache_dit/cache_factory/forward_pattern.py +45 -24
- cache_dit/cache_factory/patch_functors/__init__.py +5 -0
- cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +276 -0
- cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +49 -31
- cache_dit/quantize/quantize_ao.py +19 -4
- cache_dit/quantize/quantize_interface.py +2 -2
- cache_dit/utils.py +19 -15
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/METADATA +76 -19
- cache_dit-0.2.27.dist-info/RECORD +47 -0
- cache_dit-0.2.25.dist-info/RECORD +0 -36
- /cache_dit/cache_factory/{patch/__init__.py → cache_contexts/cache_manager.py} +0 -0
- /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import numpy as np
|
|
5
|
+
from typing import Tuple, Optional, Dict, Any, Union
|
|
6
|
+
from diffusers import ChromaTransformer2DModel
|
|
7
|
+
from diffusers.models.transformers.transformer_chroma import (
|
|
8
|
+
ChromaSingleTransformerBlock,
|
|
9
|
+
Transformer2DModelOutput,
|
|
10
|
+
)
|
|
11
|
+
from diffusers.utils import (
|
|
12
|
+
USE_PEFT_BACKEND,
|
|
13
|
+
scale_lora_layers,
|
|
14
|
+
unscale_lora_layers,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from cache_dit.cache_factory.patch_functors.functor_base import (
|
|
18
|
+
PatchFunctor,
|
|
19
|
+
)
|
|
20
|
+
from cache_dit.logger import init_logger
|
|
21
|
+
|
|
22
|
+
logger = init_logger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ChromaPatchFunctor(PatchFunctor):
|
|
26
|
+
|
|
27
|
+
def apply(
|
|
28
|
+
self,
|
|
29
|
+
transformer: ChromaTransformer2DModel,
|
|
30
|
+
blocks: torch.nn.ModuleList = None,
|
|
31
|
+
**kwargs,
|
|
32
|
+
) -> ChromaTransformer2DModel:
|
|
33
|
+
if getattr(transformer, "_is_patched", False):
|
|
34
|
+
return transformer
|
|
35
|
+
|
|
36
|
+
if blocks is None:
|
|
37
|
+
blocks = transformer.single_transformer_blocks
|
|
38
|
+
|
|
39
|
+
is_patched = False
|
|
40
|
+
for block in blocks:
|
|
41
|
+
if isinstance(block, ChromaSingleTransformerBlock):
|
|
42
|
+
forward_parameters = inspect.signature(
|
|
43
|
+
block.forward
|
|
44
|
+
).parameters.keys()
|
|
45
|
+
if "encoder_hidden_states" not in forward_parameters:
|
|
46
|
+
block.forward = __patch_single_forward__.__get__(block)
|
|
47
|
+
is_patched = True
|
|
48
|
+
|
|
49
|
+
if is_patched:
|
|
50
|
+
logger.warning("Patched Chroma for cache-dit.")
|
|
51
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
52
|
+
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
53
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
54
|
+
"parallized forward and cause a downgrade of performance."
|
|
55
|
+
)
|
|
56
|
+
transformer.forward = __patch_transformer_forward__.__get__(
|
|
57
|
+
transformer
|
|
58
|
+
)
|
|
59
|
+
transformer._is_patched = True
|
|
60
|
+
|
|
61
|
+
cls_name = transformer.__class__.__name__
|
|
62
|
+
logger.info(
|
|
63
|
+
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
64
|
+
f"Patch: {is_patched}."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return transformer
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
|
|
71
|
+
def __patch_single_forward__(
|
|
72
|
+
self: ChromaSingleTransformerBlock, # Almost same as FluxSingleTransformerBlock
|
|
73
|
+
hidden_states: torch.Tensor,
|
|
74
|
+
encoder_hidden_states: torch.Tensor,
|
|
75
|
+
temb: torch.Tensor,
|
|
76
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
77
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
78
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
79
|
+
text_seq_len = encoder_hidden_states.shape[1]
|
|
80
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
81
|
+
|
|
82
|
+
residual = hidden_states
|
|
83
|
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
|
84
|
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
|
85
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
86
|
+
attn_output = self.attn(
|
|
87
|
+
hidden_states=norm_hidden_states,
|
|
88
|
+
image_rotary_emb=image_rotary_emb,
|
|
89
|
+
**joint_attention_kwargs,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
|
93
|
+
gate = gate.unsqueeze(1)
|
|
94
|
+
hidden_states = gate * self.proj_out(hidden_states)
|
|
95
|
+
hidden_states = residual + hidden_states
|
|
96
|
+
if hidden_states.dtype == torch.float16:
|
|
97
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
|
98
|
+
|
|
99
|
+
encoder_hidden_states, hidden_states = (
|
|
100
|
+
hidden_states[:, :text_seq_len],
|
|
101
|
+
hidden_states[:, text_seq_len:],
|
|
102
|
+
)
|
|
103
|
+
return encoder_hidden_states, hidden_states
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
|
|
107
|
+
def __patch_transformer_forward__(
|
|
108
|
+
self: ChromaTransformer2DModel,
|
|
109
|
+
hidden_states: torch.Tensor,
|
|
110
|
+
encoder_hidden_states: torch.Tensor = None,
|
|
111
|
+
timestep: torch.LongTensor = None,
|
|
112
|
+
img_ids: torch.Tensor = None,
|
|
113
|
+
txt_ids: torch.Tensor = None,
|
|
114
|
+
attention_mask: torch.Tensor = None,
|
|
115
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
116
|
+
controlnet_block_samples=None,
|
|
117
|
+
controlnet_single_block_samples=None,
|
|
118
|
+
return_dict: bool = True,
|
|
119
|
+
controlnet_blocks_repeat: bool = False,
|
|
120
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
|
121
|
+
if joint_attention_kwargs is not None:
|
|
122
|
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
|
123
|
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
|
124
|
+
else:
|
|
125
|
+
lora_scale = 1.0
|
|
126
|
+
|
|
127
|
+
if USE_PEFT_BACKEND:
|
|
128
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
129
|
+
scale_lora_layers(self, lora_scale)
|
|
130
|
+
else:
|
|
131
|
+
if (
|
|
132
|
+
joint_attention_kwargs is not None
|
|
133
|
+
and joint_attention_kwargs.get("scale", None) is not None
|
|
134
|
+
):
|
|
135
|
+
logger.warning(
|
|
136
|
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
hidden_states = self.x_embedder(hidden_states)
|
|
140
|
+
|
|
141
|
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
|
142
|
+
|
|
143
|
+
input_vec = self.time_text_embed(timestep)
|
|
144
|
+
pooled_temb = self.distilled_guidance_layer(input_vec)
|
|
145
|
+
|
|
146
|
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
|
147
|
+
|
|
148
|
+
if txt_ids.ndim == 3:
|
|
149
|
+
logger.warning(
|
|
150
|
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
|
151
|
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
|
152
|
+
)
|
|
153
|
+
txt_ids = txt_ids[0]
|
|
154
|
+
if img_ids.ndim == 3:
|
|
155
|
+
logger.warning(
|
|
156
|
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
|
157
|
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
|
158
|
+
)
|
|
159
|
+
img_ids = img_ids[0]
|
|
160
|
+
|
|
161
|
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
|
162
|
+
image_rotary_emb = self.pos_embed(ids)
|
|
163
|
+
|
|
164
|
+
if (
|
|
165
|
+
joint_attention_kwargs is not None
|
|
166
|
+
and "ip_adapter_image_embeds" in joint_attention_kwargs
|
|
167
|
+
):
|
|
168
|
+
ip_adapter_image_embeds = joint_attention_kwargs.pop(
|
|
169
|
+
"ip_adapter_image_embeds"
|
|
170
|
+
)
|
|
171
|
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
|
172
|
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
|
173
|
+
|
|
174
|
+
for index_block, block in enumerate(self.transformer_blocks):
|
|
175
|
+
img_offset = 3 * len(self.single_transformer_blocks)
|
|
176
|
+
txt_offset = img_offset + 6 * len(self.transformer_blocks)
|
|
177
|
+
img_modulation = img_offset + 6 * index_block
|
|
178
|
+
text_modulation = txt_offset + 6 * index_block
|
|
179
|
+
temb = torch.cat(
|
|
180
|
+
(
|
|
181
|
+
pooled_temb[:, img_modulation : img_modulation + 6],
|
|
182
|
+
pooled_temb[:, text_modulation : text_modulation + 6],
|
|
183
|
+
),
|
|
184
|
+
dim=1,
|
|
185
|
+
)
|
|
186
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
187
|
+
encoder_hidden_states, hidden_states = (
|
|
188
|
+
self._gradient_checkpointing_func(
|
|
189
|
+
block,
|
|
190
|
+
hidden_states,
|
|
191
|
+
encoder_hidden_states,
|
|
192
|
+
temb,
|
|
193
|
+
image_rotary_emb,
|
|
194
|
+
attention_mask,
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
else:
|
|
199
|
+
encoder_hidden_states, hidden_states = block(
|
|
200
|
+
hidden_states=hidden_states,
|
|
201
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
202
|
+
temb=temb,
|
|
203
|
+
image_rotary_emb=image_rotary_emb,
|
|
204
|
+
attention_mask=attention_mask,
|
|
205
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# controlnet residual
|
|
209
|
+
if controlnet_block_samples is not None:
|
|
210
|
+
interval_control = len(self.transformer_blocks) / len(
|
|
211
|
+
controlnet_block_samples
|
|
212
|
+
)
|
|
213
|
+
interval_control = int(np.ceil(interval_control))
|
|
214
|
+
# For Xlabs ControlNet.
|
|
215
|
+
if controlnet_blocks_repeat:
|
|
216
|
+
hidden_states = (
|
|
217
|
+
hidden_states
|
|
218
|
+
+ controlnet_block_samples[
|
|
219
|
+
index_block % len(controlnet_block_samples)
|
|
220
|
+
]
|
|
221
|
+
)
|
|
222
|
+
else:
|
|
223
|
+
hidden_states = (
|
|
224
|
+
hidden_states
|
|
225
|
+
+ controlnet_block_samples[index_block // interval_control]
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
|
229
|
+
start_idx = 3 * index_block
|
|
230
|
+
temb = pooled_temb[:, start_idx : start_idx + 3]
|
|
231
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
232
|
+
encoder_hidden_states, hidden_states = (
|
|
233
|
+
self._gradient_checkpointing_func(
|
|
234
|
+
block,
|
|
235
|
+
hidden_states,
|
|
236
|
+
encoder_hidden_states,
|
|
237
|
+
temb,
|
|
238
|
+
image_rotary_emb,
|
|
239
|
+
)
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
else:
|
|
243
|
+
encoder_hidden_states, hidden_states = block(
|
|
244
|
+
hidden_states=hidden_states,
|
|
245
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
246
|
+
temb=temb,
|
|
247
|
+
image_rotary_emb=image_rotary_emb,
|
|
248
|
+
attention_mask=attention_mask,
|
|
249
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# controlnet residual
|
|
253
|
+
if controlnet_single_block_samples is not None:
|
|
254
|
+
interval_control = len(self.single_transformer_blocks) / len(
|
|
255
|
+
controlnet_single_block_samples
|
|
256
|
+
)
|
|
257
|
+
interval_control = int(np.ceil(interval_control))
|
|
258
|
+
hidden_states = (
|
|
259
|
+
hidden_states
|
|
260
|
+
+ controlnet_single_block_samples[
|
|
261
|
+
index_block // interval_control
|
|
262
|
+
]
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
temb = pooled_temb[:, -2:]
|
|
266
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
|
267
|
+
output = self.proj_out(hidden_states)
|
|
268
|
+
|
|
269
|
+
if USE_PEFT_BACKEND:
|
|
270
|
+
# remove `lora_scale` from each PEFT layer
|
|
271
|
+
unscale_lora_layers(self, lora_scale)
|
|
272
|
+
|
|
273
|
+
if not return_dict:
|
|
274
|
+
return (output,)
|
|
275
|
+
|
|
276
|
+
return Transformer2DModelOutput(sample=output)
|
|
@@ -14,12 +14,60 @@ from diffusers.utils import (
|
|
|
14
14
|
unscale_lora_layers,
|
|
15
15
|
)
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
from cache_dit.cache_factory.patch_functors.functor_base import (
|
|
18
|
+
PatchFunctor,
|
|
19
|
+
)
|
|
18
20
|
from cache_dit.logger import init_logger
|
|
19
21
|
|
|
20
22
|
logger = init_logger(__name__)
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
class FluxPatchFunctor(PatchFunctor):
|
|
26
|
+
|
|
27
|
+
def apply(
|
|
28
|
+
self,
|
|
29
|
+
transformer: FluxTransformer2DModel,
|
|
30
|
+
blocks: torch.nn.ModuleList = None,
|
|
31
|
+
**kwargs,
|
|
32
|
+
) -> FluxTransformer2DModel:
|
|
33
|
+
|
|
34
|
+
if getattr(transformer, "_is_patched", False):
|
|
35
|
+
return transformer
|
|
36
|
+
|
|
37
|
+
if blocks is None:
|
|
38
|
+
blocks = transformer.single_transformer_blocks
|
|
39
|
+
|
|
40
|
+
is_patched = False
|
|
41
|
+
for block in blocks:
|
|
42
|
+
if isinstance(block, FluxSingleTransformerBlock):
|
|
43
|
+
forward_parameters = inspect.signature(
|
|
44
|
+
block.forward
|
|
45
|
+
).parameters.keys()
|
|
46
|
+
if "encoder_hidden_states" not in forward_parameters:
|
|
47
|
+
block.forward = __patch_single_forward__.__get__(block)
|
|
48
|
+
is_patched = True
|
|
49
|
+
|
|
50
|
+
if is_patched:
|
|
51
|
+
logger.warning("Patched Flux for cache-dit.")
|
|
52
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
53
|
+
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
54
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
55
|
+
"parallized forward and cause a downgrade of performance."
|
|
56
|
+
)
|
|
57
|
+
transformer.forward = __patch_transformer_forward__.__get__(
|
|
58
|
+
transformer
|
|
59
|
+
)
|
|
60
|
+
transformer._is_patched = True
|
|
61
|
+
|
|
62
|
+
cls_name = transformer.__class__.__name__
|
|
63
|
+
logger.info(
|
|
64
|
+
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
65
|
+
f"Patch: {is_patched}."
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return transformer
|
|
69
|
+
|
|
70
|
+
|
|
23
71
|
# copy from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
|
|
24
72
|
def __patch_single_forward__(
|
|
25
73
|
self: FluxSingleTransformerBlock,
|
|
@@ -217,33 +265,3 @@ def __patch_transformer_forward__(
|
|
|
217
265
|
return (output,)
|
|
218
266
|
|
|
219
267
|
return Transformer2DModelOutput(sample=output)
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
def maybe_patch_flux_transformer(
|
|
223
|
-
transformer: FluxTransformer2DModel,
|
|
224
|
-
blocks: torch.nn.ModuleList = None,
|
|
225
|
-
) -> FluxTransformer2DModel:
|
|
226
|
-
if blocks is None:
|
|
227
|
-
blocks = transformer.single_transformer_blocks
|
|
228
|
-
|
|
229
|
-
is_patched = False
|
|
230
|
-
for block in blocks:
|
|
231
|
-
if isinstance(block, FluxSingleTransformerBlock):
|
|
232
|
-
forward_parameters = inspect.signature(
|
|
233
|
-
block.forward
|
|
234
|
-
).parameters.keys()
|
|
235
|
-
if "encoder_hidden_states" not in forward_parameters:
|
|
236
|
-
block.forward = __patch_single_forward__.__get__(block)
|
|
237
|
-
is_patched = True
|
|
238
|
-
|
|
239
|
-
if is_patched:
|
|
240
|
-
logger.warning("Patched Flux for cache-dit.")
|
|
241
|
-
assert not getattr(transformer, "_is_parallelized", False), (
|
|
242
|
-
"Please call apply_cache_on_pipe before Parallelize, "
|
|
243
|
-
"the __patch_transformer_forward__ will overwrite the "
|
|
244
|
-
"parallized forward and cause a downgrade of performance."
|
|
245
|
-
)
|
|
246
|
-
transformer.forward = __patch_transformer_forward__.__get__(transformer)
|
|
247
|
-
transformer._is_patched = True
|
|
248
|
-
|
|
249
|
-
return transformer
|
|
@@ -10,12 +10,13 @@ logger = init_logger(__name__)
|
|
|
10
10
|
def quantize_ao(
|
|
11
11
|
module: torch.nn.Module,
|
|
12
12
|
quant_type: str = "fp8_w8a8_dq",
|
|
13
|
-
per_row: bool = True,
|
|
14
13
|
exclude_layers: List[str] = [
|
|
15
14
|
"embedder",
|
|
16
15
|
"embed",
|
|
17
16
|
],
|
|
18
17
|
filter_fn: Optional[Callable] = None,
|
|
18
|
+
# paramters for fp8 quantization
|
|
19
|
+
per_row: bool = True,
|
|
19
20
|
**kwargs,
|
|
20
21
|
) -> torch.nn.Module:
|
|
21
22
|
# Apply FP8 DQ for module and skip any `embed` modules
|
|
@@ -89,17 +90,30 @@ def quantize_ao(
|
|
|
89
90
|
)
|
|
90
91
|
|
|
91
92
|
quantization_fn = float8_dynamic_activation_float8_weight(
|
|
93
|
+
weight_dtype=kwargs.get(
|
|
94
|
+
"weight_dtype",
|
|
95
|
+
torch.float8_e4m3fn,
|
|
96
|
+
),
|
|
97
|
+
activation_dtype=kwargs.get(
|
|
98
|
+
"activation_dtype",
|
|
99
|
+
torch.float8_e4m3fn,
|
|
100
|
+
),
|
|
92
101
|
granularity=(
|
|
93
102
|
((PerRow(), PerRow()))
|
|
94
103
|
if per_row
|
|
95
104
|
else ((PerTensor(), PerTensor()))
|
|
96
|
-
)
|
|
105
|
+
),
|
|
97
106
|
)
|
|
98
107
|
|
|
99
108
|
elif quant_type == "fp8_w8a16_wo":
|
|
100
109
|
from torchao.quantization import float8_weight_only
|
|
101
110
|
|
|
102
|
-
quantization_fn = float8_weight_only(
|
|
111
|
+
quantization_fn = float8_weight_only(
|
|
112
|
+
weight_dtype=kwargs.get(
|
|
113
|
+
"weight_dtype",
|
|
114
|
+
torch.float8_e4m3fn,
|
|
115
|
+
),
|
|
116
|
+
)
|
|
103
117
|
|
|
104
118
|
elif quant_type == "int8_w8a8_dq":
|
|
105
119
|
from torchao.quantization import (
|
|
@@ -159,12 +173,13 @@ def quantize_ao(
|
|
|
159
173
|
module,
|
|
160
174
|
_quantization_fn(),
|
|
161
175
|
filter_fn=_filter_fn if filter_fn is None else filter_fn,
|
|
162
|
-
|
|
176
|
+
device=kwargs.get("device", None),
|
|
163
177
|
)
|
|
164
178
|
|
|
165
179
|
force_empty_cache()
|
|
166
180
|
|
|
167
181
|
logger.info(
|
|
182
|
+
f"Quantized Method: {quant_type:>5}\n"
|
|
168
183
|
f"Quantized Linear Layers: {num_quant_linear:>5}\n"
|
|
169
184
|
f"Skipped Linear Layers: {num_skip_linear:>5}\n"
|
|
170
185
|
f"Total Linear Layers: {num_linear_layers:>5}\n"
|
|
@@ -9,13 +9,13 @@ def quantize(
|
|
|
9
9
|
module: torch.nn.Module,
|
|
10
10
|
quant_type: str = "fp8_w8a8_dq",
|
|
11
11
|
backend: str = "ao",
|
|
12
|
-
# only for fp8_w8a8_dq
|
|
13
|
-
per_row: bool = True,
|
|
14
12
|
exclude_layers: List[str] = [
|
|
15
13
|
"embedder",
|
|
16
14
|
"embed",
|
|
17
15
|
],
|
|
18
16
|
filter_fn: Optional[Callable] = None,
|
|
17
|
+
# only for fp8_w8a8_dq
|
|
18
|
+
per_row: bool = True,
|
|
19
19
|
**kwargs,
|
|
20
20
|
) -> torch.nn.Module:
|
|
21
21
|
assert isinstance(module, torch.nn.Module)
|
cache_dit/utils.py
CHANGED
|
@@ -30,26 +30,32 @@ class CacheStats:
|
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
def summary(
|
|
33
|
-
|
|
33
|
+
pipe_or_module: DiffusionPipeline | torch.nn.Module | Any,
|
|
34
34
|
details: bool = False,
|
|
35
35
|
logging: bool = True,
|
|
36
36
|
) -> CacheStats:
|
|
37
37
|
cache_stats = CacheStats()
|
|
38
|
-
|
|
39
|
-
if isinstance(
|
|
40
|
-
|
|
38
|
+
|
|
39
|
+
if not isinstance(pipe_or_module, torch.nn.Module):
|
|
40
|
+
assert hasattr(pipe_or_module, "transformer")
|
|
41
|
+
module = pipe_or_module.transformer
|
|
42
|
+
cls_name = module.__class__.__name__
|
|
41
43
|
else:
|
|
42
|
-
|
|
44
|
+
module = pipe_or_module
|
|
45
|
+
|
|
46
|
+
cls_name = module.__class__.__name__
|
|
47
|
+
if isinstance(module, torch.nn.ModuleList):
|
|
48
|
+
cls_name = module[0].__class__.__name__
|
|
43
49
|
|
|
44
|
-
if hasattr(
|
|
45
|
-
cache_options =
|
|
50
|
+
if hasattr(module, "_cache_context_kwargs"):
|
|
51
|
+
cache_options = module._cache_context_kwargs
|
|
46
52
|
cache_stats.cache_options = cache_options
|
|
47
53
|
if logging:
|
|
48
54
|
print(f"\n🤗Cache Options: {cls_name}\n\n{cache_options}")
|
|
49
55
|
|
|
50
|
-
if hasattr(
|
|
51
|
-
cached_steps: list[int] =
|
|
52
|
-
residual_diffs: dict[str, float] = dict(
|
|
56
|
+
if hasattr(module, "_cached_steps"):
|
|
57
|
+
cached_steps: list[int] = module._cached_steps
|
|
58
|
+
residual_diffs: dict[str, float] = dict(module._residual_diffs)
|
|
53
59
|
cache_stats.cached_steps = cached_steps
|
|
54
60
|
cache_stats.residual_diffs = residual_diffs
|
|
55
61
|
|
|
@@ -90,11 +96,9 @@ def summary(
|
|
|
90
96
|
compact=True,
|
|
91
97
|
)
|
|
92
98
|
|
|
93
|
-
if hasattr(
|
|
94
|
-
cfg_cached_steps: list[int] =
|
|
95
|
-
cfg_residual_diffs: dict[str, float] = dict(
|
|
96
|
-
transformer._cfg_residual_diffs
|
|
97
|
-
)
|
|
99
|
+
if hasattr(module, "_cfg_cached_steps"):
|
|
100
|
+
cfg_cached_steps: list[int] = module._cfg_cached_steps
|
|
101
|
+
cfg_residual_diffs: dict[str, float] = dict(module._cfg_residual_diffs)
|
|
98
102
|
cache_stats.cfg_cached_steps = cfg_cached_steps
|
|
99
103
|
cache_stats.cfg_residual_diffs = cfg_residual_diffs
|
|
100
104
|
|