cache-dit 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -0,0 +1,412 @@
1
+ import torch
2
+ from typing import Tuple, Optional, Dict, Any, Union, List
3
+ from diffusers import HiDreamImageTransformer2DModel
4
+ from diffusers.models.transformers.transformer_hidream_image import (
5
+ HiDreamBlock,
6
+ HiDreamImageTransformerBlock,
7
+ HiDreamImageSingleTransformerBlock,
8
+ Transformer2DModelOutput,
9
+ )
10
+ from diffusers.utils import (
11
+ deprecate,
12
+ USE_PEFT_BACKEND,
13
+ scale_lora_layers,
14
+ unscale_lora_layers,
15
+ )
16
+ from cache_dit.cache_factory.patch_functors.functor_base import (
17
+ PatchFunctor,
18
+ )
19
+ from cache_dit.logger import init_logger
20
+
21
+ logger = init_logger(__name__)
22
+
23
+
24
+ class HiDreamPatchFunctor(PatchFunctor):
25
+
26
+ def apply(
27
+ self,
28
+ transformer: HiDreamImageTransformer2DModel,
29
+ **kwargs,
30
+ ) -> HiDreamImageTransformer2DModel:
31
+ if hasattr(transformer, "_is_patched"):
32
+ return transformer
33
+
34
+ is_patched = False
35
+
36
+ _block_id = 0
37
+ for block in transformer.double_stream_blocks:
38
+ assert isinstance(block, HiDreamBlock)
39
+ block.forward = __patch_block_forward__.__get__(block)
40
+ # NOTE: Patch Inner block and block_id
41
+ _block = block.block
42
+ assert isinstance(_block, HiDreamImageTransformerBlock)
43
+ _block._block_id = _block_id
44
+ _block.forward = __patch_double_forward__.__get__(_block)
45
+ _block_id += 1
46
+
47
+ for block in transformer.single_stream_blocks:
48
+ assert isinstance(block, HiDreamBlock)
49
+ block.forward = __patch_block_forward__.__get__(block)
50
+ # NOTE: Patch Inner block and block_id
51
+ _block = block.block
52
+ assert isinstance(_block, HiDreamImageSingleTransformerBlock)
53
+ _block._block_id = _block_id
54
+ _block.forward = __patch_single_forward__.__get__(_block)
55
+ _block_id += 1
56
+
57
+ is_patched = True
58
+ cls_name = transformer.__class__.__name__
59
+
60
+ if is_patched:
61
+ logger.warning(f"Patched {cls_name} for cache-dit.")
62
+ assert not getattr(transformer, "_is_parallelized", False), (
63
+ "Please call `cache_dit.enable_cache` before Parallelize, "
64
+ "the __patch_transformer_forward__ will overwrite the "
65
+ "parallized forward and cause a downgrade of performance."
66
+ )
67
+ transformer.forward = __patch_transformer_forward__.__get__(
68
+ transformer
69
+ )
70
+
71
+ transformer._is_patched = is_patched # True or False
72
+
73
+ logger.info(
74
+ f"Applied {self.__class__.__name__} for {cls_name}, "
75
+ f"Patch: {is_patched}."
76
+ )
77
+
78
+ return transformer
79
+
80
+
81
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
82
+ def __patch_double_forward__(
83
+ self: HiDreamImageTransformerBlock,
84
+ hidden_states: torch.Tensor,
85
+ encoder_hidden_states: torch.Tensor, # initial_encoder_hidden_states
86
+ hidden_states_masks: Optional[torch.Tensor] = None,
87
+ temb: Optional[torch.Tensor] = None,
88
+ image_rotary_emb: torch.Tensor = None,
89
+ llama31_encoder_hidden_states: List[torch.Tensor] = None,
90
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
91
+ # Assume block_id was patched in transformer forward:
92
+ # for i, block in enumerate(blocks): block._block_id = i;
93
+ block_id = self._block_id
94
+ initial_encoder_hidden_states_seq_len = encoder_hidden_states.shape[1]
95
+ cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[block_id]
96
+ cur_encoder_hidden_states = torch.cat(
97
+ [encoder_hidden_states, cur_llama31_encoder_hidden_states],
98
+ dim=1,
99
+ )
100
+ encoder_hidden_states = cur_encoder_hidden_states
101
+
102
+ wtype = hidden_states.dtype
103
+ (
104
+ shift_msa_i,
105
+ scale_msa_i,
106
+ gate_msa_i,
107
+ shift_mlp_i,
108
+ scale_mlp_i,
109
+ gate_mlp_i,
110
+ shift_msa_t,
111
+ scale_msa_t,
112
+ gate_msa_t,
113
+ shift_mlp_t,
114
+ scale_mlp_t,
115
+ gate_mlp_t,
116
+ ) = self.adaLN_modulation(temb)[:, None].chunk(12, dim=-1)
117
+
118
+ # 1. MM-Attention
119
+ norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
120
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
121
+ norm_encoder_hidden_states = self.norm1_t(encoder_hidden_states).to(
122
+ dtype=wtype
123
+ )
124
+ norm_encoder_hidden_states = (
125
+ norm_encoder_hidden_states * (1 + scale_msa_t) + shift_msa_t
126
+ )
127
+
128
+ attn_output_i, attn_output_t = self.attn1(
129
+ norm_hidden_states,
130
+ hidden_states_masks,
131
+ norm_encoder_hidden_states,
132
+ image_rotary_emb=image_rotary_emb,
133
+ )
134
+
135
+ hidden_states = gate_msa_i * attn_output_i + hidden_states
136
+ encoder_hidden_states = gate_msa_t * attn_output_t + encoder_hidden_states
137
+
138
+ # 2. Feed-forward
139
+ norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
140
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
141
+ norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(
142
+ dtype=wtype
143
+ )
144
+ norm_encoder_hidden_states = (
145
+ norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t
146
+ )
147
+
148
+ ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
149
+ ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
150
+ hidden_states = ff_output_i + hidden_states
151
+ encoder_hidden_states = ff_output_t + encoder_hidden_states
152
+
153
+ initial_encoder_hidden_states = encoder_hidden_states[
154
+ :, :initial_encoder_hidden_states_seq_len
155
+ ]
156
+
157
+ return hidden_states, initial_encoder_hidden_states
158
+
159
+
160
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
161
+ def __patch_single_forward__(
162
+ self: HiDreamImageSingleTransformerBlock,
163
+ hidden_states: torch.Tensor,
164
+ hidden_states_masks: Optional[torch.Tensor] = None,
165
+ temb: Optional[torch.Tensor] = None,
166
+ image_rotary_emb: torch.Tensor = None,
167
+ llama31_encoder_hidden_states: List[torch.Tensor] = None,
168
+ ) -> torch.Tensor:
169
+ # Assume block_id was patched in transformer forward:
170
+ # for i, block in enumerate(blocks): block._block_id = i;
171
+ block_id = self._block_id
172
+ hidden_states_seq_len = hidden_states.shape[1]
173
+ cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[block_id]
174
+ hidden_states = torch.cat(
175
+ [hidden_states, cur_llama31_encoder_hidden_states], dim=1
176
+ )
177
+
178
+ wtype = hidden_states.dtype
179
+ (
180
+ shift_msa_i,
181
+ scale_msa_i,
182
+ gate_msa_i,
183
+ shift_mlp_i,
184
+ scale_mlp_i,
185
+ gate_mlp_i,
186
+ ) = self.adaLN_modulation(temb)[:, None].chunk(6, dim=-1)
187
+
188
+ # 1. MM-Attention
189
+ norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
190
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
191
+ attn_output_i = self.attn1(
192
+ norm_hidden_states,
193
+ hidden_states_masks,
194
+ image_rotary_emb=image_rotary_emb,
195
+ )
196
+ hidden_states = gate_msa_i * attn_output_i + hidden_states
197
+
198
+ # 2. Feed-forward
199
+ norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
200
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
201
+ ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
202
+ hidden_states = ff_output_i + hidden_states
203
+
204
+ hidden_states = hidden_states[:, :hidden_states_seq_len]
205
+
206
+ return hidden_states
207
+
208
+
209
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
210
+ def __patch_block_forward__(
211
+ self: HiDreamBlock,
212
+ hidden_states: torch.Tensor,
213
+ *args,
214
+ **kwargs,
215
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
216
+ return self.block(hidden_states, *args, **kwargs)
217
+
218
+
219
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
220
+ def __patch_transformer_forward__(
221
+ self: HiDreamImageTransformer2DModel,
222
+ hidden_states: torch.Tensor,
223
+ timesteps: torch.LongTensor = None,
224
+ encoder_hidden_states_t5: torch.Tensor = None,
225
+ encoder_hidden_states_llama3: torch.Tensor = None,
226
+ pooled_embeds: torch.Tensor = None,
227
+ img_ids: Optional[torch.Tensor] = None,
228
+ img_sizes: Optional[List[Tuple[int, int]]] = None,
229
+ hidden_states_masks: Optional[torch.Tensor] = None,
230
+ attention_kwargs: Optional[Dict[str, Any]] = None,
231
+ return_dict: bool = True,
232
+ **kwargs,
233
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
234
+ encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
235
+
236
+ if encoder_hidden_states is not None:
237
+ deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
238
+ deprecate("encoder_hidden_states", "0.35.0", deprecation_message)
239
+ encoder_hidden_states_t5 = encoder_hidden_states[0]
240
+ encoder_hidden_states_llama3 = encoder_hidden_states[1]
241
+
242
+ if (
243
+ img_ids is not None
244
+ and img_sizes is not None
245
+ and hidden_states_masks is None
246
+ ):
247
+ deprecation_message = "Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
248
+ deprecate("img_ids", "0.35.0", deprecation_message)
249
+
250
+ if hidden_states_masks is not None and (
251
+ img_ids is None or img_sizes is None
252
+ ):
253
+ raise ValueError(
254
+ "if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed."
255
+ )
256
+ elif hidden_states_masks is not None and hidden_states.ndim != 3:
257
+ raise ValueError(
258
+ "if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
259
+ )
260
+
261
+ if attention_kwargs is not None:
262
+ attention_kwargs = attention_kwargs.copy()
263
+ lora_scale = attention_kwargs.pop("scale", 1.0)
264
+ else:
265
+ lora_scale = 1.0
266
+
267
+ if USE_PEFT_BACKEND:
268
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
269
+ scale_lora_layers(self, lora_scale)
270
+ else:
271
+ if (
272
+ attention_kwargs is not None
273
+ and attention_kwargs.get("scale", None) is not None
274
+ ):
275
+ logger.warning(
276
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
277
+ )
278
+
279
+ # spatial forward
280
+ batch_size = hidden_states.shape[0]
281
+ hidden_states_type = hidden_states.dtype
282
+
283
+ # Patchify the input
284
+ if hidden_states_masks is None:
285
+ hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(
286
+ hidden_states
287
+ )
288
+
289
+ # Embed the hidden states
290
+ hidden_states = self.x_embedder(hidden_states)
291
+
292
+ # 0. time
293
+ timesteps = self.t_embedder(timesteps, hidden_states_type)
294
+ p_embedder = self.p_embedder(pooled_embeds)
295
+ temb = timesteps + p_embedder
296
+
297
+ encoder_hidden_states = [
298
+ encoder_hidden_states_llama3[k] for k in self.config.llama_layers
299
+ ]
300
+
301
+ if self.caption_projection is not None:
302
+ new_encoder_hidden_states = []
303
+ for i, enc_hidden_state in enumerate(encoder_hidden_states):
304
+ enc_hidden_state = self.caption_projection[i](enc_hidden_state)
305
+ enc_hidden_state = enc_hidden_state.view(
306
+ batch_size, -1, hidden_states.shape[-1]
307
+ )
308
+ new_encoder_hidden_states.append(enc_hidden_state)
309
+ encoder_hidden_states = new_encoder_hidden_states
310
+ encoder_hidden_states_t5 = self.caption_projection[-1](
311
+ encoder_hidden_states_t5
312
+ )
313
+ encoder_hidden_states_t5 = encoder_hidden_states_t5.view(
314
+ batch_size, -1, hidden_states.shape[-1]
315
+ )
316
+ encoder_hidden_states.append(encoder_hidden_states_t5)
317
+
318
+ txt_ids = torch.zeros(
319
+ batch_size,
320
+ encoder_hidden_states[-1].shape[1]
321
+ + encoder_hidden_states[-2].shape[1]
322
+ + encoder_hidden_states[0].shape[1],
323
+ 3,
324
+ device=img_ids.device,
325
+ dtype=img_ids.dtype,
326
+ )
327
+ ids = torch.cat((img_ids, txt_ids), dim=1)
328
+ image_rotary_emb = self.pe_embedder(ids)
329
+
330
+ # 2. Blocks
331
+ # NOTE: block_id is no-need anymore.
332
+ initial_encoder_hidden_states = torch.cat(
333
+ [encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1
334
+ )
335
+ llama31_encoder_hidden_states = encoder_hidden_states
336
+ for bid, block in enumerate(self.double_stream_blocks):
337
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
338
+ hidden_states, initial_encoder_hidden_states = (
339
+ self._gradient_checkpointing_func(
340
+ block,
341
+ hidden_states,
342
+ initial_encoder_hidden_states,
343
+ hidden_states_masks,
344
+ temb,
345
+ image_rotary_emb,
346
+ llama31_encoder_hidden_states,
347
+ )
348
+ )
349
+ else:
350
+ hidden_states, initial_encoder_hidden_states = block(
351
+ hidden_states,
352
+ initial_encoder_hidden_states, # encoder_hidden_states
353
+ hidden_states_masks=hidden_states_masks,
354
+ temb=temb,
355
+ image_rotary_emb=image_rotary_emb,
356
+ llama31_encoder_hidden_states=llama31_encoder_hidden_states,
357
+ )
358
+
359
+ image_tokens_seq_len = hidden_states.shape[1]
360
+ hidden_states = torch.cat(
361
+ [hidden_states, initial_encoder_hidden_states], dim=1
362
+ )
363
+ if hidden_states_masks is not None:
364
+ # NOTE: Patched
365
+ cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[
366
+ self.double_stream_blocks[-1].block._block_id
367
+ ]
368
+ encoder_attention_mask_ones = torch.ones(
369
+ (
370
+ batch_size,
371
+ initial_encoder_hidden_states.shape[1]
372
+ + cur_llama31_encoder_hidden_states.shape[1],
373
+ ),
374
+ device=hidden_states_masks.device,
375
+ dtype=hidden_states_masks.dtype,
376
+ )
377
+ hidden_states_masks = torch.cat(
378
+ [hidden_states_masks, encoder_attention_mask_ones], dim=1
379
+ )
380
+
381
+ for bid, block in enumerate(self.single_stream_blocks):
382
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
383
+ hidden_states = self._gradient_checkpointing_func(
384
+ block,
385
+ hidden_states,
386
+ hidden_states_masks,
387
+ temb,
388
+ image_rotary_emb,
389
+ llama31_encoder_hidden_states,
390
+ )
391
+ else:
392
+ hidden_states = block(
393
+ hidden_states,
394
+ hidden_states_masks=hidden_states_masks,
395
+ temb=temb,
396
+ image_rotary_emb=image_rotary_emb,
397
+ llama31_encoder_hidden_states=llama31_encoder_hidden_states,
398
+ )
399
+
400
+ hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
401
+ output = self.final_layer(hidden_states, temb)
402
+ output = self.unpatchify(output, img_sizes, self.training)
403
+ if hidden_states_masks is not None:
404
+ hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
405
+
406
+ if USE_PEFT_BACKEND:
407
+ # remove `lora_scale` from each PEFT layer
408
+ unscale_lora_layers(self, lora_scale)
409
+
410
+ if not return_dict:
411
+ return (output,)
412
+ return Transformer2DModelOutput(sample=output)
@@ -0,0 +1,213 @@
1
+ import torch
2
+ from typing import Optional, Union, List
3
+ from diffusers import HunyuanDiT2DModel
4
+ from diffusers.models.transformers.hunyuan_transformer_2d import (
5
+ HunyuanDiTBlock,
6
+ Transformer2DModelOutput,
7
+ )
8
+ from cache_dit.cache_factory.patch_functors.functor_base import (
9
+ PatchFunctor,
10
+ )
11
+ from cache_dit.logger import init_logger
12
+
13
+ logger = init_logger(__name__)
14
+
15
+
16
+ class HunyuanDiTPatchFunctor(PatchFunctor):
17
+
18
+ def apply(
19
+ self,
20
+ transformer: HunyuanDiT2DModel,
21
+ **kwargs,
22
+ ) -> HunyuanDiT2DModel:
23
+ if hasattr(transformer, "_is_patched"):
24
+ return transformer
25
+
26
+ is_patched = False
27
+
28
+ num_layers = transformer.config.num_layers
29
+ layer_id = 0
30
+ for block in transformer.blocks:
31
+ assert isinstance(block, HunyuanDiTBlock)
32
+ block._num_layers = num_layers
33
+ block._layer_id = layer_id
34
+ block.forward = __patch_block_forward__.__get__(block)
35
+ layer_id += 1
36
+
37
+ is_patched = True
38
+
39
+ cls_name = transformer.__class__.__name__
40
+
41
+ if is_patched:
42
+ logger.warning(f"Patched {cls_name} for cache-dit.")
43
+ assert not getattr(transformer, "_is_parallelized", False), (
44
+ "Please call `cache_dit.enable_cache` before Parallelize, "
45
+ "the __patch_transformer_forward__ will overwrite the "
46
+ "parallized forward and cause a downgrade of performance."
47
+ )
48
+ transformer.forward = __patch_transformer_forward__.__get__(
49
+ transformer
50
+ )
51
+
52
+ transformer._is_patched = is_patched # True or False
53
+
54
+ logger.info(
55
+ f"Applied {self.__class__.__name__} for {cls_name}, "
56
+ f"Patch: {is_patched}."
57
+ )
58
+
59
+ return transformer
60
+
61
+
62
+ def __patch_block_forward__(
63
+ self: HunyuanDiTBlock,
64
+ hidden_states: torch.Tensor,
65
+ encoder_hidden_states: Optional[torch.Tensor] = None,
66
+ temb: Optional[torch.Tensor] = None,
67
+ image_rotary_emb: torch.Tensor = None,
68
+ controlnet_block_samples: torch.Tensor = None,
69
+ skips: List[torch.Tensor] = [],
70
+ ) -> torch.Tensor:
71
+ # Notice that normalization is always applied before the real computation in the following blocks.
72
+ # 0. Long Skip Connection
73
+ num_layers = self._num_layers
74
+ layer_id = self._layer_id
75
+
76
+ if layer_id > num_layers // 2:
77
+ if controlnet_block_samples is not None:
78
+ skip = skips.pop() + controlnet_block_samples.pop()
79
+ else:
80
+ skip = skips.pop()
81
+ else:
82
+ skip = None
83
+
84
+ if self.skip_linear is not None:
85
+ cat = torch.cat([hidden_states, skip], dim=-1)
86
+ cat = self.skip_norm(cat)
87
+ hidden_states = self.skip_linear(cat)
88
+
89
+ # 1. Self-Attention
90
+ norm_hidden_states = self.norm1(
91
+ hidden_states, temb
92
+ ) # checked: self.norm1 is correct
93
+ attn_output = self.attn1(
94
+ norm_hidden_states,
95
+ image_rotary_emb=image_rotary_emb,
96
+ )
97
+ hidden_states = hidden_states + attn_output
98
+
99
+ # 2. Cross-Attention
100
+ hidden_states = hidden_states + self.attn2(
101
+ self.norm2(hidden_states),
102
+ encoder_hidden_states=encoder_hidden_states,
103
+ image_rotary_emb=image_rotary_emb,
104
+ )
105
+
106
+ # FFN Layer
107
+ mlp_inputs = self.norm3(hidden_states)
108
+ hidden_states = hidden_states + self.ff(mlp_inputs)
109
+
110
+ if layer_id < (num_layers // 2 - 1):
111
+ skips.append(hidden_states)
112
+
113
+ return hidden_states
114
+
115
+
116
+ def __patch_transformer_forward__(
117
+ self: HunyuanDiT2DModel,
118
+ hidden_states,
119
+ timestep,
120
+ encoder_hidden_states=None,
121
+ text_embedding_mask=None,
122
+ encoder_hidden_states_t5=None,
123
+ text_embedding_mask_t5=None,
124
+ image_meta_size=None,
125
+ style=None,
126
+ image_rotary_emb=None,
127
+ controlnet_block_samples=None,
128
+ return_dict=True,
129
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
130
+ height, width = hidden_states.shape[-2:]
131
+
132
+ hidden_states = self.pos_embed(hidden_states)
133
+
134
+ temb = self.time_extra_emb(
135
+ timestep,
136
+ encoder_hidden_states_t5,
137
+ image_meta_size,
138
+ style,
139
+ hidden_dtype=timestep.dtype,
140
+ ) # [B, D]
141
+
142
+ # text projection
143
+ batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
144
+ encoder_hidden_states_t5 = self.text_embedder(
145
+ encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
146
+ )
147
+ encoder_hidden_states_t5 = encoder_hidden_states_t5.view(
148
+ batch_size, sequence_length, -1
149
+ )
150
+
151
+ encoder_hidden_states = torch.cat(
152
+ [encoder_hidden_states, encoder_hidden_states_t5], dim=1
153
+ )
154
+ text_embedding_mask = torch.cat(
155
+ [text_embedding_mask, text_embedding_mask_t5], dim=-1
156
+ )
157
+ text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
158
+
159
+ encoder_hidden_states = torch.where(
160
+ text_embedding_mask, encoder_hidden_states, self.text_embedding_padding
161
+ )
162
+
163
+ skips = []
164
+ for layer, block in enumerate(self.blocks):
165
+ hidden_states = block(
166
+ hidden_states,
167
+ temb=temb,
168
+ encoder_hidden_states=encoder_hidden_states,
169
+ image_rotary_emb=image_rotary_emb,
170
+ controlnet_block_samples=controlnet_block_samples,
171
+ skips=skips,
172
+ ) # (N, L, D)
173
+
174
+ if (
175
+ controlnet_block_samples is not None
176
+ and len(controlnet_block_samples) != 0
177
+ ):
178
+ raise ValueError(
179
+ "The number of controls is not equal to the number of skip connections."
180
+ )
181
+
182
+ # final layer
183
+ hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
184
+ hidden_states = self.proj_out(hidden_states)
185
+ # (N, L, patch_size ** 2 * out_channels)
186
+
187
+ # unpatchify: (N, out_channels, H, W)
188
+ patch_size = self.pos_embed.patch_size
189
+ height = height // patch_size
190
+ width = width // patch_size
191
+
192
+ hidden_states = hidden_states.reshape(
193
+ shape=(
194
+ hidden_states.shape[0],
195
+ height,
196
+ width,
197
+ patch_size,
198
+ patch_size,
199
+ self.out_channels,
200
+ )
201
+ )
202
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
203
+ output = hidden_states.reshape(
204
+ shape=(
205
+ hidden_states.shape[0],
206
+ self.out_channels,
207
+ height * patch_size,
208
+ width * patch_size,
209
+ )
210
+ )
211
+ if not return_dict:
212
+ return (output,)
213
+ return Transformer2DModelOutput(sample=output)
cache_dit/utils.py CHANGED
@@ -52,6 +52,9 @@ def summary(
52
52
  if hasattr(adapter_or_others, "transformer_2"):
53
53
  transformer_2 = adapter_or_others.transformer_2
54
54
 
55
+ if not BlockAdapter.is_cached(transformer):
56
+ return [CacheStats()]
57
+
55
58
  blocks_stats: List[CacheStats] = []
56
59
  for blocks in BlockAdapter.find_blocks(transformer):
57
60
  blocks_stats.append(
@@ -212,7 +215,8 @@ def _summary(
212
215
  if logging:
213
216
  print(f"\n🤗Cache Options: {cls_name}\n\n{cache_options}")
214
217
  else:
215
- logger.warning(f"Can't find Cache Options for: {cls_name}")
218
+ if logging:
219
+ logger.warning(f"Can't find Cache Options for: {cls_name}")
216
220
 
217
221
  if hasattr(module, "_cached_steps"):
218
222
  cached_steps: list[int] = module._cached_steps