cache-dit 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/__init__.py +95 -61
- cache_dit/cache_factory/block_adapters/block_adapters.py +27 -6
- cache_dit/cache_factory/block_adapters/block_registers.py +10 -7
- cache_dit/cache_factory/cache_adapters.py +177 -66
- cache_dit/cache_factory/cache_blocks/__init__.py +3 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +70 -67
- cache_dit/cache_factory/cache_blocks/pattern_base.py +13 -0
- cache_dit/cache_factory/cache_contexts/cache_manager.py +8 -10
- cache_dit/cache_factory/cache_interface.py +19 -77
- cache_dit/cache_factory/cache_types.py +5 -5
- cache_dit/cache_factory/patch_functors/__init__.py +6 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +5 -3
- cache_dit/cache_factory/patch_functors/functor_flux.py +5 -3
- cache_dit/cache_factory/patch_functors/functor_hidream.py +412 -0
- cache_dit/cache_factory/patch_functors/functor_hunyuan_dit.py +213 -0
- cache_dit/utils.py +5 -1
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/METADATA +14 -48
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/RECORD +23 -21
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Tuple, Optional, Dict, Any, Union, List
|
|
3
|
+
from diffusers import HiDreamImageTransformer2DModel
|
|
4
|
+
from diffusers.models.transformers.transformer_hidream_image import (
|
|
5
|
+
HiDreamBlock,
|
|
6
|
+
HiDreamImageTransformerBlock,
|
|
7
|
+
HiDreamImageSingleTransformerBlock,
|
|
8
|
+
Transformer2DModelOutput,
|
|
9
|
+
)
|
|
10
|
+
from diffusers.utils import (
|
|
11
|
+
deprecate,
|
|
12
|
+
USE_PEFT_BACKEND,
|
|
13
|
+
scale_lora_layers,
|
|
14
|
+
unscale_lora_layers,
|
|
15
|
+
)
|
|
16
|
+
from cache_dit.cache_factory.patch_functors.functor_base import (
|
|
17
|
+
PatchFunctor,
|
|
18
|
+
)
|
|
19
|
+
from cache_dit.logger import init_logger
|
|
20
|
+
|
|
21
|
+
logger = init_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class HiDreamPatchFunctor(PatchFunctor):
|
|
25
|
+
|
|
26
|
+
def apply(
|
|
27
|
+
self,
|
|
28
|
+
transformer: HiDreamImageTransformer2DModel,
|
|
29
|
+
**kwargs,
|
|
30
|
+
) -> HiDreamImageTransformer2DModel:
|
|
31
|
+
if hasattr(transformer, "_is_patched"):
|
|
32
|
+
return transformer
|
|
33
|
+
|
|
34
|
+
is_patched = False
|
|
35
|
+
|
|
36
|
+
_block_id = 0
|
|
37
|
+
for block in transformer.double_stream_blocks:
|
|
38
|
+
assert isinstance(block, HiDreamBlock)
|
|
39
|
+
block.forward = __patch_block_forward__.__get__(block)
|
|
40
|
+
# NOTE: Patch Inner block and block_id
|
|
41
|
+
_block = block.block
|
|
42
|
+
assert isinstance(_block, HiDreamImageTransformerBlock)
|
|
43
|
+
_block._block_id = _block_id
|
|
44
|
+
_block.forward = __patch_double_forward__.__get__(_block)
|
|
45
|
+
_block_id += 1
|
|
46
|
+
|
|
47
|
+
for block in transformer.single_stream_blocks:
|
|
48
|
+
assert isinstance(block, HiDreamBlock)
|
|
49
|
+
block.forward = __patch_block_forward__.__get__(block)
|
|
50
|
+
# NOTE: Patch Inner block and block_id
|
|
51
|
+
_block = block.block
|
|
52
|
+
assert isinstance(_block, HiDreamImageSingleTransformerBlock)
|
|
53
|
+
_block._block_id = _block_id
|
|
54
|
+
_block.forward = __patch_single_forward__.__get__(_block)
|
|
55
|
+
_block_id += 1
|
|
56
|
+
|
|
57
|
+
is_patched = True
|
|
58
|
+
cls_name = transformer.__class__.__name__
|
|
59
|
+
|
|
60
|
+
if is_patched:
|
|
61
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
62
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
63
|
+
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
64
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
65
|
+
"parallized forward and cause a downgrade of performance."
|
|
66
|
+
)
|
|
67
|
+
transformer.forward = __patch_transformer_forward__.__get__(
|
|
68
|
+
transformer
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
transformer._is_patched = is_patched # True or False
|
|
72
|
+
|
|
73
|
+
logger.info(
|
|
74
|
+
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
75
|
+
f"Patch: {is_patched}."
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return transformer
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
|
|
82
|
+
def __patch_double_forward__(
|
|
83
|
+
self: HiDreamImageTransformerBlock,
|
|
84
|
+
hidden_states: torch.Tensor,
|
|
85
|
+
encoder_hidden_states: torch.Tensor, # initial_encoder_hidden_states
|
|
86
|
+
hidden_states_masks: Optional[torch.Tensor] = None,
|
|
87
|
+
temb: Optional[torch.Tensor] = None,
|
|
88
|
+
image_rotary_emb: torch.Tensor = None,
|
|
89
|
+
llama31_encoder_hidden_states: List[torch.Tensor] = None,
|
|
90
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
91
|
+
# Assume block_id was patched in transformer forward:
|
|
92
|
+
# for i, block in enumerate(blocks): block._block_id = i;
|
|
93
|
+
block_id = self._block_id
|
|
94
|
+
initial_encoder_hidden_states_seq_len = encoder_hidden_states.shape[1]
|
|
95
|
+
cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[block_id]
|
|
96
|
+
cur_encoder_hidden_states = torch.cat(
|
|
97
|
+
[encoder_hidden_states, cur_llama31_encoder_hidden_states],
|
|
98
|
+
dim=1,
|
|
99
|
+
)
|
|
100
|
+
encoder_hidden_states = cur_encoder_hidden_states
|
|
101
|
+
|
|
102
|
+
wtype = hidden_states.dtype
|
|
103
|
+
(
|
|
104
|
+
shift_msa_i,
|
|
105
|
+
scale_msa_i,
|
|
106
|
+
gate_msa_i,
|
|
107
|
+
shift_mlp_i,
|
|
108
|
+
scale_mlp_i,
|
|
109
|
+
gate_mlp_i,
|
|
110
|
+
shift_msa_t,
|
|
111
|
+
scale_msa_t,
|
|
112
|
+
gate_msa_t,
|
|
113
|
+
shift_mlp_t,
|
|
114
|
+
scale_mlp_t,
|
|
115
|
+
gate_mlp_t,
|
|
116
|
+
) = self.adaLN_modulation(temb)[:, None].chunk(12, dim=-1)
|
|
117
|
+
|
|
118
|
+
# 1. MM-Attention
|
|
119
|
+
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
|
|
120
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
|
|
121
|
+
norm_encoder_hidden_states = self.norm1_t(encoder_hidden_states).to(
|
|
122
|
+
dtype=wtype
|
|
123
|
+
)
|
|
124
|
+
norm_encoder_hidden_states = (
|
|
125
|
+
norm_encoder_hidden_states * (1 + scale_msa_t) + shift_msa_t
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
attn_output_i, attn_output_t = self.attn1(
|
|
129
|
+
norm_hidden_states,
|
|
130
|
+
hidden_states_masks,
|
|
131
|
+
norm_encoder_hidden_states,
|
|
132
|
+
image_rotary_emb=image_rotary_emb,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
hidden_states = gate_msa_i * attn_output_i + hidden_states
|
|
136
|
+
encoder_hidden_states = gate_msa_t * attn_output_t + encoder_hidden_states
|
|
137
|
+
|
|
138
|
+
# 2. Feed-forward
|
|
139
|
+
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
|
|
140
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
|
|
141
|
+
norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(
|
|
142
|
+
dtype=wtype
|
|
143
|
+
)
|
|
144
|
+
norm_encoder_hidden_states = (
|
|
145
|
+
norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
|
|
149
|
+
ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
|
|
150
|
+
hidden_states = ff_output_i + hidden_states
|
|
151
|
+
encoder_hidden_states = ff_output_t + encoder_hidden_states
|
|
152
|
+
|
|
153
|
+
initial_encoder_hidden_states = encoder_hidden_states[
|
|
154
|
+
:, :initial_encoder_hidden_states_seq_len
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
return hidden_states, initial_encoder_hidden_states
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
|
|
161
|
+
def __patch_single_forward__(
|
|
162
|
+
self: HiDreamImageSingleTransformerBlock,
|
|
163
|
+
hidden_states: torch.Tensor,
|
|
164
|
+
hidden_states_masks: Optional[torch.Tensor] = None,
|
|
165
|
+
temb: Optional[torch.Tensor] = None,
|
|
166
|
+
image_rotary_emb: torch.Tensor = None,
|
|
167
|
+
llama31_encoder_hidden_states: List[torch.Tensor] = None,
|
|
168
|
+
) -> torch.Tensor:
|
|
169
|
+
# Assume block_id was patched in transformer forward:
|
|
170
|
+
# for i, block in enumerate(blocks): block._block_id = i;
|
|
171
|
+
block_id = self._block_id
|
|
172
|
+
hidden_states_seq_len = hidden_states.shape[1]
|
|
173
|
+
cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[block_id]
|
|
174
|
+
hidden_states = torch.cat(
|
|
175
|
+
[hidden_states, cur_llama31_encoder_hidden_states], dim=1
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
wtype = hidden_states.dtype
|
|
179
|
+
(
|
|
180
|
+
shift_msa_i,
|
|
181
|
+
scale_msa_i,
|
|
182
|
+
gate_msa_i,
|
|
183
|
+
shift_mlp_i,
|
|
184
|
+
scale_mlp_i,
|
|
185
|
+
gate_mlp_i,
|
|
186
|
+
) = self.adaLN_modulation(temb)[:, None].chunk(6, dim=-1)
|
|
187
|
+
|
|
188
|
+
# 1. MM-Attention
|
|
189
|
+
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
|
|
190
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
|
|
191
|
+
attn_output_i = self.attn1(
|
|
192
|
+
norm_hidden_states,
|
|
193
|
+
hidden_states_masks,
|
|
194
|
+
image_rotary_emb=image_rotary_emb,
|
|
195
|
+
)
|
|
196
|
+
hidden_states = gate_msa_i * attn_output_i + hidden_states
|
|
197
|
+
|
|
198
|
+
# 2. Feed-forward
|
|
199
|
+
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
|
|
200
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
|
|
201
|
+
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
|
|
202
|
+
hidden_states = ff_output_i + hidden_states
|
|
203
|
+
|
|
204
|
+
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
|
205
|
+
|
|
206
|
+
return hidden_states
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
|
|
210
|
+
def __patch_block_forward__(
|
|
211
|
+
self: HiDreamBlock,
|
|
212
|
+
hidden_states: torch.Tensor,
|
|
213
|
+
*args,
|
|
214
|
+
**kwargs,
|
|
215
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
216
|
+
return self.block(hidden_states, *args, **kwargs)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
|
|
220
|
+
def __patch_transformer_forward__(
|
|
221
|
+
self: HiDreamImageTransformer2DModel,
|
|
222
|
+
hidden_states: torch.Tensor,
|
|
223
|
+
timesteps: torch.LongTensor = None,
|
|
224
|
+
encoder_hidden_states_t5: torch.Tensor = None,
|
|
225
|
+
encoder_hidden_states_llama3: torch.Tensor = None,
|
|
226
|
+
pooled_embeds: torch.Tensor = None,
|
|
227
|
+
img_ids: Optional[torch.Tensor] = None,
|
|
228
|
+
img_sizes: Optional[List[Tuple[int, int]]] = None,
|
|
229
|
+
hidden_states_masks: Optional[torch.Tensor] = None,
|
|
230
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
231
|
+
return_dict: bool = True,
|
|
232
|
+
**kwargs,
|
|
233
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
|
234
|
+
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
|
235
|
+
|
|
236
|
+
if encoder_hidden_states is not None:
|
|
237
|
+
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
|
|
238
|
+
deprecate("encoder_hidden_states", "0.35.0", deprecation_message)
|
|
239
|
+
encoder_hidden_states_t5 = encoder_hidden_states[0]
|
|
240
|
+
encoder_hidden_states_llama3 = encoder_hidden_states[1]
|
|
241
|
+
|
|
242
|
+
if (
|
|
243
|
+
img_ids is not None
|
|
244
|
+
and img_sizes is not None
|
|
245
|
+
and hidden_states_masks is None
|
|
246
|
+
):
|
|
247
|
+
deprecation_message = "Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
|
|
248
|
+
deprecate("img_ids", "0.35.0", deprecation_message)
|
|
249
|
+
|
|
250
|
+
if hidden_states_masks is not None and (
|
|
251
|
+
img_ids is None or img_sizes is None
|
|
252
|
+
):
|
|
253
|
+
raise ValueError(
|
|
254
|
+
"if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed."
|
|
255
|
+
)
|
|
256
|
+
elif hidden_states_masks is not None and hidden_states.ndim != 3:
|
|
257
|
+
raise ValueError(
|
|
258
|
+
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if attention_kwargs is not None:
|
|
262
|
+
attention_kwargs = attention_kwargs.copy()
|
|
263
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
|
264
|
+
else:
|
|
265
|
+
lora_scale = 1.0
|
|
266
|
+
|
|
267
|
+
if USE_PEFT_BACKEND:
|
|
268
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
269
|
+
scale_lora_layers(self, lora_scale)
|
|
270
|
+
else:
|
|
271
|
+
if (
|
|
272
|
+
attention_kwargs is not None
|
|
273
|
+
and attention_kwargs.get("scale", None) is not None
|
|
274
|
+
):
|
|
275
|
+
logger.warning(
|
|
276
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# spatial forward
|
|
280
|
+
batch_size = hidden_states.shape[0]
|
|
281
|
+
hidden_states_type = hidden_states.dtype
|
|
282
|
+
|
|
283
|
+
# Patchify the input
|
|
284
|
+
if hidden_states_masks is None:
|
|
285
|
+
hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(
|
|
286
|
+
hidden_states
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Embed the hidden states
|
|
290
|
+
hidden_states = self.x_embedder(hidden_states)
|
|
291
|
+
|
|
292
|
+
# 0. time
|
|
293
|
+
timesteps = self.t_embedder(timesteps, hidden_states_type)
|
|
294
|
+
p_embedder = self.p_embedder(pooled_embeds)
|
|
295
|
+
temb = timesteps + p_embedder
|
|
296
|
+
|
|
297
|
+
encoder_hidden_states = [
|
|
298
|
+
encoder_hidden_states_llama3[k] for k in self.config.llama_layers
|
|
299
|
+
]
|
|
300
|
+
|
|
301
|
+
if self.caption_projection is not None:
|
|
302
|
+
new_encoder_hidden_states = []
|
|
303
|
+
for i, enc_hidden_state in enumerate(encoder_hidden_states):
|
|
304
|
+
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
|
|
305
|
+
enc_hidden_state = enc_hidden_state.view(
|
|
306
|
+
batch_size, -1, hidden_states.shape[-1]
|
|
307
|
+
)
|
|
308
|
+
new_encoder_hidden_states.append(enc_hidden_state)
|
|
309
|
+
encoder_hidden_states = new_encoder_hidden_states
|
|
310
|
+
encoder_hidden_states_t5 = self.caption_projection[-1](
|
|
311
|
+
encoder_hidden_states_t5
|
|
312
|
+
)
|
|
313
|
+
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(
|
|
314
|
+
batch_size, -1, hidden_states.shape[-1]
|
|
315
|
+
)
|
|
316
|
+
encoder_hidden_states.append(encoder_hidden_states_t5)
|
|
317
|
+
|
|
318
|
+
txt_ids = torch.zeros(
|
|
319
|
+
batch_size,
|
|
320
|
+
encoder_hidden_states[-1].shape[1]
|
|
321
|
+
+ encoder_hidden_states[-2].shape[1]
|
|
322
|
+
+ encoder_hidden_states[0].shape[1],
|
|
323
|
+
3,
|
|
324
|
+
device=img_ids.device,
|
|
325
|
+
dtype=img_ids.dtype,
|
|
326
|
+
)
|
|
327
|
+
ids = torch.cat((img_ids, txt_ids), dim=1)
|
|
328
|
+
image_rotary_emb = self.pe_embedder(ids)
|
|
329
|
+
|
|
330
|
+
# 2. Blocks
|
|
331
|
+
# NOTE: block_id is no-need anymore.
|
|
332
|
+
initial_encoder_hidden_states = torch.cat(
|
|
333
|
+
[encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1
|
|
334
|
+
)
|
|
335
|
+
llama31_encoder_hidden_states = encoder_hidden_states
|
|
336
|
+
for bid, block in enumerate(self.double_stream_blocks):
|
|
337
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
338
|
+
hidden_states, initial_encoder_hidden_states = (
|
|
339
|
+
self._gradient_checkpointing_func(
|
|
340
|
+
block,
|
|
341
|
+
hidden_states,
|
|
342
|
+
initial_encoder_hidden_states,
|
|
343
|
+
hidden_states_masks,
|
|
344
|
+
temb,
|
|
345
|
+
image_rotary_emb,
|
|
346
|
+
llama31_encoder_hidden_states,
|
|
347
|
+
)
|
|
348
|
+
)
|
|
349
|
+
else:
|
|
350
|
+
hidden_states, initial_encoder_hidden_states = block(
|
|
351
|
+
hidden_states,
|
|
352
|
+
initial_encoder_hidden_states, # encoder_hidden_states
|
|
353
|
+
hidden_states_masks=hidden_states_masks,
|
|
354
|
+
temb=temb,
|
|
355
|
+
image_rotary_emb=image_rotary_emb,
|
|
356
|
+
llama31_encoder_hidden_states=llama31_encoder_hidden_states,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
image_tokens_seq_len = hidden_states.shape[1]
|
|
360
|
+
hidden_states = torch.cat(
|
|
361
|
+
[hidden_states, initial_encoder_hidden_states], dim=1
|
|
362
|
+
)
|
|
363
|
+
if hidden_states_masks is not None:
|
|
364
|
+
# NOTE: Patched
|
|
365
|
+
cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[
|
|
366
|
+
self.double_stream_blocks[-1].block._block_id
|
|
367
|
+
]
|
|
368
|
+
encoder_attention_mask_ones = torch.ones(
|
|
369
|
+
(
|
|
370
|
+
batch_size,
|
|
371
|
+
initial_encoder_hidden_states.shape[1]
|
|
372
|
+
+ cur_llama31_encoder_hidden_states.shape[1],
|
|
373
|
+
),
|
|
374
|
+
device=hidden_states_masks.device,
|
|
375
|
+
dtype=hidden_states_masks.dtype,
|
|
376
|
+
)
|
|
377
|
+
hidden_states_masks = torch.cat(
|
|
378
|
+
[hidden_states_masks, encoder_attention_mask_ones], dim=1
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
for bid, block in enumerate(self.single_stream_blocks):
|
|
382
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
383
|
+
hidden_states = self._gradient_checkpointing_func(
|
|
384
|
+
block,
|
|
385
|
+
hidden_states,
|
|
386
|
+
hidden_states_masks,
|
|
387
|
+
temb,
|
|
388
|
+
image_rotary_emb,
|
|
389
|
+
llama31_encoder_hidden_states,
|
|
390
|
+
)
|
|
391
|
+
else:
|
|
392
|
+
hidden_states = block(
|
|
393
|
+
hidden_states,
|
|
394
|
+
hidden_states_masks=hidden_states_masks,
|
|
395
|
+
temb=temb,
|
|
396
|
+
image_rotary_emb=image_rotary_emb,
|
|
397
|
+
llama31_encoder_hidden_states=llama31_encoder_hidden_states,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
|
401
|
+
output = self.final_layer(hidden_states, temb)
|
|
402
|
+
output = self.unpatchify(output, img_sizes, self.training)
|
|
403
|
+
if hidden_states_masks is not None:
|
|
404
|
+
hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
|
|
405
|
+
|
|
406
|
+
if USE_PEFT_BACKEND:
|
|
407
|
+
# remove `lora_scale` from each PEFT layer
|
|
408
|
+
unscale_lora_layers(self, lora_scale)
|
|
409
|
+
|
|
410
|
+
if not return_dict:
|
|
411
|
+
return (output,)
|
|
412
|
+
return Transformer2DModelOutput(sample=output)
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional, Union, List
|
|
3
|
+
from diffusers import HunyuanDiT2DModel
|
|
4
|
+
from diffusers.models.transformers.hunyuan_transformer_2d import (
|
|
5
|
+
HunyuanDiTBlock,
|
|
6
|
+
Transformer2DModelOutput,
|
|
7
|
+
)
|
|
8
|
+
from cache_dit.cache_factory.patch_functors.functor_base import (
|
|
9
|
+
PatchFunctor,
|
|
10
|
+
)
|
|
11
|
+
from cache_dit.logger import init_logger
|
|
12
|
+
|
|
13
|
+
logger = init_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HunyuanDiTPatchFunctor(PatchFunctor):
|
|
17
|
+
|
|
18
|
+
def apply(
|
|
19
|
+
self,
|
|
20
|
+
transformer: HunyuanDiT2DModel,
|
|
21
|
+
**kwargs,
|
|
22
|
+
) -> HunyuanDiT2DModel:
|
|
23
|
+
if hasattr(transformer, "_is_patched"):
|
|
24
|
+
return transformer
|
|
25
|
+
|
|
26
|
+
is_patched = False
|
|
27
|
+
|
|
28
|
+
num_layers = transformer.config.num_layers
|
|
29
|
+
layer_id = 0
|
|
30
|
+
for block in transformer.blocks:
|
|
31
|
+
assert isinstance(block, HunyuanDiTBlock)
|
|
32
|
+
block._num_layers = num_layers
|
|
33
|
+
block._layer_id = layer_id
|
|
34
|
+
block.forward = __patch_block_forward__.__get__(block)
|
|
35
|
+
layer_id += 1
|
|
36
|
+
|
|
37
|
+
is_patched = True
|
|
38
|
+
|
|
39
|
+
cls_name = transformer.__class__.__name__
|
|
40
|
+
|
|
41
|
+
if is_patched:
|
|
42
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
43
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
44
|
+
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
45
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
46
|
+
"parallized forward and cause a downgrade of performance."
|
|
47
|
+
)
|
|
48
|
+
transformer.forward = __patch_transformer_forward__.__get__(
|
|
49
|
+
transformer
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
transformer._is_patched = is_patched # True or False
|
|
53
|
+
|
|
54
|
+
logger.info(
|
|
55
|
+
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
56
|
+
f"Patch: {is_patched}."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
return transformer
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def __patch_block_forward__(
|
|
63
|
+
self: HunyuanDiTBlock,
|
|
64
|
+
hidden_states: torch.Tensor,
|
|
65
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
66
|
+
temb: Optional[torch.Tensor] = None,
|
|
67
|
+
image_rotary_emb: torch.Tensor = None,
|
|
68
|
+
controlnet_block_samples: torch.Tensor = None,
|
|
69
|
+
skips: List[torch.Tensor] = [],
|
|
70
|
+
) -> torch.Tensor:
|
|
71
|
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
|
72
|
+
# 0. Long Skip Connection
|
|
73
|
+
num_layers = self._num_layers
|
|
74
|
+
layer_id = self._layer_id
|
|
75
|
+
|
|
76
|
+
if layer_id > num_layers // 2:
|
|
77
|
+
if controlnet_block_samples is not None:
|
|
78
|
+
skip = skips.pop() + controlnet_block_samples.pop()
|
|
79
|
+
else:
|
|
80
|
+
skip = skips.pop()
|
|
81
|
+
else:
|
|
82
|
+
skip = None
|
|
83
|
+
|
|
84
|
+
if self.skip_linear is not None:
|
|
85
|
+
cat = torch.cat([hidden_states, skip], dim=-1)
|
|
86
|
+
cat = self.skip_norm(cat)
|
|
87
|
+
hidden_states = self.skip_linear(cat)
|
|
88
|
+
|
|
89
|
+
# 1. Self-Attention
|
|
90
|
+
norm_hidden_states = self.norm1(
|
|
91
|
+
hidden_states, temb
|
|
92
|
+
) # checked: self.norm1 is correct
|
|
93
|
+
attn_output = self.attn1(
|
|
94
|
+
norm_hidden_states,
|
|
95
|
+
image_rotary_emb=image_rotary_emb,
|
|
96
|
+
)
|
|
97
|
+
hidden_states = hidden_states + attn_output
|
|
98
|
+
|
|
99
|
+
# 2. Cross-Attention
|
|
100
|
+
hidden_states = hidden_states + self.attn2(
|
|
101
|
+
self.norm2(hidden_states),
|
|
102
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
103
|
+
image_rotary_emb=image_rotary_emb,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# FFN Layer
|
|
107
|
+
mlp_inputs = self.norm3(hidden_states)
|
|
108
|
+
hidden_states = hidden_states + self.ff(mlp_inputs)
|
|
109
|
+
|
|
110
|
+
if layer_id < (num_layers // 2 - 1):
|
|
111
|
+
skips.append(hidden_states)
|
|
112
|
+
|
|
113
|
+
return hidden_states
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def __patch_transformer_forward__(
|
|
117
|
+
self: HunyuanDiT2DModel,
|
|
118
|
+
hidden_states,
|
|
119
|
+
timestep,
|
|
120
|
+
encoder_hidden_states=None,
|
|
121
|
+
text_embedding_mask=None,
|
|
122
|
+
encoder_hidden_states_t5=None,
|
|
123
|
+
text_embedding_mask_t5=None,
|
|
124
|
+
image_meta_size=None,
|
|
125
|
+
style=None,
|
|
126
|
+
image_rotary_emb=None,
|
|
127
|
+
controlnet_block_samples=None,
|
|
128
|
+
return_dict=True,
|
|
129
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
|
130
|
+
height, width = hidden_states.shape[-2:]
|
|
131
|
+
|
|
132
|
+
hidden_states = self.pos_embed(hidden_states)
|
|
133
|
+
|
|
134
|
+
temb = self.time_extra_emb(
|
|
135
|
+
timestep,
|
|
136
|
+
encoder_hidden_states_t5,
|
|
137
|
+
image_meta_size,
|
|
138
|
+
style,
|
|
139
|
+
hidden_dtype=timestep.dtype,
|
|
140
|
+
) # [B, D]
|
|
141
|
+
|
|
142
|
+
# text projection
|
|
143
|
+
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
|
|
144
|
+
encoder_hidden_states_t5 = self.text_embedder(
|
|
145
|
+
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
|
|
146
|
+
)
|
|
147
|
+
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(
|
|
148
|
+
batch_size, sequence_length, -1
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
encoder_hidden_states = torch.cat(
|
|
152
|
+
[encoder_hidden_states, encoder_hidden_states_t5], dim=1
|
|
153
|
+
)
|
|
154
|
+
text_embedding_mask = torch.cat(
|
|
155
|
+
[text_embedding_mask, text_embedding_mask_t5], dim=-1
|
|
156
|
+
)
|
|
157
|
+
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
|
|
158
|
+
|
|
159
|
+
encoder_hidden_states = torch.where(
|
|
160
|
+
text_embedding_mask, encoder_hidden_states, self.text_embedding_padding
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
skips = []
|
|
164
|
+
for layer, block in enumerate(self.blocks):
|
|
165
|
+
hidden_states = block(
|
|
166
|
+
hidden_states,
|
|
167
|
+
temb=temb,
|
|
168
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
169
|
+
image_rotary_emb=image_rotary_emb,
|
|
170
|
+
controlnet_block_samples=controlnet_block_samples,
|
|
171
|
+
skips=skips,
|
|
172
|
+
) # (N, L, D)
|
|
173
|
+
|
|
174
|
+
if (
|
|
175
|
+
controlnet_block_samples is not None
|
|
176
|
+
and len(controlnet_block_samples) != 0
|
|
177
|
+
):
|
|
178
|
+
raise ValueError(
|
|
179
|
+
"The number of controls is not equal to the number of skip connections."
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# final layer
|
|
183
|
+
hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
|
|
184
|
+
hidden_states = self.proj_out(hidden_states)
|
|
185
|
+
# (N, L, patch_size ** 2 * out_channels)
|
|
186
|
+
|
|
187
|
+
# unpatchify: (N, out_channels, H, W)
|
|
188
|
+
patch_size = self.pos_embed.patch_size
|
|
189
|
+
height = height // patch_size
|
|
190
|
+
width = width // patch_size
|
|
191
|
+
|
|
192
|
+
hidden_states = hidden_states.reshape(
|
|
193
|
+
shape=(
|
|
194
|
+
hidden_states.shape[0],
|
|
195
|
+
height,
|
|
196
|
+
width,
|
|
197
|
+
patch_size,
|
|
198
|
+
patch_size,
|
|
199
|
+
self.out_channels,
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
|
203
|
+
output = hidden_states.reshape(
|
|
204
|
+
shape=(
|
|
205
|
+
hidden_states.shape[0],
|
|
206
|
+
self.out_channels,
|
|
207
|
+
height * patch_size,
|
|
208
|
+
width * patch_size,
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
if not return_dict:
|
|
212
|
+
return (output,)
|
|
213
|
+
return Transformer2DModelOutput(sample=output)
|
cache_dit/utils.py
CHANGED
|
@@ -52,6 +52,9 @@ def summary(
|
|
|
52
52
|
if hasattr(adapter_or_others, "transformer_2"):
|
|
53
53
|
transformer_2 = adapter_or_others.transformer_2
|
|
54
54
|
|
|
55
|
+
if not BlockAdapter.is_cached(transformer):
|
|
56
|
+
return [CacheStats()]
|
|
57
|
+
|
|
55
58
|
blocks_stats: List[CacheStats] = []
|
|
56
59
|
for blocks in BlockAdapter.find_blocks(transformer):
|
|
57
60
|
blocks_stats.append(
|
|
@@ -212,7 +215,8 @@ def _summary(
|
|
|
212
215
|
if logging:
|
|
213
216
|
print(f"\n🤗Cache Options: {cls_name}\n\n{cache_options}")
|
|
214
217
|
else:
|
|
215
|
-
|
|
218
|
+
if logging:
|
|
219
|
+
logger.warning(f"Can't find Cache Options for: {cls_name}")
|
|
216
220
|
|
|
217
221
|
if hasattr(module, "_cached_steps"):
|
|
218
222
|
cached_steps: list[int] = module._cached_steps
|