cache-dit 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (32) hide show
  1. cache_dit/__init__.py +9 -4
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +16 -3
  4. cache_dit/cache_factory/block_adapters/__init__.py +538 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +121 -563
  8. cache_dit/cache_factory/cache_blocks/__init__.py +18 -0
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +275 -0
  11. cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +100 -82
  12. cache_dit/cache_factory/cache_blocks/utils.py +23 -0
  13. cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
  14. cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +94 -56
  15. cache_dit/cache_factory/cache_interface.py +24 -16
  16. cache_dit/cache_factory/forward_pattern.py +45 -24
  17. cache_dit/cache_factory/patch_functors/__init__.py +5 -0
  18. cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
  19. cache_dit/cache_factory/patch_functors/functor_chroma.py +276 -0
  20. cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +49 -31
  21. cache_dit/quantize/quantize_ao.py +19 -4
  22. cache_dit/quantize/quantize_interface.py +2 -2
  23. cache_dit/utils.py +19 -15
  24. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/METADATA +76 -19
  25. cache_dit-0.2.27.dist-info/RECORD +47 -0
  26. cache_dit-0.2.25.dist-info/RECORD +0 -36
  27. /cache_dit/cache_factory/{patch/__init__.py → cache_contexts/cache_manager.py} +0 -0
  28. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  29. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
  30. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
  31. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
  32. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,276 @@
