cache-dit 0.2.32__py3-none-any.whl → 0.2.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,9 @@
1
- import inspect
2
-
3
1
  import torch
4
2
  import numpy as np
5
3
  from typing import Tuple, Optional, Dict, Any, Union
6
4
  from diffusers import ChromaTransformer2DModel
7
5
  from diffusers.models.transformers.transformer_chroma import (
6
+ ChromaTransformerBlock,
8
7
  ChromaSingleTransformerBlock,
9
8
  Transformer2DModelOutput,
10
9
  )
@@ -27,24 +26,31 @@ class ChromaPatchFunctor(PatchFunctor):
27
26
  def apply(
28
27
  self,
29
28
  transformer: ChromaTransformer2DModel,
30
- blocks: torch.nn.ModuleList = None,
31
29
  **kwargs,
32
30
  ) -> ChromaTransformer2DModel:
33
31
  if hasattr(transformer, "_is_patched"):
34
32
  return transformer
35
33
 
36
- if blocks is None:
37
- blocks = transformer.single_transformer_blocks
38
-
39
34
  is_patched = False
40
- for block in blocks:
41
- if isinstance(block, ChromaSingleTransformerBlock):
42
- forward_parameters = inspect.signature(
43
- block.forward
44
- ).parameters.keys()
45
- if "encoder_hidden_states" not in forward_parameters:
46
- block.forward = __patch_single_forward__.__get__(block)
47
- is_patched = True
35
+ for index_block, block in enumerate(transformer.transformer_blocks):
36
+ assert isinstance(block, ChromaTransformerBlock)
37
+ img_offset = 3 * len(transformer.single_transformer_blocks)
38
+ txt_offset = img_offset + 6 * len(transformer.transformer_blocks)
39
+ img_modulation = img_offset + 6 * index_block
40
+ text_modulation = txt_offset + 6 * index_block
41
+ block._img_modulation = img_modulation
42
+ block._text_modulation = text_modulation
43
+ block.forward = __patch_double_forward__.__get__(block)
44
+
45
+ for index_block, block in enumerate(
46
+ transformer.single_transformer_blocks
47
+ ):
48
+ assert isinstance(block, ChromaSingleTransformerBlock)
49
+ start_idx = 3 * index_block
50
+ block._start_idx = start_idx
51
+ block.forward = __patch_single_forward__.__get__(block)
52
+
53
+ is_patched = True
48
54
 
49
55
  cls_name = transformer.__class__.__name__
50
56
 
@@ -69,25 +75,123 @@ class ChromaPatchFunctor(PatchFunctor):
69
75
  return transformer
70
76
 
71
77
 
