cache-dit 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (43) hide show
  1. cache_dit/__init__.py +12 -0
  2. cache_dit/_version.py +16 -3
  3. cache_dit/cache_factory/.gitignore +2 -0
  4. cache_dit/cache_factory/__init__.py +52 -2
  5. cache_dit/cache_factory/cache_adapters.py +654 -0
  6. cache_dit/cache_factory/cache_blocks.py +487 -0
  7. cache_dit/cache_factory/{dual_block_cache/cache_context.py → cache_context.py} +11 -862
  8. cache_dit/cache_factory/patch/flux.py +249 -0
  9. cache_dit/cache_factory/utils.py +1 -1
  10. cache_dit/compile/__init__.py +1 -1
  11. cache_dit/compile/utils.py +1 -1
  12. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/METADATA +87 -204
  13. cache_dit-0.2.17.dist-info/RECORD +30 -0
  14. cache_dit/cache_factory/adapters.py +0 -169
  15. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -55
  16. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -87
  17. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -98
  18. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +0 -294
  19. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -87
  20. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/qwen_image.py +0 -88
  21. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +0 -97
  22. cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
  23. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -51
  24. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -87
  25. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -98
  26. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +0 -294
  27. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -87
  28. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +0 -97
  29. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +0 -1005
  30. cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
  31. cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
  32. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
  33. cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +0 -89
  34. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
  35. cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
  36. cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +0 -89
  37. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
  38. cache_dit-0.2.15.dist-info/RECORD +0 -50
  39. /cache_dit/cache_factory/{dual_block_cache → patch}/__init__.py +0 -0
  40. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/WHEEL +0 -0
  41. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/entry_points.txt +0 -0
  42. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/licenses/LICENSE +0 -0
  43. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,249 @@
