cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/offload_utils.py +115 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/caching/cache_interface.py +358 -0
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
- cache_dit/caching/utils.py +68 -0
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/metrics/metrics.py +3 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- cache_dit-1.0.14.dist-info/METADATA +301 -0
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
- cache_dit/cache_factory/cache_blocks/utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
- cache_dit/cache_factory/cache_interface.py +0 -217
- cache_dit/cache_factory/patch_functors/__init__.py +0 -12
- cache_dit/cache_factory/utils.py +0 -57
- cache_dit-0.3.2.dist-info/METADATA +0 -753
- cache_dit-0.3.2.dist-info/RECORD +0 -56
- cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -1,404 +0,0 @@
|
|
|
1
|
-
import inspect
|
|
2
|
-
import torch
|
|
3
|
-
import torch.distributed as dist
|
|
4
|
-
|
|
5
|
-
from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
|
|
6
|
-
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
7
|
-
CachedContextManager,
|
|
8
|
-
)
|
|
9
|
-
from cache_dit.cache_factory import ForwardPattern
|
|
10
|
-
from cache_dit.logger import init_logger
|
|
11
|
-
|
|
12
|
-
logger = init_logger(__name__)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
16
|
-
_supported_patterns = [
|
|
17
|
-
ForwardPattern.Pattern_0,
|
|
18
|
-
ForwardPattern.Pattern_1,
|
|
19
|
-
ForwardPattern.Pattern_2,
|
|
20
|
-
]
|
|
21
|
-
|
|
22
|
-
def __init__(
|
|
23
|
-
self,
|
|
24
|
-
# 0. Transformer blocks configuration
|
|
25
|
-
transformer_blocks: torch.nn.ModuleList,
|
|
26
|
-
transformer: torch.nn.Module = None,
|
|
27
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
28
|
-
check_forward_pattern: bool = True,
|
|
29
|
-
check_num_outputs: bool = True,
|
|
30
|
-
# 1. Cache context configuration
|
|
31
|
-
cache_prefix: str = None, # maybe un-need.
|
|
32
|
-
cache_context: CachedContext | str = None,
|
|
33
|
-
cache_manager: CachedContextManager = None,
|
|
34
|
-
**kwargs,
|
|
35
|
-
):
|
|
36
|
-
super().__init__()
|
|
37
|
-
|
|
38
|
-
# 0. Transformer blocks configuration
|
|
39
|
-
self.transformer = transformer
|
|
40
|
-
self.transformer_blocks = transformer_blocks
|
|
41
|
-
self.forward_pattern = forward_pattern
|
|
42
|
-
self.check_forward_pattern = check_forward_pattern
|
|
43
|
-
self.check_num_outputs = check_num_outputs
|
|
44
|
-
# 1. Cache context configuration
|
|
45
|
-
self.cache_prefix = cache_prefix
|
|
46
|
-
self.cache_context = cache_context
|
|
47
|
-
self.cache_manager = cache_manager
|
|
48
|
-
|
|
49
|
-
self._check_forward_pattern()
|
|
50
|
-
logger.info(
|
|
51
|
-
f"Match Cached Blocks: {self.__class__.__name__}, for "
|
|
52
|
-
f"{self.cache_prefix}, cache_context: {self.cache_context}, "
|
|
53
|
-
f"cache_manager: {self.cache_manager.name}."
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
def _check_forward_pattern(self):
|
|
57
|
-
if not self.check_forward_pattern:
|
|
58
|
-
logger.warning(
|
|
59
|
-
f"Skipped Forward Pattern Check: {self.forward_pattern}"
|
|
60
|
-
)
|
|
61
|
-
return
|
|
62
|
-
|
|
63
|
-
assert (
|
|
64
|
-
self.forward_pattern.Supported
|
|
65
|
-
and self.forward_pattern in self._supported_patterns
|
|
66
|
-
), f"Pattern {self.forward_pattern} is not supported now!"
|
|
67
|
-
|
|
68
|
-
if self.transformer_blocks is not None:
|
|
69
|
-
for block in self.transformer_blocks:
|
|
70
|
-
# Special case for HiDreamBlock
|
|
71
|
-
if hasattr(block, "block"):
|
|
72
|
-
if isinstance(block.block, torch.nn.Module):
|
|
73
|
-
block = block.block
|
|
74
|
-
|
|
75
|
-
forward_parameters = set(
|
|
76
|
-
inspect.signature(block.forward).parameters.keys()
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
if self.check_num_outputs:
|
|
80
|
-
num_outputs = str(
|
|
81
|
-
inspect.signature(block.forward).return_annotation
|
|
82
|
-
).count("torch.Tensor")
|
|
83
|
-
|
|
84
|
-
if num_outputs > 0:
|
|
85
|
-
assert len(self.forward_pattern.Out) == num_outputs, (
|
|
86
|
-
f"The number of block's outputs is {num_outputs} don't not "
|
|
87
|
-
f"match the number of the pattern: {self.forward_pattern}, "
|
|
88
|
-
f"Out: {len(self.forward_pattern.Out)}."
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
for required_param in self.forward_pattern.In:
|
|
92
|
-
assert (
|
|
93
|
-
required_param in forward_parameters
|
|
94
|
-
), f"The input parameters must contains: {required_param}."
|
|
95
|
-
|
|
96
|
-
@torch.compiler.disable
|
|
97
|
-
def _check_cache_params(self):
|
|
98
|
-
assert self.cache_manager.Fn_compute_blocks() <= len(
|
|
99
|
-
self.transformer_blocks
|
|
100
|
-
), (
|
|
101
|
-
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
102
|
-
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
103
|
-
)
|
|
104
|
-
assert self.cache_manager.Bn_compute_blocks() <= len(
|
|
105
|
-
self.transformer_blocks
|
|
106
|
-
), (
|
|
107
|
-
f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
|
|
108
|
-
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
def forward(
|
|
112
|
-
self,
|
|
113
|
-
hidden_states: torch.Tensor,
|
|
114
|
-
encoder_hidden_states: torch.Tensor,
|
|
115
|
-
*args,
|
|
116
|
-
**kwargs,
|
|
117
|
-
):
|
|
118
|
-
# Use it's own cache context.
|
|
119
|
-
self.cache_manager.set_context(self.cache_context)
|
|
120
|
-
self._check_cache_params()
|
|
121
|
-
|
|
122
|
-
original_hidden_states = hidden_states
|
|
123
|
-
# Call first `n` blocks to process the hidden states for
|
|
124
|
-
# more stable diff calculation.
|
|
125
|
-
hidden_states, encoder_hidden_states = self.call_Fn_blocks(
|
|
126
|
-
hidden_states,
|
|
127
|
-
encoder_hidden_states,
|
|
128
|
-
*args,
|
|
129
|
-
**kwargs,
|
|
130
|
-
)
|
|
131
|
-
|
|
132
|
-
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
133
|
-
del original_hidden_states
|
|
134
|
-
|
|
135
|
-
self.cache_manager.mark_step_begin()
|
|
136
|
-
# Residual L1 diff or Hidden States L1 diff
|
|
137
|
-
can_use_cache = self.cache_manager.can_cache(
|
|
138
|
-
(
|
|
139
|
-
Fn_hidden_states_residual
|
|
140
|
-
if not self.cache_manager.is_l1_diff_enabled()
|
|
141
|
-
else hidden_states
|
|
142
|
-
),
|
|
143
|
-
parallelized=self._is_parallelized(),
|
|
144
|
-
prefix=(
|
|
145
|
-
f"{self.cache_prefix}_Fn_residual"
|
|
146
|
-
if not self.cache_manager.is_l1_diff_enabled()
|
|
147
|
-
else f"{self.cache_prefix}_Fn_hidden_states"
|
|
148
|
-
),
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
torch._dynamo.graph_break()
|
|
152
|
-
if can_use_cache:
|
|
153
|
-
self.cache_manager.add_cached_step()
|
|
154
|
-
del Fn_hidden_states_residual
|
|
155
|
-
hidden_states, encoder_hidden_states = (
|
|
156
|
-
self.cache_manager.apply_cache(
|
|
157
|
-
hidden_states,
|
|
158
|
-
encoder_hidden_states,
|
|
159
|
-
prefix=(
|
|
160
|
-
f"{self.cache_prefix}_Bn_residual"
|
|
161
|
-
if self.cache_manager.is_cache_residual()
|
|
162
|
-
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
163
|
-
),
|
|
164
|
-
encoder_prefix=(
|
|
165
|
-
f"{self.cache_prefix}_Bn_residual"
|
|
166
|
-
if self.cache_manager.is_encoder_cache_residual()
|
|
167
|
-
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
168
|
-
),
|
|
169
|
-
)
|
|
170
|
-
)
|
|
171
|
-
torch._dynamo.graph_break()
|
|
172
|
-
# Call last `n` blocks to further process the hidden states
|
|
173
|
-
# for higher precision.
|
|
174
|
-
hidden_states, encoder_hidden_states = self.call_Bn_blocks(
|
|
175
|
-
hidden_states,
|
|
176
|
-
encoder_hidden_states,
|
|
177
|
-
*args,
|
|
178
|
-
**kwargs,
|
|
179
|
-
)
|
|
180
|
-
else:
|
|
181
|
-
self.cache_manager.set_Fn_buffer(
|
|
182
|
-
Fn_hidden_states_residual,
|
|
183
|
-
prefix=f"{self.cache_prefix}_Fn_residual",
|
|
184
|
-
)
|
|
185
|
-
if self.cache_manager.is_l1_diff_enabled():
|
|
186
|
-
# for hidden states L1 diff
|
|
187
|
-
self.cache_manager.set_Fn_buffer(
|
|
188
|
-
hidden_states,
|
|
189
|
-
f"{self.cache_prefix}_Fn_hidden_states",
|
|
190
|
-
)
|
|
191
|
-
del Fn_hidden_states_residual
|
|
192
|
-
torch._dynamo.graph_break()
|
|
193
|
-
(
|
|
194
|
-
hidden_states,
|
|
195
|
-
encoder_hidden_states,
|
|
196
|
-
hidden_states_residual,
|
|
197
|
-
encoder_hidden_states_residual,
|
|
198
|
-
) = self.call_Mn_blocks( # middle
|
|
199
|
-
hidden_states,
|
|
200
|
-
encoder_hidden_states,
|
|
201
|
-
*args,
|
|
202
|
-
**kwargs,
|
|
203
|
-
)
|
|
204
|
-
torch._dynamo.graph_break()
|
|
205
|
-
if self.cache_manager.is_cache_residual():
|
|
206
|
-
self.cache_manager.set_Bn_buffer(
|
|
207
|
-
hidden_states_residual,
|
|
208
|
-
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
209
|
-
)
|
|
210
|
-
else:
|
|
211
|
-
self.cache_manager.set_Bn_buffer(
|
|
212
|
-
hidden_states,
|
|
213
|
-
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
214
|
-
)
|
|
215
|
-
|
|
216
|
-
if self.cache_manager.is_encoder_cache_residual():
|
|
217
|
-
self.cache_manager.set_Bn_encoder_buffer(
|
|
218
|
-
encoder_hidden_states_residual,
|
|
219
|
-
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
220
|
-
)
|
|
221
|
-
else:
|
|
222
|
-
self.cache_manager.set_Bn_encoder_buffer(
|
|
223
|
-
encoder_hidden_states,
|
|
224
|
-
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
225
|
-
)
|
|
226
|
-
torch._dynamo.graph_break()
|
|
227
|
-
# Call last `n` blocks to further process the hidden states
|
|
228
|
-
# for higher precision.
|
|
229
|
-
hidden_states, encoder_hidden_states = self.call_Bn_blocks(
|
|
230
|
-
hidden_states,
|
|
231
|
-
encoder_hidden_states,
|
|
232
|
-
*args,
|
|
233
|
-
**kwargs,
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
# patch cached stats for blocks or remove it.
|
|
237
|
-
torch._dynamo.graph_break()
|
|
238
|
-
|
|
239
|
-
return (
|
|
240
|
-
hidden_states
|
|
241
|
-
if self.forward_pattern.Return_H_Only
|
|
242
|
-
else (
|
|
243
|
-
(hidden_states, encoder_hidden_states)
|
|
244
|
-
if self.forward_pattern.Return_H_First
|
|
245
|
-
else (encoder_hidden_states, hidden_states)
|
|
246
|
-
)
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
@torch.compiler.disable
|
|
250
|
-
def _is_parallelized(self):
|
|
251
|
-
# Compatible with distributed inference.
|
|
252
|
-
return any(
|
|
253
|
-
(
|
|
254
|
-
all(
|
|
255
|
-
(
|
|
256
|
-
self.transformer is not None,
|
|
257
|
-
getattr(self.transformer, "_is_parallelized", False),
|
|
258
|
-
)
|
|
259
|
-
),
|
|
260
|
-
(dist.is_initialized() and dist.get_world_size() > 1),
|
|
261
|
-
)
|
|
262
|
-
)
|
|
263
|
-
|
|
264
|
-
@torch.compiler.disable
|
|
265
|
-
def _is_in_cache_step(self):
|
|
266
|
-
# Check if the current step is in cache steps.
|
|
267
|
-
# If so, we can skip some Bn blocks and directly
|
|
268
|
-
# use the cached values.
|
|
269
|
-
return (
|
|
270
|
-
self.cache_manager.get_current_step()
|
|
271
|
-
in self.cache_manager.get_cached_steps()
|
|
272
|
-
) or (
|
|
273
|
-
self.cache_manager.get_current_step()
|
|
274
|
-
in self.cache_manager.get_cfg_cached_steps()
|
|
275
|
-
)
|
|
276
|
-
|
|
277
|
-
@torch.compiler.disable
|
|
278
|
-
def _Fn_blocks(self):
|
|
279
|
-
# Select first `n` blocks to process the hidden states for
|
|
280
|
-
# more stable diff calculation.
|
|
281
|
-
# Fn: [0,...,n-1]
|
|
282
|
-
selected_Fn_blocks = self.transformer_blocks[
|
|
283
|
-
: self.cache_manager.Fn_compute_blocks()
|
|
284
|
-
]
|
|
285
|
-
return selected_Fn_blocks
|
|
286
|
-
|
|
287
|
-
@torch.compiler.disable
|
|
288
|
-
def _Mn_blocks(self): # middle blocks
|
|
289
|
-
# M(N-2n): only transformer_blocks [n,...,N-n], middle
|
|
290
|
-
if self.cache_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
291
|
-
selected_Mn_blocks = self.transformer_blocks[
|
|
292
|
-
self.cache_manager.Fn_compute_blocks() :
|
|
293
|
-
]
|
|
294
|
-
else:
|
|
295
|
-
selected_Mn_blocks = self.transformer_blocks[
|
|
296
|
-
self.cache_manager.Fn_compute_blocks() : -self.cache_manager.Bn_compute_blocks()
|
|
297
|
-
]
|
|
298
|
-
return selected_Mn_blocks
|
|
299
|
-
|
|
300
|
-
@torch.compiler.disable
|
|
301
|
-
def _Bn_blocks(self):
|
|
302
|
-
# Bn: transformer_blocks [N-n+1,...,N-1]
|
|
303
|
-
selected_Bn_blocks = self.transformer_blocks[
|
|
304
|
-
-self.cache_manager.Bn_compute_blocks() :
|
|
305
|
-
]
|
|
306
|
-
return selected_Bn_blocks
|
|
307
|
-
|
|
308
|
-
def call_Fn_blocks(
|
|
309
|
-
self,
|
|
310
|
-
hidden_states: torch.Tensor,
|
|
311
|
-
encoder_hidden_states: torch.Tensor,
|
|
312
|
-
*args,
|
|
313
|
-
**kwargs,
|
|
314
|
-
):
|
|
315
|
-
for block in self._Fn_blocks():
|
|
316
|
-
hidden_states = block(
|
|
317
|
-
hidden_states,
|
|
318
|
-
encoder_hidden_states,
|
|
319
|
-
*args,
|
|
320
|
-
**kwargs,
|
|
321
|
-
)
|
|
322
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
323
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
324
|
-
if not self.forward_pattern.Return_H_First:
|
|
325
|
-
hidden_states, encoder_hidden_states = (
|
|
326
|
-
encoder_hidden_states,
|
|
327
|
-
hidden_states,
|
|
328
|
-
)
|
|
329
|
-
|
|
330
|
-
return hidden_states, encoder_hidden_states
|
|
331
|
-
|
|
332
|
-
def call_Mn_blocks(
|
|
333
|
-
self,
|
|
334
|
-
hidden_states: torch.Tensor,
|
|
335
|
-
encoder_hidden_states: torch.Tensor,
|
|
336
|
-
*args,
|
|
337
|
-
**kwargs,
|
|
338
|
-
):
|
|
339
|
-
original_hidden_states = hidden_states
|
|
340
|
-
original_encoder_hidden_states = encoder_hidden_states
|
|
341
|
-
for block in self._Mn_blocks():
|
|
342
|
-
hidden_states = block(
|
|
343
|
-
hidden_states,
|
|
344
|
-
encoder_hidden_states,
|
|
345
|
-
*args,
|
|
346
|
-
**kwargs,
|
|
347
|
-
)
|
|
348
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
349
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
350
|
-
if not self.forward_pattern.Return_H_First:
|
|
351
|
-
hidden_states, encoder_hidden_states = (
|
|
352
|
-
encoder_hidden_states,
|
|
353
|
-
hidden_states,
|
|
354
|
-
)
|
|
355
|
-
|
|
356
|
-
# compute hidden_states residual
|
|
357
|
-
hidden_states = hidden_states.contiguous()
|
|
358
|
-
|
|
359
|
-
hidden_states_residual = hidden_states - original_hidden_states
|
|
360
|
-
|
|
361
|
-
if (
|
|
362
|
-
encoder_hidden_states is not None
|
|
363
|
-
and original_encoder_hidden_states is not None
|
|
364
|
-
):
|
|
365
|
-
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
366
|
-
encoder_hidden_states_residual = (
|
|
367
|
-
encoder_hidden_states - original_encoder_hidden_states
|
|
368
|
-
)
|
|
369
|
-
else:
|
|
370
|
-
encoder_hidden_states_residual = None
|
|
371
|
-
|
|
372
|
-
return (
|
|
373
|
-
hidden_states,
|
|
374
|
-
encoder_hidden_states,
|
|
375
|
-
hidden_states_residual,
|
|
376
|
-
encoder_hidden_states_residual,
|
|
377
|
-
)
|
|
378
|
-
|
|
379
|
-
def call_Bn_blocks(
|
|
380
|
-
self,
|
|
381
|
-
hidden_states: torch.Tensor,
|
|
382
|
-
encoder_hidden_states: torch.Tensor,
|
|
383
|
-
*args,
|
|
384
|
-
**kwargs,
|
|
385
|
-
):
|
|
386
|
-
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
387
|
-
return hidden_states, encoder_hidden_states
|
|
388
|
-
|
|
389
|
-
for block in self._Bn_blocks():
|
|
390
|
-
hidden_states = block(
|
|
391
|
-
hidden_states,
|
|
392
|
-
encoder_hidden_states,
|
|
393
|
-
*args,
|
|
394
|
-
**kwargs,
|
|
395
|
-
)
|
|
396
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
397
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
398
|
-
if not self.forward_pattern.Return_H_First:
|
|
399
|
-
hidden_states, encoder_hidden_states = (
|
|
400
|
-
encoder_hidden_states,
|
|
401
|
-
hidden_states,
|
|
402
|
-
)
|
|
403
|
-
|
|
404
|
-
return hidden_states, encoder_hidden_states
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
from typing import Any
|
|
4
|
-
from cache_dit.cache_factory import CachedContext
|
|
5
|
-
from cache_dit.cache_factory import CachedContextManager
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def patch_cached_stats(
|
|
9
|
-
module: torch.nn.Module | Any,
|
|
10
|
-
cache_context: CachedContext | str = None,
|
|
11
|
-
cache_manager: CachedContextManager = None,
|
|
12
|
-
):
|
|
13
|
-
# Patch the cached stats to the module, the cached stats
|
|
14
|
-
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
15
|
-
if module is None or cache_manager is None:
|
|
16
|
-
return
|
|
17
|
-
|
|
18
|
-
if cache_context is not None:
|
|
19
|
-
cache_manager.set_context(cache_context)
|
|
20
|
-
|
|
21
|
-
# TODO: Patch more cached stats to the module
|
|
22
|
-
module._cached_steps = cache_manager.get_cached_steps()
|
|
23
|
-
module._residual_diffs = cache_manager.get_residual_diffs()
|
|
24
|
-
module._cfg_cached_steps = cache_manager.get_cfg_cached_steps()
|
|
25
|
-
module._cfg_residual_diffs = cache_manager.get_cfg_residual_diffs()
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def remove_cached_stats(
|
|
29
|
-
module: torch.nn.Module | Any,
|
|
30
|
-
):
|
|
31
|
-
if module is None:
|
|
32
|
-
return
|
|
33
|
-
|
|
34
|
-
if hasattr(module, "_cached_steps"):
|
|
35
|
-
del module._cached_steps
|
|
36
|
-
if hasattr(module, "_residual_diffs"):
|
|
37
|
-
del module._residual_diffs
|
|
38
|
-
if hasattr(module, "_cfg_cached_steps"):
|
|
39
|
-
del module._cfg_cached_steps
|
|
40
|
-
if hasattr(module, "_cfg_residual_diffs"):
|
|
41
|
-
del module._cfg_residual_diffs
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
from cache_dit.cache_factory.cache_contexts.calibrators import (
|
|
2
|
-
Calibrator,
|
|
3
|
-
CalibratorBase,
|
|
4
|
-
CalibratorConfig,
|
|
5
|
-
TaylorSeerCalibratorConfig,
|
|
6
|
-
FoCaCalibratorConfig,
|
|
7
|
-
)
|
|
8
|
-
from cache_dit.cache_factory.cache_contexts.cache_context import (
|
|
9
|
-
CachedContext,
|
|
10
|
-
BasicCacheConfig,
|
|
11
|
-
)
|
|
12
|
-
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
13
|
-
CachedContextManager,
|
|
14
|
-
)
|
|
@@ -1,217 +0,0 @@
|
|
|
1
|
-
from typing import Any, Tuple, List, Union, Optional
|
|
2
|
-
from diffusers import DiffusionPipeline
|
|
3
|
-
from cache_dit.cache_factory.cache_types import CacheType
|
|
4
|
-
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
5
|
-
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
6
|
-
from cache_dit.cache_factory.cache_adapters import CachedAdapter
|
|
7
|
-
from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
|
|
8
|
-
from cache_dit.cache_factory.cache_contexts import CalibratorConfig
|
|
9
|
-
from cache_dit.cache_factory.params_modifier import ParamsModifier
|
|
10
|
-
|
|
11
|
-
from cache_dit.logger import init_logger
|
|
12
|
-
|
|
13
|
-
logger = init_logger(__name__)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def enable_cache(
|
|
17
|
-
# DiffusionPipeline or BlockAdapter
|
|
18
|
-
pipe_or_adapter: Union[
|
|
19
|
-
DiffusionPipeline,
|
|
20
|
-
BlockAdapter,
|
|
21
|
-
],
|
|
22
|
-
# Basic DBCache config: BasicCacheConfig
|
|
23
|
-
cache_config: BasicCacheConfig = BasicCacheConfig(),
|
|
24
|
-
# Calibrator config: TaylorSeerCalibratorConfig, etc.
|
|
25
|
-
calibrator_config: Optional[CalibratorConfig] = None,
|
|
26
|
-
# Modify cache context params for specific blocks.
|
|
27
|
-
params_modifiers: Optional[
|
|
28
|
-
Union[
|
|
29
|
-
ParamsModifier,
|
|
30
|
-
List[ParamsModifier],
|
|
31
|
-
List[List[ParamsModifier]],
|
|
32
|
-
]
|
|
33
|
-
] = None,
|
|
34
|
-
# Other cache context kwargs: Deprecated cache kwargs
|
|
35
|
-
**kwargs,
|
|
36
|
-
) -> Union[
|
|
37
|
-
DiffusionPipeline,
|
|
38
|
-
BlockAdapter,
|
|
39
|
-
]:
|
|
40
|
-
r"""
|
|
41
|
-
Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
|
|
42
|
-
that match the specific Input and Output patterns).
|
|
43
|
-
|
|
44
|
-
For a good balance between performance and precision, DBCache is configured by default
|
|
45
|
-
with F8B0, 8 warmup steps, and unlimited cached steps.
|
|
46
|
-
|
|
47
|
-
Args:
|
|
48
|
-
pipe_or_adapter (`DiffusionPipeline` or `BlockAdapter`, *required*):
|
|
49
|
-
The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
|
|
50
|
-
For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
|
|
51
|
-
for the usgae of BlockAdapter.
|
|
52
|
-
cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
|
|
53
|
-
Basic DBCache config for cache context, defaults to BasicCacheConfig(). The configurable params listed belows:
|
|
54
|
-
Fn_compute_blocks: (`int`, *required*, defaults to 8):
|
|
55
|
-
Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
|
|
56
|
-
at time step t, enabling the calculation of a more stable L1 diff and delivering more
|
|
57
|
-
accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
|
|
58
|
-
for more details of DBCache.
|
|
59
|
-
Bn_compute_blocks: (`int`, *required*, defaults to 0):
|
|
60
|
-
Further fuses approximate information in the **last n** Transformer blocks to enhance
|
|
61
|
-
prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
|
|
62
|
-
that use residual cache.
|
|
63
|
-
residual_diff_threshold (`float`, *required*, defaults to 0.08):
|
|
64
|
-
the value of residual diff threshold, a higher value leads to faster performance at the
|
|
65
|
-
cost of lower precision.
|
|
66
|
-
max_warmup_steps (`int`, *required*, defaults to 8):
|
|
67
|
-
DBCache does not apply the caching strategy when the number of running steps is less than
|
|
68
|
-
or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
69
|
-
max_cached_steps (`int`, *required*, defaults to -1):
|
|
70
|
-
DBCache disables the caching strategy when the previous cached steps exceed this value to
|
|
71
|
-
prevent precision degradation.
|
|
72
|
-
max_continuous_cached_steps (`int`, *required*, defaults to -1):
|
|
73
|
-
DBCache disables the caching strategy when the previous continous cached steps exceed this value to
|
|
74
|
-
prevent precision degradation.
|
|
75
|
-
enable_separate_cfg (`bool`, *required*, defaults to None):
|
|
76
|
-
Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
|
|
77
|
-
and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
|
|
78
|
-
CogVideoX, HunyuanVideo, Mochi, etc.
|
|
79
|
-
cfg_compute_first (`bool`, *required*, defaults to False):
|
|
80
|
-
Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
|
|
81
|
-
1, 3, 5, ... -> CFG step.
|
|
82
|
-
cfg_diff_compute_separate (`bool`, *required*, defaults to True):
|
|
83
|
-
Compute separate diff values for CFG and non-CFG step, default True. If False, we will
|
|
84
|
-
use the computed diff from current non-CFG transformer step for current CFG step.
|
|
85
|
-
calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
|
|
86
|
-
Config for calibrator, if calibrator_config is not None, means that user want to use DBCache
|
|
87
|
-
with specific calibrator, such as taylorseer, foca, and so on.
|
|
88
|
-
params_modifiers ('ParamsModifier', *optional*, defaults to None):
|
|
89
|
-
Modify cache context params for specific blocks. The configurable params listed belows:
|
|
90
|
-
cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
|
|
91
|
-
The same as 'cache_config' param in cache_dit.enable_cache() interface.
|
|
92
|
-
calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
|
|
93
|
-
The same as 'calibrator_config' param in cache_dit.enable_cache() interface.
|
|
94
|
-
**kwargs: (`dict`, *optional*, defaults to {}):
|
|
95
|
-
The same as 'kwargs' param in cache_dit.enable_cache() interface.
|
|
96
|
-
kwargs (`dict`, *optional*, defaults to {})
|
|
97
|
-
Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
|
|
98
|
-
for more details.
|
|
99
|
-
|
|
100
|
-
Examples:
|
|
101
|
-
```py
|
|
102
|
-
>>> import cache_dit
|
|
103
|
-
>>> from diffusers import DiffusionPipeline
|
|
104
|
-
>>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image") # Can be any diffusion pipeline
|
|
105
|
-
>>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
|
|
106
|
-
>>> output = pipe(...) # Just call the pipe as normal.
|
|
107
|
-
>>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
|
|
108
|
-
>>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
|
|
109
|
-
"""
|
|
110
|
-
# Collect cache context kwargs
|
|
111
|
-
cache_context_kwargs = {}
|
|
112
|
-
if (cache_type := cache_context_kwargs.pop("cache_type", None)) is not None:
|
|
113
|
-
if cache_type == CacheType.NONE:
|
|
114
|
-
return pipe_or_adapter
|
|
115
|
-
|
|
116
|
-
# WARNING: Deprecated cache config params. These parameters are now retained
|
|
117
|
-
# for backward compatibility but will be removed in the future.
|
|
118
|
-
deprecated_cache_kwargs = {
|
|
119
|
-
"Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
|
|
120
|
-
"Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
|
|
121
|
-
"max_warmup_steps": kwargs.get("max_warmup_steps", None),
|
|
122
|
-
"max_cached_steps": kwargs.get("max_cached_steps", None),
|
|
123
|
-
"max_continuous_cached_steps": kwargs.get(
|
|
124
|
-
"max_continuous_cached_steps", None
|
|
125
|
-
),
|
|
126
|
-
"residual_diff_threshold": kwargs.get("residual_diff_threshold", None),
|
|
127
|
-
"enable_separate_cfg": kwargs.get("enable_separate_cfg", None),
|
|
128
|
-
"cfg_compute_first": kwargs.get("cfg_compute_first", None),
|
|
129
|
-
"cfg_diff_compute_separate": kwargs.get(
|
|
130
|
-
"cfg_diff_compute_separate", None
|
|
131
|
-
),
|
|
132
|
-
}
|
|
133
|
-
|
|
134
|
-
deprecated_cache_kwargs = {
|
|
135
|
-
k: v for k, v in deprecated_cache_kwargs.items() if v is not None
|
|
136
|
-
}
|
|
137
|
-
|
|
138
|
-
if deprecated_cache_kwargs:
|
|
139
|
-
logger.warning(
|
|
140
|
-
"Manually settup DBCache context without BasicCacheConfig is "
|
|
141
|
-
"deprecated and will be removed in the future, please use "
|
|
142
|
-
"`cache_config` parameter instead!"
|
|
143
|
-
)
|
|
144
|
-
if cache_config is not None:
|
|
145
|
-
cache_config.update(**deprecated_cache_kwargs)
|
|
146
|
-
else:
|
|
147
|
-
cache_config = BasicCacheConfig(**deprecated_cache_kwargs)
|
|
148
|
-
|
|
149
|
-
if cache_config is not None:
|
|
150
|
-
cache_context_kwargs["cache_config"] = cache_config
|
|
151
|
-
|
|
152
|
-
# WARNING: Deprecated taylorseer params. These parameters are now retained
|
|
153
|
-
# for backward compatibility but will be removed in the future.
|
|
154
|
-
if (
|
|
155
|
-
kwargs.get("enable_taylorseer", None) is not None
|
|
156
|
-
or kwargs.get("enable_encoder_taylorseer", None) is not None
|
|
157
|
-
):
|
|
158
|
-
logger.warning(
|
|
159
|
-
"Manually settup TaylorSeer calibrator without TaylorSeerCalibratorConfig is "
|
|
160
|
-
"deprecated and will be removed in the future, please use "
|
|
161
|
-
"`calibrator_config` parameter instead!"
|
|
162
|
-
)
|
|
163
|
-
from cache_dit.cache_factory.cache_contexts.calibrators import (
|
|
164
|
-
TaylorSeerCalibratorConfig,
|
|
165
|
-
)
|
|
166
|
-
|
|
167
|
-
calibrator_config = TaylorSeerCalibratorConfig(
|
|
168
|
-
enable_calibrator=kwargs.get("enable_taylorseer"),
|
|
169
|
-
enable_encoder_calibrator=kwargs.get("enable_encoder_taylorseer"),
|
|
170
|
-
calibrator_cache_type=kwargs.get(
|
|
171
|
-
"taylorseer_cache_type", "residual"
|
|
172
|
-
),
|
|
173
|
-
taylorseer_order=kwargs.get("taylorseer_order", 1),
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
if calibrator_config is not None:
|
|
177
|
-
cache_context_kwargs["calibrator_config"] = calibrator_config
|
|
178
|
-
|
|
179
|
-
if params_modifiers is not None:
|
|
180
|
-
cache_context_kwargs["params_modifiers"] = params_modifiers
|
|
181
|
-
|
|
182
|
-
if isinstance(pipe_or_adapter, (DiffusionPipeline, BlockAdapter)):
|
|
183
|
-
return CachedAdapter.apply(
|
|
184
|
-
pipe_or_adapter,
|
|
185
|
-
**cache_context_kwargs,
|
|
186
|
-
)
|
|
187
|
-
else:
|
|
188
|
-
raise ValueError(
|
|
189
|
-
f"type: {type(pipe_or_adapter)} is not valid, "
|
|
190
|
-
"Please pass DiffusionPipeline or BlockAdapter"
|
|
191
|
-
"for the 1's position param: pipe_or_adapter"
|
|
192
|
-
)
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
def disable_cache(
|
|
196
|
-
pipe_or_adapter: Union[
|
|
197
|
-
DiffusionPipeline,
|
|
198
|
-
BlockAdapter,
|
|
199
|
-
],
|
|
200
|
-
):
|
|
201
|
-
CachedAdapter.maybe_release_hooks(pipe_or_adapter)
|
|
202
|
-
logger.warning(
|
|
203
|
-
f"Cache Acceleration is disabled for: "
|
|
204
|
-
f"{pipe_or_adapter.__class__.__name__}."
|
|
205
|
-
)
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
def supported_pipelines(
|
|
209
|
-
**kwargs,
|
|
210
|
-
) -> Tuple[int, List[str]]:
|
|
211
|
-
return BlockAdapterRegistry.supported_pipelines(**kwargs)
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
def get_adapter(
|
|
215
|
-
pipe: DiffusionPipeline | str | Any,
|
|
216
|
-
) -> BlockAdapter:
|
|
217
|
-
return BlockAdapterRegistry.get_adapter(pipe)
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
from cache_dit.cache_factory.patch_functors.functor_base import PatchFunctor
|
|
2
|
-
from cache_dit.cache_factory.patch_functors.functor_dit import DiTPatchFunctor
|
|
3
|
-
from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
|
|
4
|
-
from cache_dit.cache_factory.patch_functors.functor_chroma import (
|
|
5
|
-
ChromaPatchFunctor,
|
|
6
|
-
)
|
|
7
|
-
from cache_dit.cache_factory.patch_functors.functor_hidream import (
|
|
8
|
-
HiDreamPatchFunctor,
|
|
9
|
-
)
|
|
10
|
-
from cache_dit.cache_factory.patch_functors.functor_hunyuan_dit import (
|
|
11
|
-
HunyuanDiTPatchFunctor,
|
|
12
|
-
)
|