78
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
79
+ def __patch_double_forward__(
80
+ self: ChromaTransformerBlock,
81
+ hidden_states: torch.Tensor,
82
+ encoder_hidden_states: torch.Tensor,
83
+ pooled_temb: torch.Tensor,
84
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
85
+ attention_mask: Optional[torch.Tensor] = None,
86
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
87
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
88
+ # TODO: Fuse controlnet into block forward
89
+ img_modulation = self._img_modulation
90
+ text_modulation = self._text_modulation
91
+ temb = torch.cat(
92
+ (
93
+ pooled_temb[:, img_modulation : img_modulation + 6],
94
+ pooled_temb[:, text_modulation : text_modulation + 6],
95
+ ),
96
+ dim=1,
97
+ )
98
+
99
+ temb_img, temb_txt = temb[:, :6], temb[:, 6:]
100
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
101
+ hidden_states, emb=temb_img
102
+ )
103
+
104
+ (
105
+ norm_encoder_hidden_states,
106
+ c_gate_msa,
107
+ c_shift_mlp,
108
+ c_scale_mlp,
109
+ c_gate_mlp,
110
+ ) = self.norm1_context(encoder_hidden_states, emb=temb_txt)
111
+ joint_attention_kwargs = joint_attention_kwargs or {}
112
+ if attention_mask is not None:
113
+ attention_mask = (
114
+ attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
115
+ )
116
+
117
+ # Attention.
118
+ attention_outputs = self.attn(
119
+ hidden_states=norm_hidden_states,
120
+ encoder_hidden_states=norm_encoder_hidden_states,
121
+ image_rotary_emb=image_rotary_emb,
122
+ attention_mask=attention_mask,
123
+ **joint_attention_kwargs,
124
+ )
125
+
126
+ if len(attention_outputs) == 2:
127
+ attn_output, context_attn_output = attention_outputs
128
+ elif len(attention_outputs) == 3:
129
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
130
+
131
+ # Process attention outputs for the `hidden_states`.
132
+ attn_output = gate_msa.unsqueeze(1) * attn_output
133
+ hidden_states = hidden_states + attn_output
134
+
135
+ norm_hidden_states = self.norm2(hidden_states)
136
+ norm_hidden_states = (
137
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
138
+ )
139
+
140
+ ff_output = self.ff(norm_hidden_states)
141
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
142
+
143
+ hidden_states = hidden_states + ff_output
144
+ if len(attention_outputs) == 3:
145
+ hidden_states = hidden_states + ip_attn_output
146
+
147
+ # Process attention outputs for the `encoder_hidden_states`.
148
+
149
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
150
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
151
+
152
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
153
+ norm_encoder_hidden_states = (
154
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
155
+ + c_shift_mlp[:, None]
156
+ )
157
+
158
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
159
+ encoder_hidden_states = (
160
+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
161
+ )
162
+ if encoder_hidden_states.dtype == torch.float16:
163
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
164
+
165
+ return encoder_hidden_states, hidden_states
166
+
167
+
72
168
  # adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
73
169
  def __patch_single_forward__(
74
170
  self: ChromaSingleTransformerBlock, # Almost same as FluxSingleTransformerBlock
75
171
  hidden_states: torch.Tensor,
76
- encoder_hidden_states: torch.Tensor,
77
- temb: torch.Tensor,
172
+ pooled_temb: torch.Tensor,
78
173
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
174
+ attention_mask: Optional[torch.Tensor] = None,
79
175
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
80
- ) -> Tuple[torch.Tensor, torch.Tensor]:
81
- text_seq_len = encoder_hidden_states.shape[1]
82
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
176
+ ) -> torch.Tensor:
177
+ # TODO: Fuse controlnet into block forward
178
+ start_idx = self._start_idx
179
+ temb = pooled_temb[:, start_idx : start_idx + 3]
83
180
 
84
181
  residual = hidden_states
85
182
  norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
86
183
  mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
87
184
  joint_attention_kwargs = joint_attention_kwargs or {}
185
+
186
+ if attention_mask is not None:
187
+ attention_mask = (
188
+ attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
189
+ )
190
+
88
191
  attn_output = self.attn(
89
192
  hidden_states=norm_hidden_states,
90
193
  image_rotary_emb=image_rotary_emb,
194
+ attention_mask=attention_mask,
91
195
  **joint_attention_kwargs,
92
196
  )
93
197
 
