cache-dit 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (43) hide show
  1. cache_dit/__init__.py +12 -0
  2. cache_dit/_version.py +16 -3
  3. cache_dit/cache_factory/.gitignore +2 -0
  4. cache_dit/cache_factory/__init__.py +52 -2
  5. cache_dit/cache_factory/cache_adapters.py +654 -0
  6. cache_dit/cache_factory/cache_blocks.py +487 -0
  7. cache_dit/cache_factory/{dual_block_cache/cache_context.py → cache_context.py} +11 -862
  8. cache_dit/cache_factory/patch/flux.py +249 -0
  9. cache_dit/cache_factory/utils.py +1 -1
  10. cache_dit/compile/__init__.py +1 -1
  11. cache_dit/compile/utils.py +1 -1
  12. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/METADATA +87 -204
  13. cache_dit-0.2.17.dist-info/RECORD +30 -0
  14. cache_dit/cache_factory/adapters.py +0 -169
  15. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -55
  16. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -87
  17. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -98
  18. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +0 -294
  19. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -87
  20. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/qwen_image.py +0 -88
  21. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +0 -97
  22. cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
  23. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -51
  24. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -87
  25. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -98
  26. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +0 -294
  27. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -87
  28. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +0 -97
  29. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +0 -1005
  30. cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
  31. cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
  32. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
  33. cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +0 -89
  34. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
  35. cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
  36. cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +0 -89
  37. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
  38. cache_dit-0.2.15.dist-info/RECORD +0 -50
  39. /cache_dit/cache_factory/{dual_block_cache → patch}/__init__.py +0 -0
  40. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/WHEEL +0 -0
  41. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/entry_points.txt +0 -0
  42. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/licenses/LICENSE +0 -0
  43. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/top_level.txt +0 -0