1
+ import inspect
2
+
3
+ import torch
4
+ import numpy as np
5
+ from typing import Tuple, Optional, Dict, Any, Union
6
+ from diffusers import FluxTransformer2DModel
7
+ from diffusers.models.transformers.transformer_flux import (
8
+ FluxSingleTransformerBlock,
9
+ Transformer2DModelOutput,
10
+ )
11
+ from diffusers.utils import (
12
+ USE_PEFT_BACKEND,
13
+ scale_lora_layers,
14
+ unscale_lora_layers,
15
+ )
16
+
17
+
18
+ from cache_dit.logger import init_logger
19
+
20
+ logger = init_logger(__name__)
21
+
22
+
23
+ # copy from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
24
+ def __patch_single_forward__(
25
+ self: FluxSingleTransformerBlock,
26
+ hidden_states: torch.Tensor,
27
+ encoder_hidden_states: torch.Tensor,
28
+ temb: torch.Tensor,
29
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
30
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
31
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
32
+ text_seq_len = encoder_hidden_states.shape[1]
33
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
34
+
35
+ residual = hidden_states
36
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
37
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
38
+ joint_attention_kwargs = joint_attention_kwargs or {}
39
+ attn_output = self.attn(
40
+ hidden_states=norm_hidden_states,
41
+ image_rotary_emb=image_rotary_emb,
42
+ **joint_attention_kwargs,
43
+ )
44
+
45
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
46
+ gate = gate.unsqueeze(1)
47
+ hidden_states = gate * self.proj_out(hidden_states)
48
+ hidden_states = residual + hidden_states
49
+ if hidden_states.dtype == torch.float16:
50
+ hidden_states = hidden_states.clip(-65504, 65504)
51
+
52
+ encoder_hidden_states, hidden_states = (
53
+ hidden_states[:, :text_seq_len],
54
+ hidden_states[:, text_seq_len:],
55
+ )
56
+ return encoder_hidden_states, hidden_states
57
+
58
+
59
+ # copy from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L631
60
+ def __patch_transformer_forward__(
61
+ self: FluxTransformer2DModel,
62
+ hidden_states: torch.Tensor,
63
+ encoder_hidden_states: torch.Tensor = None,
64
+ pooled_projections: torch.Tensor = None,
65
+ timestep: torch.LongTensor = None,
66
+ img_ids: torch.Tensor = None,
67
+ txt_ids: torch.Tensor = None,
68
+ guidance: torch.Tensor = None,
69
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
70
+ controlnet_block_samples=None,
71
+ controlnet_single_block_samples=None,
72
+ return_dict: bool = True,
73
+ controlnet_blocks_repeat: bool = False,
74
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
75
+ if joint_attention_kwargs is not None:
76
+ joint_attention_kwargs = joint_attention_kwargs.copy()
77
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
78
+ else:
79
+ lora_scale = 1.0
80
+
81
+ if USE_PEFT_BACKEND:
82
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
83
+ scale_lora_layers(self, lora_scale)
84
+ else:
85
+ if (
86
+ joint_attention_kwargs is not None
87
+ and joint_attention_kwargs.get("scale", None) is not None
88
+ ):
89
+ logger.warning(
90
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
91
+ )
92
+
93
+ hidden_states = self.x_embedder(hidden_states)
94
+
95
+ timestep = timestep.to(hidden_states.dtype) * 1000
96
+ if guidance is not None:
97
+ guidance = guidance.to(hidden_states.dtype) * 1000
98
+
99
+ temb = (
100
+ self.time_text_embed(timestep, pooled_projections)
101
+ if guidance is None
102
+ else self.time_text_embed(timestep, guidance, pooled_projections)
103
+ )
104
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
105
+
106
+ if txt_ids.ndim == 3:
107
+ logger.warning(
108
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
109
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
110
+ )
111
+ txt_ids = txt_ids[0]
112
+ if img_ids.ndim == 3:
113
+ logger.warning(
114
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
115
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
116
+ )
117
+ img_ids = img_ids[0]
118
+
119
+ ids = torch.cat((txt_ids, img_ids), dim=0)
120
+ image_rotary_emb = self.pos_embed(ids)
121
+
122
+ if (
123
+ joint_attention_kwargs is not None
124
+ and "ip_adapter_image_embeds" in joint_attention_kwargs
125
+ ):
126
+ ip_adapter_image_embeds = joint_attention_kwargs.pop(
127
+ "ip_adapter_image_embeds"
128
+ )
129
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
130
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
131
+
132
+ for index_block, block in enumerate(self.transformer_blocks):
133
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
134
+ encoder_hidden_states, hidden_states = (
135
+ self._gradient_checkpointing_func(
136
+ block,
137
+ hidden_states,
138
+ encoder_hidden_states,
139
+ temb,
140
+ image_rotary_emb,
141
+ joint_attention_kwargs,
142
+ )
143
+ )
144
+
145
+ else:
146
+ encoder_hidden_states, hidden_states = block(
147
+ hidden_states=hidden_states,
148
+ encoder_hidden_states=encoder_hidden_states,
149
+ temb=temb,
150
+ image_rotary_emb=image_rotary_emb,
151
+ joint_attention_kwargs=joint_attention_kwargs,
152
+ )
153
+
154
+ # controlnet residual
155
+ if controlnet_block_samples is not None:
156
+ interval_control = len(self.transformer_blocks) / len(
157
+ controlnet_block_samples
158
+ )
159
+ interval_control = int(np.ceil(interval_control))
160
+ # For Xlabs ControlNet.
161
+ if controlnet_blocks_repeat:
162
+ hidden_states = (
163
+ hidden_states
164
+ + controlnet_block_samples[
165
+ index_block % len(controlnet_block_samples)
166
+ ]
167
+ )
168
+ else:
169
+ hidden_states = (
170
+ hidden_states
171
+ + controlnet_block_samples[index_block // interval_control]
172
+ )
173
+
174
+ for index_block, block in enumerate(self.single_transformer_blocks):
175
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
176
+ encoder_hidden_states, hidden_states = (
177
+ self._gradient_checkpointing_func(
178
+ block,
179
+ hidden_states,
180
+ encoder_hidden_states,
181
+ temb,
182
+ image_rotary_emb,
183
+ joint_attention_kwargs,
184
+ )
185
+ )
186
+
187
+ else:
188
+ encoder_hidden_states, hidden_states = block(
189
+ hidden_states=hidden_states,
190
+ encoder_hidden_states=encoder_hidden_states,
191
+ temb=temb,
192
+ image_rotary_emb=image_rotary_emb,
193
+ joint_attention_kwargs=joint_attention_kwargs,
194
+ )
195
+
196
+ # controlnet residual
197
+ if controlnet_single_block_samples is not None:
198
+ interval_control = len(self.single_transformer_blocks) / len(
199
+ controlnet_single_block_samples
200
+ )
201
+ interval_control = int(np.ceil(interval_control))
202
+ hidden_states = (
203
+ hidden_states
204
+ + controlnet_single_block_samples[
205
+ index_block // interval_control
206
+ ]
207
+ )
208
+
209
+ hidden_states = self.norm_out(hidden_states, temb)
210
+ output = self.proj_out(hidden_states)
211
+
212
+ if USE_PEFT_BACKEND:
213
+ # remove `lora_scale` from each PEFT layer
214
+ unscale_lora_layers(self, lora_scale)
215
+
216
+ if not return_dict:
217
+ return (output,)
218
+
219
+ return Transformer2DModelOutput(sample=output)
220
+
221
+
222
+ def maybe_patch_flux_transformer(
223
+ transformer: FluxTransformer2DModel,
224
+ blocks: torch.nn.ModuleList = None,
225
+ ) -> FluxTransformer2DModel:
226
+ if blocks is None:
227
+ blocks = transformer.single_transformer_blocks
228
+
229
+ is_patched = False
230
+ for block in blocks:
231
+ if isinstance(block, FluxSingleTransformerBlock):
232
+ forward_parameters = inspect.signature(
233
+ blocks.forward
234
+ ).parameters.keys()
235
+ if "encoder_hidden_states" not in forward_parameters:
236
+ block.forward = __patch_single_forward__.__get__(block)
237
+ is_patched = True
238
+
239
+ if is_patched:
240
+ logger.warning("Patched Flux for cache-dit.")
241
+ assert not getattr(transformer, "_is_parallelized", False), (
242
+ "Please call apply_cache_on_pipe before Parallelize, "
243
+ "the __patch_transformer_forward__ will overwrite the "
244
+ "parallized forward and cause a downgrade of performance."
245
+ )
246
+ transformer.forward = __patch_transformer_forward__.__get__(transformer)
247
+ transformer._is_patched = True
248
+
249
+ return transformer
@@ -1,5 +1,5 @@
1
1
  import yaml
2
- from cache_dit.cache_factory.adapters import CacheType
2
+ from cache_dit.cache_factory.cache_adapters import CacheType
3
3
 
4
4
 
5
5
  def load_cache_options_from_yaml(yaml_file_path):
@@ -1 +1 @@
1
- from cache_dit.compile.utils import set_custom_compile_configs
1
+ from cache_dit.compile.utils import set_compile_configs
@@ -23,7 +23,7 @@ def epilogue_prologue_fusion_enabled(**kwargs) -> bool:
23
23
  return CACHE_DIT_EPILOGUE_PROLOGUE_FUSION or mode
24
24
 
25
25
 
26
- def set_custom_compile_configs(
26
+ def set_compile_configs(
27
27
  cuda_graphs: bool = False,
28
28
  force_disable_compile_caches: bool = False,
29
29
  use_fast_math: bool = False,