1
+ import inspect
2
+
3
+ import torch
4
+ import numpy as np
5
+ from typing import Tuple, Optional, Dict, Any, Union
6
+ from diffusers import ChromaTransformer2DModel
7
+ from diffusers.models.transformers.transformer_chroma import (
8
+ ChromaSingleTransformerBlock,
9
+ Transformer2DModelOutput,
10
+ )
11
+ from diffusers.utils import (
12
+ USE_PEFT_BACKEND,
13
+ scale_lora_layers,
14
+ unscale_lora_layers,
15
+ )
16
+
17
+ from cache_dit.cache_factory.patch_functors.functor_base import (
18
+ PatchFunctor,
19
+ )
20
+ from cache_dit.logger import init_logger
21
+
22
+ logger = init_logger(__name__)
23
+
24
+
25
+ class ChromaPatchFunctor(PatchFunctor):
26
+
27
+ def apply(
28
+ self,
29
+ transformer: ChromaTransformer2DModel,
30
+ blocks: torch.nn.ModuleList = None,
31
+ **kwargs,
32
+ ) -> ChromaTransformer2DModel:
33
+ if getattr(transformer, "_is_patched", False):
34
+ return transformer
35
+
36
+ if blocks is None:
37
+ blocks = transformer.single_transformer_blocks
38
+
39
+ is_patched = False
40
+ for block in blocks:
41
+ if isinstance(block, ChromaSingleTransformerBlock):
42
+ forward_parameters = inspect.signature(
43
+ block.forward
44
+ ).parameters.keys()
45
+ if "encoder_hidden_states" not in forward_parameters:
46
+ block.forward = __patch_single_forward__.__get__(block)
47
+ is_patched = True
48
+
49
+ if is_patched:
50
+ logger.warning("Patched Chroma for cache-dit.")
51
+ assert not getattr(transformer, "_is_parallelized", False), (
52
+ "Please call `cache_dit.enable_cache` before Parallelize, "
53
+ "the __patch_transformer_forward__ will overwrite the "
54
+ "parallized forward and cause a downgrade of performance."
55
+ )
56
+ transformer.forward = __patch_transformer_forward__.__get__(
57
+ transformer
58
+ )
59
+ transformer._is_patched = True
60
+
61
+ cls_name = transformer.__class__.__name__
62
+ logger.info(
63
+ f"Applied {self.__class__.__name__} for {cls_name}, "
64
+ f"Patch: {is_patched}."
65
+ )
66
+
67
+ return transformer
68
+
69
+
70
+ # adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
71
+ def __patch_single_forward__(
72
+ self: ChromaSingleTransformerBlock, # Almost same as FluxSingleTransformerBlock
73
+ hidden_states: torch.Tensor,
74
+ encoder_hidden_states: torch.Tensor,
75
+ temb: torch.Tensor,
76
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
77
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
78
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
79
+ text_seq_len = encoder_hidden_states.shape[1]
80
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
81
+
82
+ residual = hidden_states
83
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
84
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
85
+ joint_attention_kwargs = joint_attention_kwargs or {}
86
+ attn_output = self.attn(
87
+ hidden_states=norm_hidden_states,
88
+ image_rotary_emb=image_rotary_emb,
89
+ **joint_attention_kwargs,
90
+ )
91
+
92
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
93
+ gate = gate.unsqueeze(1)
94
+ hidden_states = gate * self.proj_out(hidden_states)
95
+ hidden_states = residual + hidden_states
96
+ if hidden_states.dtype == torch.float16:
97
+ hidden_states = hidden_states.clip(-65504, 65504)
98
+
99
+ encoder_hidden_states, hidden_states = (
100
+ hidden_states[:, :text_seq_len],
101
+ hidden_states[:, text_seq_len:],
102
+ )
103
+ return encoder_hidden_states, hidden_states
104
+
105
+
106
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
107
+ def __patch_transformer_forward__(
108
+ self: ChromaTransformer2DModel,
109
+ hidden_states: torch.Tensor,
110
+ encoder_hidden_states: torch.Tensor = None,
111
+ timestep: torch.LongTensor = None,
112
+ img_ids: torch.Tensor = None,
113
+ txt_ids: torch.Tensor = None,
114
+ attention_mask: torch.Tensor = None,
115
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
116
+ controlnet_block_samples=None,
117
+ controlnet_single_block_samples=None,
118
+ return_dict: bool = True,
119
+ controlnet_blocks_repeat: bool = False,
120
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
121
+ if joint_attention_kwargs is not None:
122
+ joint_attention_kwargs = joint_attention_kwargs.copy()
123
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
124
+ else:
125
+ lora_scale = 1.0
126
+
127
+ if USE_PEFT_BACKEND:
128
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
129
+ scale_lora_layers(self, lora_scale)
130
+ else:
131
+ if (
132
+ joint_attention_kwargs is not None
133
+ and joint_attention_kwargs.get("scale", None) is not None
134
+ ):
135
+ logger.warning(
136
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
137
+ )
138
+
139
+ hidden_states = self.x_embedder(hidden_states)
140
+
141
+ timestep = timestep.to(hidden_states.dtype) * 1000
142
+
143
+ input_vec = self.time_text_embed(timestep)
144
+ pooled_temb = self.distilled_guidance_layer(input_vec)
145
+
146
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
147
+
148
+ if txt_ids.ndim == 3:
149
+ logger.warning(
150
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
151
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
152
+ )
153
+ txt_ids = txt_ids[0]
154
+ if img_ids.ndim == 3:
155
+ logger.warning(
156
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
157
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
158
+ )
159
+ img_ids = img_ids[0]
160
+
161
+ ids = torch.cat((txt_ids, img_ids), dim=0)
162
+ image_rotary_emb = self.pos_embed(ids)
163
+
164
+ if (
165
+ joint_attention_kwargs is not None
166
+ and "ip_adapter_image_embeds" in joint_attention_kwargs
167
+ ):
168
+ ip_adapter_image_embeds = joint_attention_kwargs.pop(
169
+ "ip_adapter_image_embeds"
170
+ )
171
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
172
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
173
+
174
+ for index_block, block in enumerate(self.transformer_blocks):
175
+ img_offset = 3 * len(self.single_transformer_blocks)
176
+ txt_offset = img_offset + 6 * len(self.transformer_blocks)
177
+ img_modulation = img_offset + 6 * index_block
178
+ text_modulation = txt_offset + 6 * index_block
179
+ temb = torch.cat(
180
+ (
181
+ pooled_temb[:, img_modulation : img_modulation + 6],
182
+ pooled_temb[:, text_modulation : text_modulation + 6],
183
+ ),
184
+ dim=1,
185
+ )
186
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
187
+ encoder_hidden_states, hidden_states = (
188
+ self._gradient_checkpointing_func(
189
+ block,
190
+ hidden_states,
191
+ encoder_hidden_states,
192
+ temb,
193
+ image_rotary_emb,
194
+ attention_mask,
195
+ )
196
+ )
197
+
198
+ else:
199
+ encoder_hidden_states, hidden_states = block(
200
+ hidden_states=hidden_states,
201
+ encoder_hidden_states=encoder_hidden_states,
202
+ temb=temb,
203
+ image_rotary_emb=image_rotary_emb,
204
+ attention_mask=attention_mask,
205
+ joint_attention_kwargs=joint_attention_kwargs,
206
+ )
207
+
208
+ # controlnet residual
209
+ if controlnet_block_samples is not None:
210
+ interval_control = len(self.transformer_blocks) / len(
211
+ controlnet_block_samples
212
+ )
213
+ interval_control = int(np.ceil(interval_control))
214
+ # For Xlabs ControlNet.
215
+ if controlnet_blocks_repeat:
216
+ hidden_states = (
217
+ hidden_states
218
+ + controlnet_block_samples[
219
+ index_block % len(controlnet_block_samples)
220
+ ]
221
+ )
222
+ else:
223
+ hidden_states = (
224
+ hidden_states
225
+ + controlnet_block_samples[index_block // interval_control]
226
+ )
227
+
228
+ for index_block, block in enumerate(self.single_transformer_blocks):
229
+ start_idx = 3 * index_block
230
+ temb = pooled_temb[:, start_idx : start_idx + 3]
231
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
232
+ encoder_hidden_states, hidden_states = (
233
+ self._gradient_checkpointing_func(
234
+ block,
235
+ hidden_states,
236
+ encoder_hidden_states,
237
+ temb,
238
+ image_rotary_emb,
239
+ )
240
+ )
241
+
242
+ else:
243
+ encoder_hidden_states, hidden_states = block(
244
+ hidden_states=hidden_states,
245
+ encoder_hidden_states=encoder_hidden_states,
246
+ temb=temb,
247
+ image_rotary_emb=image_rotary_emb,
248
+ attention_mask=attention_mask,
249
+ joint_attention_kwargs=joint_attention_kwargs,
250
+ )
251
+
252
+ # controlnet residual
253
+ if controlnet_single_block_samples is not None:
254
+ interval_control = len(self.single_transformer_blocks) / len(
255
+ controlnet_single_block_samples
256
+ )
257
+ interval_control = int(np.ceil(interval_control))
258
+ hidden_states = (
259
+ hidden_states
260
+ + controlnet_single_block_samples[
261
+ index_block // interval_control
262
+ ]
263
+ )
264
+
265
+ temb = pooled_temb[:, -2:]
266
+ hidden_states = self.norm_out(hidden_states, temb)
267
+ output = self.proj_out(hidden_states)
268
+
269
+ if USE_PEFT_BACKEND:
270
+ # remove `lora_scale` from each PEFT layer
271
+ unscale_lora_layers(self, lora_scale)
272
+
273
+ if not return_dict:
274
+ return (output,)
275
+
276
+ return Transformer2DModelOutput(sample=output)
@@ -14,12 +14,60 @@ from diffusers.utils import (
14
14
  unscale_lora_layers,
15
15
  )
16
16
 
17
-
17
+ from cache_dit.cache_factory.patch_functors.functor_base import (
18
+ PatchFunctor,
19
+ )
18
20
  from cache_dit.logger import init_logger
19
21
 
20
22
  logger = init_logger(__name__)
21
23
 
22
24
 
25
+ class FluxPatchFunctor(PatchFunctor):
26
+
27
+ def apply(
28
+ self,
29
+ transformer: FluxTransformer2DModel,
30
+ blocks: torch.nn.ModuleList = None,
31
+ **kwargs,
32
+ ) -> FluxTransformer2DModel:
33
+
34
+ if getattr(transformer, "_is_patched", False):
35
+ return transformer
36
+
37
+ if blocks is None:
38
+ blocks = transformer.single_transformer_blocks
39
+
40
+ is_patched = False
41
+ for block in blocks:
42
+ if isinstance(block, FluxSingleTransformerBlock):
43
+ forward_parameters = inspect.signature(
44
+ block.forward
45
+ ).parameters.keys()
46
+ if "encoder_hidden_states" not in forward_parameters:
47
+ block.forward = __patch_single_forward__.__get__(block)
48
+ is_patched = True
49
+
50
+ if is_patched:
51
+ logger.warning("Patched Flux for cache-dit.")
52
+ assert not getattr(transformer, "_is_parallelized", False), (
53
+ "Please call `cache_dit.enable_cache` before Parallelize, "
54
+ "the __patch_transformer_forward__ will overwrite the "
55
+ "parallized forward and cause a downgrade of performance."
56
+ )
57
+ transformer.forward = __patch_transformer_forward__.__get__(
58
+ transformer
59
+ )
60
+ transformer._is_patched = True
61
+
62
+ cls_name = transformer.__class__.__name__
63
+ logger.info(
64
+ f"Applied {self.__class__.__name__} for {cls_name}, "
65
+ f"Patch: {is_patched}."
66
+ )
67
+
68
+ return transformer
69
+
70
+
23
71
  # copy from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
24
72
  def __patch_single_forward__(
25
73
  self: FluxSingleTransformerBlock,
@@ -217,33 +265,3 @@ def __patch_transformer_forward__(
217
265
  return (output,)
218
266
 
219
267
  return Transformer2DModelOutput(sample=output)
220
-
221
-
222
- def maybe_patch_flux_transformer(
223
- transformer: FluxTransformer2DModel,
224
- blocks: torch.nn.ModuleList = None,
225
- ) -> FluxTransformer2DModel:
226
- if blocks is None:
227
- blocks = transformer.single_transformer_blocks
228
-
229
- is_patched = False
230
- for block in blocks:
231
- if isinstance(block, FluxSingleTransformerBlock):
232
- forward_parameters = inspect.signature(
233
- block.forward
234
- ).parameters.keys()
235
- if "encoder_hidden_states" not in forward_parameters:
236
- block.forward = __patch_single_forward__.__get__(block)
237
- is_patched = True
238
-
239
- if is_patched:
240
- logger.warning("Patched Flux for cache-dit.")
241
- assert not getattr(transformer, "_is_parallelized", False), (
242
- "Please call apply_cache_on_pipe before Parallelize, "
243
- "the __patch_transformer_forward__ will overwrite the "
244
- "parallized forward and cause a downgrade of performance."
245
- )
246
- transformer.forward = __patch_transformer_forward__.__get__(transformer)
247
- transformer._is_patched = True
248
-
249
- return transformer
@@ -10,12 +10,13 @@ logger = init_logger(__name__)
10
10
  def quantize_ao(
11
11
  module: torch.nn.Module,
12
12
  quant_type: str = "fp8_w8a8_dq",
13
- per_row: bool = True,
14
13
  exclude_layers: List[str] = [
15
14
  "embedder",
16
15
  "embed",
17
16
  ],
18
17
  filter_fn: Optional[Callable] = None,
18
+ # paramters for fp8 quantization
19
+ per_row: bool = True,
19
20
  **kwargs,
20
21
  ) -> torch.nn.Module:
21
22
  # Apply FP8 DQ for module and skip any `embed` modules
@@ -89,17 +90,30 @@ def quantize_ao(
89
90
  )
90
91
 
91
92
  quantization_fn = float8_dynamic_activation_float8_weight(
93
+ weight_dtype=kwargs.get(
94
+ "weight_dtype",
95
+ torch.float8_e4m3fn,
96
+ ),
97
+ activation_dtype=kwargs.get(
98
+ "activation_dtype",
99
+ torch.float8_e4m3fn,
100
+ ),
92
101
  granularity=(
93
102
  ((PerRow(), PerRow()))
94
103
  if per_row
95
104
  else ((PerTensor(), PerTensor()))
96
- )
105
+ ),
97
106
  )
98
107
 
99
108
  elif quant_type == "fp8_w8a16_wo":
100
109
  from torchao.quantization import float8_weight_only
101
110
 
102
- quantization_fn = float8_weight_only()
111
+ quantization_fn = float8_weight_only(
112
+ weight_dtype=kwargs.get(
113
+ "weight_dtype",
114
+ torch.float8_e4m3fn,
115
+ ),
116
+ )
103
117
 
104
118
  elif quant_type == "int8_w8a8_dq":
105
119
  from torchao.quantization import (
@@ -159,12 +173,13 @@ def quantize_ao(
159
173
  module,
160
174
  _quantization_fn(),
161
175
  filter_fn=_filter_fn if filter_fn is None else filter_fn,
162
- **kwargs,
176
+ device=kwargs.get("device", None),
163
177
  )
164
178
 
165
179
  force_empty_cache()
166
180
 
167
181
  logger.info(
182
+ f"Quantized Method: {quant_type:>5}\n"
168
183
  f"Quantized Linear Layers: {num_quant_linear:>5}\n"
169
184
  f"Skipped Linear Layers: {num_skip_linear:>5}\n"
170
185
  f"Total Linear Layers: {num_linear_layers:>5}\n"
@@ -9,13 +9,13 @@ def quantize(
9
9
  module: torch.nn.Module,
10
10
  quant_type: str = "fp8_w8a8_dq",
11
11
  backend: str = "ao",
12
- # only for fp8_w8a8_dq
13
- per_row: bool = True,
14
12
  exclude_layers: List[str] = [
15
13
  "embedder",
16
14
  "embed",
17
15
  ],
18
16
  filter_fn: Optional[Callable] = None,
17
+ # only for fp8_w8a8_dq
18
+ per_row: bool = True,
19
19
  **kwargs,
20
20
  ) -> torch.nn.Module:
21
21
  assert isinstance(module, torch.nn.Module)
cache_dit/utils.py CHANGED
@@ -30,26 +30,32 @@ class CacheStats:
30
30
 
31
31
 
32
32
  def summary(
33
- pipe_or_transformer: DiffusionPipeline | torch.nn.Module,
33
+ pipe_or_module: DiffusionPipeline | torch.nn.Module | Any,
34
34
  details: bool = False,
35
35
  logging: bool = True,
36
36
  ) -> CacheStats:
37
37
  cache_stats = CacheStats()
38
- cls_name = pipe_or_transformer.__class__.__name__
39
- if isinstance(pipe_or_transformer, DiffusionPipeline):
40
- transformer = pipe_or_transformer.transformer
38
+
39
+ if not isinstance(pipe_or_module, torch.nn.Module):
40
+ assert hasattr(pipe_or_module, "transformer")
41
+ module = pipe_or_module.transformer
42
+ cls_name = module.__class__.__name__
41
43
  else:
42
- transformer = pipe_or_transformer
44
+ module = pipe_or_module
45
+
46
+ cls_name = module.__class__.__name__
47
+ if isinstance(module, torch.nn.ModuleList):
48
+ cls_name = module[0].__class__.__name__
43
49
 
44
- if hasattr(transformer, "_cache_context_kwargs"):
45
- cache_options = transformer._cache_context_kwargs
50
+ if hasattr(module, "_cache_context_kwargs"):
51
+ cache_options = module._cache_context_kwargs
46
52
  cache_stats.cache_options = cache_options
47
53
  if logging:
48
54
  print(f"\n🤗Cache Options: {cls_name}\n\n{cache_options}")
49
55
 
50
- if hasattr(transformer, "_cached_steps"):
51
- cached_steps: list[int] = transformer._cached_steps
52
- residual_diffs: dict[str, float] = dict(transformer._residual_diffs)
56
+ if hasattr(module, "_cached_steps"):
57
+ cached_steps: list[int] = module._cached_steps
58
+ residual_diffs: dict[str, float] = dict(module._residual_diffs)
53
59
  cache_stats.cached_steps = cached_steps
54
60
  cache_stats.residual_diffs = residual_diffs
55
61
 
@@ -90,11 +96,9 @@ def summary(
90
96
  compact=True,
91
97
  )
92
98
 
93
- if hasattr(transformer, "_cfg_cached_steps"):
94
- cfg_cached_steps: list[int] = transformer._cfg_cached_steps
95
- cfg_residual_diffs: dict[str, float] = dict(
96
- transformer._cfg_residual_diffs
97
- )
99
+ if hasattr(module, "_cfg_cached_steps"):
100
+ cfg_cached_steps: list[int] = module._cfg_cached_steps
101
+ cfg_residual_diffs: dict[str, float] = dict(module._cfg_residual_diffs)
98
102
  cache_stats.cfg_cached_steps = cfg_cached_steps
99
103
  cache_stats.cfg_residual_diffs = cfg_residual_diffs
100
104