@@ -1,55 +0,0 @@
1
- import importlib
2
-
3
- from diffusers import DiffusionPipeline
4
-
5
-
6
- def apply_db_cache_on_transformer(transformer, *args, **kwargs):
7
- transformer_cls_name: str = transformer.__class__.__name__
8
- if transformer_cls_name.startswith("Flux"):
9
- adapter_name = "flux"
10
- elif transformer_cls_name.startswith("Mochi"):
11
- adapter_name = "mochi"
12
- elif transformer_cls_name.startswith("CogVideoX"):
13
- adapter_name = "cogvideox"
14
- elif transformer_cls_name.startswith("Wan"):
15
- adapter_name = "wan"
16
- elif transformer_cls_name.startswith("HunyuanVideo"):
17
- adapter_name = "hunyuan_video"
18
- elif transformer_cls_name.startswith("QwenImage"):
19
- adapter_name = "qwen_image"
20
- else:
21
- raise ValueError(
22
- f"Unknown transformer class name: {transformer_cls_name}"
23
- )
24
-
25
- adapter_module = importlib.import_module(f".{adapter_name}", __package__)
26
- apply_db_cache_on_transformer_fn = getattr(
27
- adapter_module, "apply_db_cache_on_transformer"
28
- )
29
- return apply_db_cache_on_transformer_fn(transformer, *args, **kwargs)
30
-
31
-
32
- def apply_db_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
33
- assert isinstance(pipe, DiffusionPipeline)
34
-
35
- pipe_cls_name: str = pipe.__class__.__name__
36
- if pipe_cls_name.startswith("Flux"):
37
- adapter_name = "flux"
38
- elif pipe_cls_name.startswith("Mochi"):
39
- adapter_name = "mochi"
40
- elif pipe_cls_name.startswith("CogVideoX"):
41
- adapter_name = "cogvideox"
42
- elif pipe_cls_name.startswith("Wan"):
43
- adapter_name = "wan"
44
- elif pipe_cls_name.startswith("HunyuanVideo"):
45
- adapter_name = "hunyuan_video"
46
- elif pipe_cls_name.startswith("QwenImage"):
47
- adapter_name = "qwen_image"
48
- else:
49
- raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
50
-
51
- adapter_module = importlib.import_module(f".{adapter_name}", __package__)
52
- apply_db_cache_on_pipe_fn = getattr(
53
- adapter_module, "apply_db_cache_on_pipe"
54
- )
55
- return apply_db_cache_on_pipe_fn(pipe, *args, **kwargs)
@@ -1,87 +0,0 @@
1
- import functools
2
- import unittest
3
-
4
- import torch
5
- from diffusers import CogVideoXTransformer3DModel, DiffusionPipeline
6
-
7
- from cache_dit.cache_factory.dual_block_cache import cache_context
8
-
9
-
10
- def apply_db_cache_on_transformer(
11
- transformer: CogVideoXTransformer3DModel,
12
- ):
13
- if getattr(transformer, "_is_cached", False):
14
- return transformer
15
-
16
- cached_transformer_blocks = torch.nn.ModuleList(
17
- [
18
- cache_context.DBCachedTransformerBlocks(
19
- transformer.transformer_blocks,
20
- transformer=transformer,
21
- )
22
- ]
23
- )
24
-
25
- original_forward = transformer.forward
26
-
27
- @functools.wraps(transformer.__class__.forward)
28
- def new_forward(
29
- self,
30
- *args,
31
- **kwargs,
32
- ):
33
- with unittest.mock.patch.object(
34
- self,
35
- "transformer_blocks",
36
- cached_transformer_blocks,
37
- ):
38
- return original_forward(
39
- *args,
40
- **kwargs,
41
- )
42
-
43
- transformer.forward = new_forward.__get__(transformer)
44
-
45
- transformer._is_cached = True
46
-
47
- return transformer
48
-
49
-
50
- def apply_db_cache_on_pipe(
51
- pipe: DiffusionPipeline,
52
- *,
53
- shallow_patch: bool = False,
54
- residual_diff_threshold=0.04,
55
- downsample_factor=1,
56
- warmup_steps=0,
57
- max_cached_steps=-1,
58
- **kwargs,
59
- ):
60
- cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
61
- default_attrs={
62
- "residual_diff_threshold": residual_diff_threshold,
63
- "downsample_factor": downsample_factor,
64
- "warmup_steps": warmup_steps,
65
- "max_cached_steps": max_cached_steps,
66
- },
67
- **kwargs,
68
- )
69
- if not getattr(pipe, "_is_cached", False):
70
- original_call = pipe.__class__.__call__
71
-
72
- @functools.wraps(original_call)
73
- def new_call(self, *args, **kwargs):
74
- with cache_context.cache_context(
75
- cache_context.create_cache_context(
76
- **cache_kwargs,
77
- )
78
- ):
79
- return original_call(self, *args, **kwargs)
80
-
81
- pipe.__class__.__call__ = new_call
82
- pipe.__class__._is_cached = True
83
-
84
- if not shallow_patch:
85
- apply_db_cache_on_transformer(pipe.transformer)
86
-
87
- return pipe
@@ -1,98 +0,0 @@
1
- import functools
2
- import unittest
3
-
4
- import torch
5
- from diffusers import DiffusionPipeline, FluxTransformer2DModel
6
-
7
- from cache_dit.cache_factory.dual_block_cache import cache_context
8
-
9
-
10
- def apply_db_cache_on_transformer(
11
- transformer: FluxTransformer2DModel,
12
- ):
13
- if getattr(transformer, "_is_cached", False):
14
- return transformer
15
-
16
- cached_transformer_blocks = torch.nn.ModuleList(
17
- [
18
- cache_context.DBCachedTransformerBlocks(
19
- transformer.transformer_blocks,
20
- transformer.single_transformer_blocks,
21
- transformer=transformer,
22
- return_hidden_states_first=False,
23
- )
24
- ]
25
- )
26
- dummy_single_transformer_blocks = torch.nn.ModuleList()
27
-
28
- original_forward = transformer.forward
29
-
30
- @functools.wraps(original_forward)
31
- def new_forward(
32
- self,
33
- *args,
34
- **kwargs,
35
- ):
36
- with (
37
- unittest.mock.patch.object(
38
- self,
39
- "transformer_blocks",
40
- cached_transformer_blocks,
41
- ),
42
- unittest.mock.patch.object(
43
- self,
44
- "single_transformer_blocks",
45
- dummy_single_transformer_blocks,
46
- ),
47
- ):
48
- return original_forward(
49
- *args,
50
- **kwargs,
51
- )
52
-
53
- transformer.forward = new_forward.__get__(transformer)
54
-
55
- transformer._is_cached = True
56
-
57
- return transformer
58
-
59
-
60
- def apply_db_cache_on_pipe(
61
- pipe: DiffusionPipeline,
62
- *,
63
- shallow_patch: bool = False,
64
- residual_diff_threshold=0.05,
65
- downsample_factor=1,
66
- warmup_steps=0,
67
- max_cached_steps=-1,
68
- **kwargs,
69
- ):
70
- cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
71
- default_attrs={
72
- "residual_diff_threshold": residual_diff_threshold,
73
- "downsample_factor": downsample_factor,
74
- "warmup_steps": warmup_steps,
75
- "max_cached_steps": max_cached_steps,
76
- },
77
- **kwargs,
78
- )
79
-
80
- if not getattr(pipe, "_is_cached", False):
81
- original_call = pipe.__class__.__call__
82
-
83
- @functools.wraps(original_call)
84
- def new_call(self, *args, **kwargs):
85
- with cache_context.cache_context(
86
- cache_context.create_cache_context(
87
- **cache_kwargs,
88
- )
89
- ):
90
- return original_call(self, *args, **kwargs)
91
-
92
- pipe.__class__.__call__ = new_call
93
- pipe.__class__._is_cached = True
94
-
95
- if not shallow_patch:
96
- apply_db_cache_on_transformer(pipe.transformer)
97
-
98
- return pipe
@@ -1,294 +0,0 @@
1
- import functools
2
- import unittest
3
- from typing import Any, Dict, Optional, Union
4
-
5
- import torch
6
- from diffusers import DiffusionPipeline, HunyuanVideoTransformer3DModel
7
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
8
- from diffusers.utils import (
9
- scale_lora_layers,
10
- unscale_lora_layers,
11
- USE_PEFT_BACKEND,
12
- )
13
-
14
- from cache_dit.cache_factory.dual_block_cache import cache_context
15
- from cache_dit.logger import init_logger
16
-
17
- try:
18
- from para_attn.para_attn_interface import SparseKVAttnMode
19
-
20
- def is_sparse_kv_attn_available():
21
- return True
22
-
23
- except ImportError:
24
-
25
- class SparseKVAttnMode:
26
- def __enter__(self):
27
- pass
28
-
29
- def __exit__(self, exc_type, exc_value, traceback):
30
- pass
31
-
32
- def is_sparse_kv_attn_available():
33
- return False
34
-
35
-
36
- logger = init_logger(__name__) # pylint: disable=invalid-name
37
-
38
-
39
- def apply_db_cache_on_transformer(
40
- transformer: HunyuanVideoTransformer3DModel,
41
- ):
42
- if getattr(transformer, "_is_cached", False):
43
- return transformer
44
-
45
- cached_transformer_blocks = torch.nn.ModuleList(
46
- [
47
- cache_context.DBCachedTransformerBlocks(
48
- transformer.transformer_blocks
49
- + transformer.single_transformer_blocks,
50
- transformer=transformer,
51
- )
52
- ]
53
- )
54
- dummy_single_transformer_blocks = torch.nn.ModuleList()
55
-
56
- original_forward = transformer.forward
57
-
58
- @functools.wraps(transformer.__class__.forward)
59
- def new_forward(
60
- self,
61
- hidden_states: torch.Tensor,
62
- timestep: torch.LongTensor,
63
- encoder_hidden_states: torch.Tensor,
64
- encoder_attention_mask: torch.Tensor,
65
- pooled_projections: torch.Tensor,
66
- guidance: torch.Tensor = None,
67
- attention_kwargs: Optional[Dict[str, Any]] = None,
68
- return_dict: bool = True,
69
- **kwargs,
70
- ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
71
- with (
72
- unittest.mock.patch.object(
73
- self,
74
- "transformer_blocks",
75
- cached_transformer_blocks,
76
- ),
77
- unittest.mock.patch.object(
78
- self,
79
- "single_transformer_blocks",
80
- dummy_single_transformer_blocks,
81
- ),
82
- ):
83
- if getattr(self, "_is_parallelized", False):
84
- return original_forward(
85
- hidden_states,
86
- timestep,
87
- encoder_hidden_states,
88
- encoder_attention_mask,
89
- pooled_projections,
90
- guidance=guidance,
91
- attention_kwargs=attention_kwargs,
92
- return_dict=return_dict,
93
- **kwargs,
94
- )
95
- else:
96
- if attention_kwargs is not None:
97
- attention_kwargs = attention_kwargs.copy()
98
- lora_scale = attention_kwargs.pop("scale", 1.0)
99
- else:
100
- lora_scale = 1.0
101
-
102
- if USE_PEFT_BACKEND:
103
- # weight the lora layers by setting `lora_scale` for each PEFT layer
104
- scale_lora_layers(self, lora_scale)
105
- else:
106
- if (
107
- attention_kwargs is not None
108
- and attention_kwargs.get("scale", None) is not None
109
- ):
110
- logger.warning(
111
- "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
112
- )
113
-
114
- batch_size, num_channels, num_frames, height, width = (
115
- hidden_states.shape
116
- )
117
- p, p_t = self.config.patch_size, self.config.patch_size_t
118
- post_patch_num_frames = num_frames // p_t
119
- post_patch_height = height // p
120
- post_patch_width = width // p
121
-
122
- # 1. RoPE
123
- image_rotary_emb = self.rope(hidden_states)
124
-
125
- # 2. Conditional embeddings
126
- temb = self.time_text_embed(
127
- timestep, guidance, pooled_projections
128
- )
129
- hidden_states = self.x_embedder(hidden_states)
130
- encoder_hidden_states = self.context_embedder(
131
- encoder_hidden_states, timestep, encoder_attention_mask
132
- )
133
-
134
- # 3. Attention mask preparation
135
- latent_sequence_length = hidden_states.shape[1]
136
- latent_attention_mask = torch.ones(
137
- batch_size,
138
- 1,
139
- latent_sequence_length,
140
- device=hidden_states.device,
141
- dtype=torch.bool,
142
- ) # [B, 1, N]
143
- attention_mask = torch.cat(
144
- [
145
- latent_attention_mask,
146
- encoder_attention_mask.unsqueeze(1).to(torch.bool),
147
- ],
148
- dim=-1,
149
- ) # [B, 1, N + M]
150
-
151
- with SparseKVAttnMode():
152
- # 4. Transformer blocks
153
- hidden_states, encoder_hidden_states = (
154
- self.call_transformer_blocks(
155
- hidden_states,
156
- encoder_hidden_states,
157
- temb,
158
- attention_mask,
159
- image_rotary_emb,
160
- )
161
- )
162
-
163
- # 5. Output projection
164
- hidden_states = self.norm_out(hidden_states, temb)
165
- hidden_states = self.proj_out(hidden_states)
166
-
167
- hidden_states = hidden_states.reshape(
168
- batch_size,
169
- post_patch_num_frames,
170
- post_patch_height,
171
- post_patch_width,
172
- -1,
173
- p_t,
174
- p,
175
- p,
176
- )
177
- hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
178
- hidden_states = (
179
- hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
180
- )
181
-
182
- hidden_states = hidden_states.to(timestep.dtype)
183
-
184
- if USE_PEFT_BACKEND:
185
- # remove `lora_scale` from each PEFT layer
186
- unscale_lora_layers(self, lora_scale)
187
-
188
- if not return_dict:
189
- return (hidden_states,)
190
-
191
- return Transformer2DModelOutput(sample=hidden_states)
192
-
193
- transformer.forward = new_forward.__get__(transformer)
194
-
195
- def call_transformer_blocks(
196
- self, hidden_states, encoder_hidden_states, *args, **kwargs
197
- ):
198
- if torch.is_grad_enabled() and self.gradient_checkpointing:
199
-
200
- def create_custom_forward(module, return_dict=None):
201
- def custom_forward(*inputs):
202
- if return_dict is not None:
203
- return module(*inputs, return_dict=return_dict)
204
- else:
205
- return module(*inputs)
206
-
207
- return custom_forward
208
-
209
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}
210
-
211
- for block in self.transformer_blocks:
212
- hidden_states, encoder_hidden_states = (
213
- torch.utils.checkpoint.checkpoint(
214
- create_custom_forward(block),
215
- hidden_states,
216
- encoder_hidden_states,
217
- *args,
218
- **kwargs,
219
- **ckpt_kwargs,
220
- )
221
- )
222
-
223
- for block in self.single_transformer_blocks:
224
- hidden_states, encoder_hidden_states = (
225
- torch.utils.checkpoint.checkpoint(
226
- create_custom_forward(block),
227
- hidden_states,
228
- encoder_hidden_states,
229
- *args,
230
- **kwargs,
231
- **ckpt_kwargs,
232
- )
233
- )
234
-
235
- else:
236
- for block in self.transformer_blocks:
237
- hidden_states, encoder_hidden_states = block(
238
- hidden_states, encoder_hidden_states, *args, **kwargs
239
- )
240
-
241
- for block in self.single_transformer_blocks:
242
- hidden_states, encoder_hidden_states = block(
243
- hidden_states, encoder_hidden_states, *args, **kwargs
244
- )
245
-
246
- return hidden_states, encoder_hidden_states
247
-
248
- transformer.call_transformer_blocks = call_transformer_blocks.__get__(
249
- transformer
250
- )
251
-
252
- transformer._is_cached = True
253
-
254
- return transformer
255
-
256
-
257
- def apply_db_cache_on_pipe(
258
- pipe: DiffusionPipeline,
259
- *,
260
- shallow_patch: bool = False,
261
- residual_diff_threshold=0.06,
262
- downsample_factor=1,
263
- warmup_steps=0,
264
- max_cached_steps=-1,
265
- **kwargs,
266
- ):
267
- cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
268
- default_attrs={
269
- "residual_diff_threshold": residual_diff_threshold,
270
- "downsample_factor": downsample_factor,
271
- "warmup_steps": warmup_steps,
272
- "max_cached_steps": max_cached_steps,
273
- },
274
- **kwargs,
275
- )
276
- if not getattr(pipe, "_is_cached", False):
277
- original_call = pipe.__class__.__call__
278
-
279
- @functools.wraps(original_call)
280
- def new_call(self, *args, **kwargs):
281
- with cache_context.cache_context(
282
- cache_context.create_cache_context(
283
- **cache_kwargs,
284
- )
285
- ):
286
- return original_call(self, *args, **kwargs)
287
-
288
- pipe.__class__.__call__ = new_call
289
- pipe.__class__._is_cached = True
290
-
291
- if not shallow_patch:
292
- apply_db_cache_on_transformer(pipe.transformer)
293
-
294
- return pipe
@@ -1,87 +0,0 @@
1
- import functools
2
- import unittest
3
-
4
- import torch
5
- from diffusers import DiffusionPipeline, MochiTransformer3DModel
6
-
7
- from cache_dit.cache_factory.dual_block_cache import cache_context
8
-
9
-
10
- def apply_db_cache_on_transformer(
11
- transformer: MochiTransformer3DModel,
12
- ):
13
- if getattr(transformer, "_is_cached", False):
14
- return transformer
15
-
16
- cached_transformer_blocks = torch.nn.ModuleList(
17
- [
18
- cache_context.DBCachedTransformerBlocks(
19
- transformer.transformer_blocks,
20
- transformer=transformer,
21
- )
22
- ]
23
- )
24
-
25
- original_forward = transformer.forward
26
-
27
- @functools.wraps(transformer.__class__.forward)
28
- def new_forward(
29
- self,
30
- *args,
31
- **kwargs,
32
- ):
33
- with unittest.mock.patch.object(
34
- self,
35
- "transformer_blocks",
36
- cached_transformer_blocks,
37
- ):
38
- return original_forward(
39
- *args,
40
- **kwargs,
41
- )
42
-
43
- transformer.forward = new_forward.__get__(transformer)
44
-
45
- transformer._is_cached = True
46
-
47
- return transformer
48
-
49
-
50
- def apply_db_cache_on_pipe(
51
- pipe: DiffusionPipeline,
52
- *,
53
- shallow_patch: bool = False,
54
- residual_diff_threshold=0.06,
55
- downsample_factor=1,
56
- warmup_steps=0,
57
- max_cached_steps=-1,
58
- **kwargs,
59
- ):
60
- cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
61
- default_attrs={
62
- "residual_diff_threshold": residual_diff_threshold,
63
- "downsample_factor": downsample_factor,
64
- "warmup_steps": warmup_steps,
65
- "max_cached_steps": max_cached_steps,
66
- },
67
- **kwargs,
68
- )
69
- if not getattr(pipe, "_is_cached", False):
70
- original_call = pipe.__class__.__call__
71
-
72
- @functools.wraps(original_call)
73
- def new_call(self, *args, **kwargs):
74
- with cache_context.cache_context(
75
- cache_context.create_cache_context(
76
- **cache_kwargs,
77
- )
78
- ):
79
- return original_call(self, *args, **kwargs)
80
-
81
- pipe.__class__.__call__ = new_call
82
- pipe.__class__._is_cached = True
83
-
84
- if not shallow_patch:
85
- apply_db_cache_on_transformer(pipe.transformer)
86
-
87
- return pipe
@@ -1,88 +0,0 @@
1
- import functools
2
- import unittest
3
-
4
- import torch
5
- from diffusers import QwenImagePipeline, QwenImageTransformer2DModel
6
-
7
- from cache_dit.cache_factory.dual_block_cache import cache_context
8
-
9
-
10
- def apply_db_cache_on_transformer(
11
- transformer: QwenImageTransformer2DModel,
12
- ):
13
- if getattr(transformer, "_is_cached", False):
14
- return transformer
15
-
16
- transformer_blocks = torch.nn.ModuleList(
17
- [
18
- cache_context.DBCachedTransformerBlocks(
19
- transformer.transformer_blocks,
20
- transformer=transformer,
21
- return_hidden_states_first=False,
22
- )
23
- ]
24
- )
25
-
26
- original_forward = transformer.forward
27
-
28
- @functools.wraps(transformer.__class__.forward)
29
- def new_forward(
30
- self,
31
- *args,
32
- **kwargs,
33
- ):
34
- with unittest.mock.patch.object(
35
- self,
36
- "transformer_blocks",
37
- transformer_blocks,
38
- ):
39
- return original_forward(
40
- *args,
41
- **kwargs,
42
- )
43
-
44
- transformer.forward = new_forward.__get__(transformer)
45
-
46
- transformer._is_cached = True
47
-
48
- return transformer
49
-
50
-
51
- def apply_db_cache_on_pipe(
52
- pipe: QwenImagePipeline,
53
- *,
54
- shallow_patch: bool = False,
55
- residual_diff_threshold=0.06,
56
- downsample_factor=1,
57
- warmup_steps=0,
58
- max_cached_steps=-1,
59
- **kwargs,
60
- ):
61
- cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
62
- default_attrs={
63
- "residual_diff_threshold": residual_diff_threshold,
64
- "downsample_factor": downsample_factor,
65
- "warmup_steps": warmup_steps,
66
- "max_cached_steps": max_cached_steps,
67
- },
68
- **kwargs,
69
- )
70
- if not getattr(pipe, "_is_cached", False):
71
- original_call = pipe.__class__.__call__
72
-
73
- @functools.wraps(original_call)
74
- def new_call(self, *args, **kwargs):
75
- with cache_context.cache_context(
76
- cache_context.create_cache_context(
77
- **cache_kwargs,
78
- )
79
- ):
80
- return original_call(self, *args, **kwargs)
81
-
82
- pipe.__class__.__call__ = new_call
83
- pipe.__class__._is_cached = True
84
-
85
- if not shallow_patch:
86
- apply_db_cache_on_transformer(pipe.transformer)
87
-
88
- return pipe