cache-dit 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/__init__.py +95 -61
- cache_dit/cache_factory/block_adapters/block_adapters.py +27 -6
- cache_dit/cache_factory/block_adapters/block_registers.py +10 -7
- cache_dit/cache_factory/cache_adapters.py +177 -66
- cache_dit/cache_factory/cache_blocks/__init__.py +3 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +70 -67
- cache_dit/cache_factory/cache_blocks/pattern_base.py +13 -0
- cache_dit/cache_factory/cache_contexts/cache_manager.py +8 -10
- cache_dit/cache_factory/cache_interface.py +19 -77
- cache_dit/cache_factory/cache_types.py +5 -5
- cache_dit/cache_factory/patch_functors/__init__.py +6 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +5 -3
- cache_dit/cache_factory/patch_functors/functor_flux.py +5 -3
- cache_dit/cache_factory/patch_functors/functor_hidream.py +412 -0
- cache_dit/cache_factory/patch_functors/functor_hunyuan_dit.py +213 -0
- cache_dit/utils.py +5 -1
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/METADATA +14 -48
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/RECORD +23 -21
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
+
from typing import Dict, Any
|
|
3
4
|
from cache_dit.cache_factory import ForwardPattern
|
|
4
5
|
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
5
6
|
CachedBlocks_Pattern_Base,
|
|
@@ -31,7 +32,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
31
32
|
# Call first `n` blocks to process the hidden states for
|
|
32
33
|
# more stable diff calculation.
|
|
33
34
|
# encoder_hidden_states: None Pattern 3, else 4, 5
|
|
34
|
-
hidden_states,
|
|
35
|
+
hidden_states, new_encoder_hidden_states = self.call_Fn_blocks(
|
|
35
36
|
hidden_states,
|
|
36
37
|
*args,
|
|
37
38
|
**kwargs,
|
|
@@ -60,11 +61,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
60
61
|
if can_use_cache:
|
|
61
62
|
self.cache_manager.add_cached_step()
|
|
62
63
|
del Fn_hidden_states_residual
|
|
63
|
-
hidden_states,
|
|
64
|
+
hidden_states, new_encoder_hidden_states = (
|
|
64
65
|
self.cache_manager.apply_cache(
|
|
65
66
|
hidden_states,
|
|
66
|
-
#
|
|
67
|
-
encoder_hidden_states,
|
|
67
|
+
new_encoder_hidden_states, # encoder_hidden_states not use cache
|
|
68
68
|
prefix=(
|
|
69
69
|
f"{self.cache_prefix}_Bn_residual"
|
|
70
70
|
if self.cache_manager.is_cache_residual()
|
|
@@ -80,12 +80,12 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
80
80
|
torch._dynamo.graph_break()
|
|
81
81
|
# Call last `n` blocks to further process the hidden states
|
|
82
82
|
# for higher precision.
|
|
83
|
-
|
|
84
|
-
hidden_states,
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
83
|
+
if self.cache_manager.Bn_compute_blocks() > 0:
|
|
84
|
+
hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
|
|
85
|
+
hidden_states,
|
|
86
|
+
*args,
|
|
87
|
+
**kwargs,
|
|
88
|
+
)
|
|
89
89
|
else:
|
|
90
90
|
self.cache_manager.set_Fn_buffer(
|
|
91
91
|
Fn_hidden_states_residual,
|
|
@@ -99,19 +99,20 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
99
99
|
)
|
|
100
100
|
del Fn_hidden_states_residual
|
|
101
101
|
torch._dynamo.graph_break()
|
|
102
|
+
old_encoder_hidden_states = new_encoder_hidden_states
|
|
102
103
|
(
|
|
103
104
|
hidden_states,
|
|
104
|
-
|
|
105
|
+
new_encoder_hidden_states,
|
|
105
106
|
hidden_states_residual,
|
|
106
|
-
# None Pattern 3, else 4, 5
|
|
107
|
-
encoder_hidden_states_residual,
|
|
108
107
|
) = self.call_Mn_blocks( # middle
|
|
109
108
|
hidden_states,
|
|
110
|
-
# None Pattern 3, else 4, 5
|
|
111
|
-
encoder_hidden_states,
|
|
112
109
|
*args,
|
|
113
110
|
**kwargs,
|
|
114
111
|
)
|
|
112
|
+
if new_encoder_hidden_states is not None:
|
|
113
|
+
new_encoder_hidden_states_residual = (
|
|
114
|
+
new_encoder_hidden_states - old_encoder_hidden_states
|
|
115
|
+
)
|
|
115
116
|
torch._dynamo.graph_break()
|
|
116
117
|
if self.cache_manager.is_cache_residual():
|
|
117
118
|
self.cache_manager.set_Bn_buffer(
|
|
@@ -119,34 +120,32 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
119
120
|
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
120
121
|
)
|
|
121
122
|
else:
|
|
122
|
-
# TaylorSeer
|
|
123
123
|
self.cache_manager.set_Bn_buffer(
|
|
124
124
|
hidden_states,
|
|
125
125
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
126
126
|
)
|
|
127
|
+
|
|
127
128
|
if self.cache_manager.is_encoder_cache_residual():
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
129
|
+
if new_encoder_hidden_states is not None:
|
|
130
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
131
|
+
new_encoder_hidden_states_residual,
|
|
132
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
133
|
+
)
|
|
133
134
|
else:
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
)
|
|
135
|
+
if new_encoder_hidden_states is not None:
|
|
136
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
137
|
+
new_encoder_hidden_states_residual,
|
|
138
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
139
|
+
)
|
|
140
140
|
torch._dynamo.graph_break()
|
|
141
141
|
# Call last `n` blocks to further process the hidden states
|
|
142
142
|
# for higher precision.
|
|
143
|
-
|
|
144
|
-
hidden_states,
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
)
|
|
143
|
+
if self.cache_manager.Bn_compute_blocks() > 0:
|
|
144
|
+
hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
|
|
145
|
+
hidden_states,
|
|
146
|
+
*args,
|
|
147
|
+
**kwargs,
|
|
148
|
+
)
|
|
150
149
|
|
|
151
150
|
torch._dynamo.graph_break()
|
|
152
151
|
|
|
@@ -154,12 +153,21 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
154
153
|
hidden_states
|
|
155
154
|
if self.forward_pattern.Return_H_Only
|
|
156
155
|
else (
|
|
157
|
-
(hidden_states,
|
|
156
|
+
(hidden_states, new_encoder_hidden_states)
|
|
158
157
|
if self.forward_pattern.Return_H_First
|
|
159
|
-
else (
|
|
158
|
+
else (new_encoder_hidden_states, hidden_states)
|
|
160
159
|
)
|
|
161
160
|
)
|
|
162
161
|
|
|
162
|
+
@torch.compiler.disable
|
|
163
|
+
def maybe_update_kwargs(
|
|
164
|
+
self, encoder_hidden_states, kwargs: Dict[str, Any]
|
|
165
|
+
) -> Dict[str, Any]:
|
|
166
|
+
# if "encoder_hidden_states" in kwargs:
|
|
167
|
+
# kwargs["encoder_hidden_states"] = encoder_hidden_states
|
|
168
|
+
# return kwargs
|
|
169
|
+
return kwargs
|
|
170
|
+
|
|
163
171
|
def call_Fn_blocks(
|
|
164
172
|
self,
|
|
165
173
|
hidden_states: torch.Tensor,
|
|
@@ -172,7 +180,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
172
180
|
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
173
181
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
174
182
|
)
|
|
175
|
-
|
|
183
|
+
new_encoder_hidden_states = None
|
|
176
184
|
for block in self._Fn_blocks():
|
|
177
185
|
hidden_states = block(
|
|
178
186
|
hidden_states,
|
|
@@ -180,25 +188,27 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
180
188
|
**kwargs,
|
|
181
189
|
)
|
|
182
190
|
if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
|
|
183
|
-
hidden_states,
|
|
191
|
+
hidden_states, new_encoder_hidden_states = hidden_states
|
|
184
192
|
if not self.forward_pattern.Return_H_First:
|
|
185
|
-
hidden_states,
|
|
186
|
-
|
|
193
|
+
hidden_states, new_encoder_hidden_states = (
|
|
194
|
+
new_encoder_hidden_states,
|
|
187
195
|
hidden_states,
|
|
188
196
|
)
|
|
197
|
+
kwargs = self.maybe_update_kwargs(
|
|
198
|
+
new_encoder_hidden_states,
|
|
199
|
+
kwargs,
|
|
200
|
+
)
|
|
189
201
|
|
|
190
|
-
return hidden_states,
|
|
202
|
+
return hidden_states, new_encoder_hidden_states
|
|
191
203
|
|
|
192
204
|
def call_Mn_blocks(
|
|
193
205
|
self,
|
|
194
206
|
hidden_states: torch.Tensor,
|
|
195
|
-
# None Pattern 3, else 4, 5
|
|
196
|
-
encoder_hidden_states: torch.Tensor | None,
|
|
197
207
|
*args,
|
|
198
208
|
**kwargs,
|
|
199
209
|
):
|
|
200
210
|
original_hidden_states = hidden_states
|
|
201
|
-
|
|
211
|
+
new_encoder_hidden_states = None
|
|
202
212
|
for block in self._Mn_blocks():
|
|
203
213
|
hidden_states = block(
|
|
204
214
|
hidden_states,
|
|
@@ -206,44 +216,33 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
206
216
|
**kwargs,
|
|
207
217
|
)
|
|
208
218
|
if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
|
|
209
|
-
hidden_states,
|
|
219
|
+
hidden_states, new_encoder_hidden_states = hidden_states
|
|
210
220
|
if not self.forward_pattern.Return_H_First:
|
|
211
|
-
hidden_states,
|
|
212
|
-
|
|
221
|
+
hidden_states, new_encoder_hidden_states = (
|
|
222
|
+
new_encoder_hidden_states,
|
|
213
223
|
hidden_states,
|
|
214
224
|
)
|
|
225
|
+
kwargs = self.maybe_update_kwargs(
|
|
226
|
+
new_encoder_hidden_states,
|
|
227
|
+
kwargs,
|
|
228
|
+
)
|
|
215
229
|
|
|
216
230
|
# compute hidden_states residual
|
|
217
231
|
hidden_states = hidden_states.contiguous()
|
|
218
232
|
hidden_states_residual = hidden_states - original_hidden_states
|
|
219
|
-
if (
|
|
220
|
-
original_encoder_hidden_states is not None
|
|
221
|
-
and encoder_hidden_states is not None
|
|
222
|
-
): # Pattern 4, 5
|
|
223
|
-
encoder_hidden_states_residual = (
|
|
224
|
-
encoder_hidden_states - original_encoder_hidden_states
|
|
225
|
-
)
|
|
226
|
-
else:
|
|
227
|
-
encoder_hidden_states_residual = None # Pattern 3
|
|
228
233
|
|
|
229
234
|
return (
|
|
230
235
|
hidden_states,
|
|
231
|
-
|
|
236
|
+
new_encoder_hidden_states,
|
|
232
237
|
hidden_states_residual,
|
|
233
|
-
encoder_hidden_states_residual,
|
|
234
238
|
)
|
|
235
239
|
|
|
236
240
|
def call_Bn_blocks(
|
|
237
241
|
self,
|
|
238
242
|
hidden_states: torch.Tensor,
|
|
239
|
-
# None Pattern 3, else 4, 5
|
|
240
|
-
encoder_hidden_states: torch.Tensor | None,
|
|
241
243
|
*args,
|
|
242
244
|
**kwargs,
|
|
243
245
|
):
|
|
244
|
-
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
245
|
-
return hidden_states, encoder_hidden_states
|
|
246
|
-
|
|
247
246
|
assert self.cache_manager.Bn_compute_blocks() <= len(
|
|
248
247
|
self.transformer_blocks
|
|
249
248
|
), (
|
|
@@ -264,11 +263,15 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
264
263
|
**kwargs,
|
|
265
264
|
)
|
|
266
265
|
if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
|
|
267
|
-
hidden_states,
|
|
266
|
+
hidden_states, new_encoder_hidden_states = hidden_states
|
|
268
267
|
if not self.forward_pattern.Return_H_First:
|
|
269
|
-
hidden_states,
|
|
270
|
-
|
|
268
|
+
hidden_states, new_encoder_hidden_states = (
|
|
269
|
+
new_encoder_hidden_states,
|
|
271
270
|
hidden_states,
|
|
272
271
|
)
|
|
272
|
+
kwargs = self.maybe_update_kwargs(
|
|
273
|
+
new_encoder_hidden_states,
|
|
274
|
+
kwargs,
|
|
275
|
+
)
|
|
273
276
|
|
|
274
|
-
return hidden_states,
|
|
277
|
+
return hidden_states, new_encoder_hidden_states
|
|
@@ -25,6 +25,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
25
25
|
transformer_blocks: torch.nn.ModuleList,
|
|
26
26
|
transformer: torch.nn.Module = None,
|
|
27
27
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
28
|
+
check_forward_pattern: bool = True,
|
|
28
29
|
check_num_outputs: bool = True,
|
|
29
30
|
# 1. Cache context configuration
|
|
30
31
|
cache_prefix: str = None, # maybe un-need.
|
|
@@ -38,6 +39,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
38
39
|
self.transformer = transformer
|
|
39
40
|
self.transformer_blocks = transformer_blocks
|
|
40
41
|
self.forward_pattern = forward_pattern
|
|
42
|
+
self.check_forward_pattern = check_forward_pattern
|
|
41
43
|
self.check_num_outputs = check_num_outputs
|
|
42
44
|
# 1. Cache context configuration
|
|
43
45
|
self.cache_prefix = cache_prefix
|
|
@@ -52,6 +54,12 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
52
54
|
)
|
|
53
55
|
|
|
54
56
|
def _check_forward_pattern(self):
|
|
57
|
+
if not self.check_forward_pattern:
|
|
58
|
+
logger.warning(
|
|
59
|
+
f"Skipped Forward Pattern Check: {self.forward_pattern}"
|
|
60
|
+
)
|
|
61
|
+
return
|
|
62
|
+
|
|
55
63
|
assert (
|
|
56
64
|
self.forward_pattern.Supported
|
|
57
65
|
and self.forward_pattern in self._supported_patterns
|
|
@@ -59,6 +67,11 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
59
67
|
|
|
60
68
|
if self.transformer_blocks is not None:
|
|
61
69
|
for block in self.transformer_blocks:
|
|
70
|
+
# Special case for HiDreamBlock
|
|
71
|
+
if hasattr(block, "block"):
|
|
72
|
+
if isinstance(block.block, torch.nn.Module):
|
|
73
|
+
block = block.block
|
|
74
|
+
|
|
62
75
|
forward_parameters = set(
|
|
63
76
|
inspect.signature(block.forward).parameters.keys()
|
|
64
77
|
)
|
|
@@ -733,17 +733,15 @@ class CachedContextManager:
|
|
|
733
733
|
encoder_prefix
|
|
734
734
|
)
|
|
735
735
|
|
|
736
|
-
|
|
737
|
-
encoder_hidden_states_prev is not None
|
|
738
|
-
), f"{prefix}_encoder_buffer must be set before"
|
|
736
|
+
if encoder_hidden_states_prev is not None:
|
|
739
737
|
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
738
|
+
if self.is_encoder_cache_residual():
|
|
739
|
+
encoder_hidden_states = (
|
|
740
|
+
encoder_hidden_states_prev + encoder_hidden_states
|
|
741
|
+
)
|
|
742
|
+
else:
|
|
743
|
+
# If encoder cache is not residual, we use the encoder hidden states directly
|
|
744
|
+
encoder_hidden_states = encoder_hidden_states_prev
|
|
747
745
|
|
|
748
746
|
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
749
747
|
|
|
@@ -1,11 +1,9 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import Any, Tuple, List
|
|
1
|
+
from typing import Any, Tuple, List, Union
|
|
3
2
|
from diffusers import DiffusionPipeline
|
|
4
3
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
5
4
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
6
5
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
7
6
|
from cache_dit.cache_factory.cache_adapters import CachedAdapter
|
|
8
|
-
from cache_dit.cache_factory.cache_contexts import CachedContextManager
|
|
9
7
|
|
|
10
8
|
from cache_dit.logger import init_logger
|
|
11
9
|
|
|
@@ -14,7 +12,10 @@ logger = init_logger(__name__)
|
|
|
14
12
|
|
|
15
13
|
def enable_cache(
|
|
16
14
|
# DiffusionPipeline or BlockAdapter
|
|
17
|
-
pipe_or_adapter:
|
|
15
|
+
pipe_or_adapter: Union[
|
|
16
|
+
DiffusionPipeline,
|
|
17
|
+
BlockAdapter,
|
|
18
|
+
],
|
|
18
19
|
# Cache context kwargs
|
|
19
20
|
Fn_compute_blocks: int = 8,
|
|
20
21
|
Bn_compute_blocks: int = 0,
|
|
@@ -32,7 +33,10 @@ def enable_cache(
|
|
|
32
33
|
taylorseer_cache_type: str = "residual",
|
|
33
34
|
taylorseer_order: int = 2,
|
|
34
35
|
**other_cache_context_kwargs,
|
|
35
|
-
) ->
|
|
36
|
+
) -> Union[
|
|
37
|
+
DiffusionPipeline,
|
|
38
|
+
BlockAdapter,
|
|
39
|
+
]:
|
|
36
40
|
r"""
|
|
37
41
|
Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
|
|
38
42
|
that match the specific Input and Output patterns).
|
|
@@ -102,11 +106,11 @@ def enable_cache(
|
|
|
102
106
|
>>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
|
|
103
107
|
>>> output = pipe(...) # Just call the pipe as normal.
|
|
104
108
|
>>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
|
|
109
|
+
>>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
|
|
105
110
|
"""
|
|
106
|
-
|
|
107
111
|
# Collect cache context kwargs
|
|
108
112
|
cache_context_kwargs = other_cache_context_kwargs.copy()
|
|
109
|
-
if cache_type := cache_context_kwargs.get("cache_type", None):
|
|
113
|
+
if (cache_type := cache_context_kwargs.get("cache_type", None)) is not None:
|
|
110
114
|
if cache_type == CacheType.NONE:
|
|
111
115
|
return pipe_or_adapter
|
|
112
116
|
|
|
@@ -145,79 +149,17 @@ def enable_cache(
|
|
|
145
149
|
|
|
146
150
|
|
|
147
151
|
def disable_cache(
|
|
148
|
-
|
|
149
|
-
|
|
152
|
+
pipe_or_adapter: Union[
|
|
153
|
+
DiffusionPipeline,
|
|
154
|
+
BlockAdapter,
|
|
155
|
+
],
|
|
150
156
|
):
|
|
151
|
-
|
|
152
|
-
|
|
157
|
+
CachedAdapter.maybe_release_hooks(pipe_or_adapter)
|
|
158
|
+
logger.warning(
|
|
159
|
+
f"Cache Acceleration is disabled for: "
|
|
160
|
+
f"{pipe_or_adapter.__class__.__name__}."
|
|
153
161
|
)
|
|
154
162
|
|
|
155
|
-
def _disable_blocks(blocks: torch.nn.ModuleList):
|
|
156
|
-
if blocks is None:
|
|
157
|
-
return
|
|
158
|
-
if hasattr(blocks, "_forward_pattern"):
|
|
159
|
-
del blocks._forward_pattern
|
|
160
|
-
if hasattr(blocks, "_cache_context_kwargs"):
|
|
161
|
-
del blocks._cache_context_kwargs
|
|
162
|
-
remove_cached_stats(blocks)
|
|
163
|
-
|
|
164
|
-
def _disable_transformer(transformer: torch.nn.Module):
|
|
165
|
-
if transformer is None or not BlockAdapter.is_cached(transformer):
|
|
166
|
-
return
|
|
167
|
-
if original_forward := getattr(transformer, "_original_forward"):
|
|
168
|
-
transformer.forward = original_forward.__get__(transformer)
|
|
169
|
-
del transformer._original_forward
|
|
170
|
-
if hasattr(transformer, "_is_cached"):
|
|
171
|
-
del transformer._is_cached
|
|
172
|
-
if hasattr(transformer, "_forward_pattern"):
|
|
173
|
-
del transformer._forward_pattern
|
|
174
|
-
if hasattr(transformer, "_has_separate_cfg"):
|
|
175
|
-
del transformer._has_separate_cfg
|
|
176
|
-
if hasattr(transformer, "_cache_context_kwargs"):
|
|
177
|
-
del transformer._cache_context_kwargs
|
|
178
|
-
remove_cached_stats(transformer)
|
|
179
|
-
for blocks in BlockAdapter.find_blocks(transformer):
|
|
180
|
-
_disable_blocks(blocks)
|
|
181
|
-
|
|
182
|
-
def _disable_pipe(pipe: DiffusionPipeline):
|
|
183
|
-
if pipe is None or not BlockAdapter.is_cached(pipe):
|
|
184
|
-
return
|
|
185
|
-
if original_call := getattr(pipe, "_original_call"):
|
|
186
|
-
pipe.__class__.__call__ = original_call
|
|
187
|
-
del pipe.__class__._original_call
|
|
188
|
-
if cache_manager := getattr(pipe, "_cache_manager"):
|
|
189
|
-
assert isinstance(cache_manager, CachedContextManager)
|
|
190
|
-
cache_manager.clear_contexts()
|
|
191
|
-
del pipe._cache_manager
|
|
192
|
-
if hasattr(pipe, "_is_cached"):
|
|
193
|
-
del pipe.__class__._is_cached
|
|
194
|
-
if hasattr(pipe, "_cache_context_kwargs"):
|
|
195
|
-
del pipe._cache_context_kwargs
|
|
196
|
-
remove_cached_stats(pipe)
|
|
197
|
-
|
|
198
|
-
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
199
|
-
pipe = pipe_or_adapter
|
|
200
|
-
_disable_pipe(pipe)
|
|
201
|
-
if hasattr(pipe, "transformer"):
|
|
202
|
-
_disable_transformer(pipe.transformer)
|
|
203
|
-
if hasattr(pipe, "transformer_2"): # Wan 2.2
|
|
204
|
-
_disable_transformer(pipe.transformer_2)
|
|
205
|
-
pipe_cls_name = pipe.__class__.__name__
|
|
206
|
-
logger.warning(f"Cache Acceleration is disabled for: {pipe_cls_name}")
|
|
207
|
-
elif isinstance(pipe_or_adapter, BlockAdapter):
|
|
208
|
-
# BlockAdapter
|
|
209
|
-
adapter = pipe_or_adapter
|
|
210
|
-
BlockAdapter.assert_normalized(adapter)
|
|
211
|
-
_disable_pipe(adapter.pipe)
|
|
212
|
-
for transformer in BlockAdapter.flatten(adapter.transformer):
|
|
213
|
-
_disable_transformer(transformer)
|
|
214
|
-
for blocks in BlockAdapter.flatten(adapter.blocks):
|
|
215
|
-
_disable_blocks(blocks)
|
|
216
|
-
pipe_cls_name = adapter.pipe.__class__.__name__
|
|
217
|
-
logger.warning(f"Cache Acceleration is disabled for: {pipe_cls_name}")
|
|
218
|
-
else:
|
|
219
|
-
pass # do nothing
|
|
220
|
-
|
|
221
163
|
|
|
222
164
|
def supported_pipelines(
|
|
223
165
|
**kwargs,
|
|
@@ -22,11 +22,11 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
|
|
|
22
22
|
if isinstance(type_hint, CacheType):
|
|
23
23
|
return type_hint
|
|
24
24
|
|
|
25
|
-
elif type_hint.
|
|
26
|
-
"
|
|
27
|
-
"
|
|
28
|
-
"
|
|
29
|
-
"
|
|
25
|
+
elif type_hint.upper() in (
|
|
26
|
+
"DUAL_BLOCK_CACHE",
|
|
27
|
+
"DB_CACHE",
|
|
28
|
+
"DBCACHE",
|
|
29
|
+
"DB",
|
|
30
30
|
):
|
|
31
31
|
return CacheType.DBCache
|
|
32
32
|
return CacheType.NONE
|
|
@@ -3,3 +3,9 @@ from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
|
|
|
3
3
|
from cache_dit.cache_factory.patch_functors.functor_chroma import (
|
|
4
4
|
ChromaPatchFunctor,
|
|
5
5
|
)
|
|
6
|
+
from cache_dit.cache_factory.patch_functors.functor_hidream import (
|
|
7
|
+
HiDreamPatchFunctor,
|
|
8
|
+
)
|
|
9
|
+
from cache_dit.cache_factory.patch_functors.functor_hunyuan_dit import (
|
|
10
|
+
HunyuanDiTPatchFunctor,
|
|
11
|
+
)
|
|
@@ -46,8 +46,10 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
46
46
|
block.forward = __patch_single_forward__.__get__(block)
|
|
47
47
|
is_patched = True
|
|
48
48
|
|
|
49
|
+
cls_name = transformer.__class__.__name__
|
|
50
|
+
|
|
49
51
|
if is_patched:
|
|
50
|
-
logger.warning("Patched
|
|
52
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
51
53
|
assert not getattr(transformer, "_is_parallelized", False), (
|
|
52
54
|
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
53
55
|
"the __patch_transformer_forward__ will overwrite the "
|
|
@@ -56,9 +58,9 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
56
58
|
transformer.forward = __patch_transformer_forward__.__get__(
|
|
57
59
|
transformer
|
|
58
60
|
)
|
|
59
|
-
transformer._is_patched = True
|
|
60
61
|
|
|
61
|
-
|
|
62
|
+
transformer._is_patched = is_patched # True or False
|
|
63
|
+
|
|
62
64
|
logger.info(
|
|
63
65
|
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
64
66
|
f"Patch: {is_patched}."
|
|
@@ -47,8 +47,10 @@ class FluxPatchFunctor(PatchFunctor):
|
|
|
47
47
|
block.forward = __patch_single_forward__.__get__(block)
|
|
48
48
|
is_patched = True
|
|
49
49
|
|
|
50
|
+
cls_name = transformer.__class__.__name__
|
|
51
|
+
|
|
50
52
|
if is_patched:
|
|
51
|
-
logger.warning("Patched
|
|
53
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
52
54
|
assert not getattr(transformer, "_is_parallelized", False), (
|
|
53
55
|
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
54
56
|
"the __patch_transformer_forward__ will overwrite the "
|
|
@@ -57,9 +59,9 @@ class FluxPatchFunctor(PatchFunctor):
|
|
|
57
59
|
transformer.forward = __patch_transformer_forward__.__get__(
|
|
58
60
|
transformer
|
|
59
61
|
)
|
|
60
|
-
transformer._is_patched = True
|
|
61
62
|
|
|
62
|
-
|
|
63
|
+
transformer._is_patched = is_patched # True or False
|
|
64
|
+
|
|
63
65
|
logger.info(
|
|
64
66
|
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
65
67
|
f"Patch: {is_patched}."
|