@@ -98,11 +202,7 @@ def __patch_single_forward__(
98
202
  if hidden_states.dtype == torch.float16:
99
203
  hidden_states = hidden_states.clip(-65504, 65504)
100
204
 
101
- encoder_hidden_states, hidden_states = (
102
- hidden_states[:, :text_seq_len],
103
- hidden_states[:, text_seq_len:],
104
- )
105
- return encoder_hidden_states, hidden_states
205
+ return hidden_states
106
206
 
107
207
 
108
208
  # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
@@ -174,24 +274,13 @@ def __patch_transformer_forward__(
174
274
  joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
175
275
 
176
276
  for index_block, block in enumerate(self.transformer_blocks):
177
- img_offset = 3 * len(self.single_transformer_blocks)
178
- txt_offset = img_offset + 6 * len(self.transformer_blocks)
179
- img_modulation = img_offset + 6 * index_block
180
- text_modulation = txt_offset + 6 * index_block
181
- temb = torch.cat(
182
- (
183
- pooled_temb[:, img_modulation : img_modulation + 6],
184
- pooled_temb[:, text_modulation : text_modulation + 6],
185
- ),
186
- dim=1,
187
- )
188
277
  if torch.is_grad_enabled() and self.gradient_checkpointing:
189
278
  encoder_hidden_states, hidden_states = (
190
279
  self._gradient_checkpointing_func(
191
280
  block,
192
281
  hidden_states,
193
282
  encoder_hidden_states,
194
- temb,
283
+ pooled_temb,
195
284
  image_rotary_emb,
196
285
  attention_mask,
197
286
  )
@@ -201,12 +290,13 @@ def __patch_transformer_forward__(
201
290
  encoder_hidden_states, hidden_states = block(
202
291
  hidden_states=hidden_states,
203
292
  encoder_hidden_states=encoder_hidden_states,
204
- temb=temb,
293
+ pooled_temb=pooled_temb,
205
294
  image_rotary_emb=image_rotary_emb,
206
295
  attention_mask=attention_mask,
207
296
  joint_attention_kwargs=joint_attention_kwargs,
208
297
  )
209
298
 
299
+ # TODO: Fuse controlnet into block forward
210
300
  # controlnet residual
211
301
  if controlnet_block_samples is not None:
212
302
  interval_control = len(self.transformer_blocks) / len(
@@ -227,43 +317,43 @@ def __patch_transformer_forward__(
227
317
  + controlnet_block_samples[index_block // interval_control]
228
318
  )
229
319
 
320
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
321
+
230
322
  for index_block, block in enumerate(self.single_transformer_blocks):
231
- start_idx = 3 * index_block
232
- temb = pooled_temb[:, start_idx : start_idx + 3]
233
323
  if torch.is_grad_enabled() and self.gradient_checkpointing:
234
- encoder_hidden_states, hidden_states = (
235
- self._gradient_checkpointing_func(
236
- block,
237
- hidden_states,
238
- encoder_hidden_states,
239
- temb,
240
- image_rotary_emb,
241
- )
324
+ hidden_states = self._gradient_checkpointing_func(
325
+ block,
326
+ hidden_states,
327
+ pooled_temb,
328
+ image_rotary_emb,
329
+ attention_mask,
330
+ joint_attention_kwargs,
242
331
  )
243
332
 
244
333
  else:
245
- encoder_hidden_states, hidden_states = block(
334
+ hidden_states = block(
246
335
  hidden_states=hidden_states,
247
- encoder_hidden_states=encoder_hidden_states,
248
- temb=temb,
336
+ pooled_temb=pooled_temb,
249
337
  image_rotary_emb=image_rotary_emb,
250
338
  attention_mask=attention_mask,
251
339
  joint_attention_kwargs=joint_attention_kwargs,
252
340
  )
253
341
 
342
+ # TODO: Fuse controlnet into block forward
254
343
  # controlnet residual
255
344
  if controlnet_single_block_samples is not None:
256
345
  interval_control = len(self.single_transformer_blocks) / len(
257
346
  controlnet_single_block_samples
258
347
  )
259
348
  interval_control = int(np.ceil(interval_control))
260
- hidden_states = (
261
- hidden_states
349
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
350
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
262
351
  + controlnet_single_block_samples[
263
352
  index_block // interval_control
264
353
  ]
265
354
  )
266
355
 
356
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
267
357
  temb = pooled_temb[:, -2:]
268
358
  hidden_states = self.norm_out(hidden_states, temb)
269
359
  output = self.proj_out(hidden_states)
@@ -0,0 +1,130 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from typing import Optional, Dict, Any
5
+ from diffusers.models.transformers.dit_transformer_2d import (
6
+ DiTTransformer2DModel,
7
+ Transformer2DModelOutput,
8
+ )
9
+ from cache_dit.cache_factory.patch_functors.functor_base import (
10
+ PatchFunctor,
11
+ )
12
+ from cache_dit.logger import init_logger
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ class DiTPatchFunctor(PatchFunctor):
18
+
19
+ def apply(
20
+ self,
21
+ transformer: DiTTransformer2DModel,
22
+ **kwargs,
23
+ ) -> DiTTransformer2DModel:
24
+ if hasattr(transformer, "_is_patched"):
25
+ return transformer
26
+
27
+ is_patched = False
28
+
29
+ transformer._norm1_emb = transformer.transformer_blocks[0].norm1.emb
30
+
31
+ is_patched = True
32
+
33
+ cls_name = transformer.__class__.__name__
34
+
35
+ if is_patched:
36
+ logger.warning(f"Patched {cls_name} for cache-dit.")
37
+ assert not getattr(transformer, "_is_parallelized", False), (
38
+ "Please call `cache_dit.enable_cache` before Parallelize, "
39
+ "the __patch_transformer_forward__ will overwrite the "
40
+ "parallized forward and cause a downgrade of performance."
41
+ )
42
+ transformer.forward = __patch_transformer_forward__.__get__(
43
+ transformer
44
+ )
45
+
46
+ transformer._is_patched = is_patched # True or False
47
+
48
+ logger.info(
49
+ f"Applied {self.__class__.__name__} for {cls_name}, "
50
+ f"Patch: {is_patched}."
51
+ )
52
+
53
+ return transformer
54
+
55
+
56
+ def __patch_transformer_forward__(
57
+ self: DiTTransformer2DModel,
58
+ hidden_states: torch.Tensor,
59
+ timestep: Optional[torch.LongTensor] = None,
60
+ class_labels: Optional[torch.LongTensor] = None,
61
+ cross_attention_kwargs: Dict[str, Any] = None,
62
+ return_dict: bool = True,
63
+ ):
64
+ height, width = (
65
+ hidden_states.shape[-2] // self.patch_size,
66
+ hidden_states.shape[-1] // self.patch_size,
67
+ )
68
+ hidden_states = self.pos_embed(hidden_states)
69
+
70
+ # 2. Blocks
71
+ for block in self.transformer_blocks:
72
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
73
+ hidden_states = self._gradient_checkpointing_func(
74
+ block,
75
+ hidden_states,
76
+ None,
77
+ None,
78
+ None,
79
+ timestep,
80
+ cross_attention_kwargs,
81
+ class_labels,
82
+ )
83
+ else:
84
+ hidden_states = block(
85
+ hidden_states,
86
+ attention_mask=None,
87
+ encoder_hidden_states=None,
88
+ encoder_attention_mask=None,
89
+ timestep=timestep,
90
+ cross_attention_kwargs=cross_attention_kwargs,
91
+ class_labels=class_labels,
92
+ )
93
+
94
+ # 3. Output
95
+ # conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype)
96
+ conditioning = self._norm1_emb(
97
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
98
+ )
99
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
100
+ hidden_states = (
101
+ self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
102
+ )
103
+ hidden_states = self.proj_out_2(hidden_states)
104
+
105
+ # unpatchify
106
+ height = width = int(hidden_states.shape[1] ** 0.5)
107
+ hidden_states = hidden_states.reshape(
108
+ shape=(
109
+ -1,
110
+ height,
111
+ width,
112
+ self.patch_size,
113
+ self.patch_size,
114
+ self.out_channels,
115
+ )
116
+ )
117
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
118
+ output = hidden_states.reshape(
119
+ shape=(
120
+ -1,
121
+ self.out_channels,
122
+ height * self.patch_size,
123
+ width * self.patch_size,
124
+ )
125
+ )
126
+
127
+ if not return_dict:
128
+ return (output,)
129
+
130
+ return Transformer2DModelOutput(sample=output)
@@ -89,6 +89,9 @@ def quantize_ao(
89
89
  PerRow,
90
90
  )
91
91
 
92
+ if per_row: # Ensure bfloat16
93
+ module.to(torch.bfloat16)
94
+
92
95
  quantization_fn = float8_dynamic_activation_float8_weight(
93
96
  weight_dtype=kwargs.get(
94
97
  "weight_dtype",