cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
- cache_dit/{cache_factory → caching}/utils.py +19 -8
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
- cache_dit/cache_factory/patch_functors/__init__.py +0 -15
- cache_dit-1.0.3.dist-info/RECORD +0 -58
- cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
cache_dit/summary.py
ADDED
|
@@ -0,0 +1,593 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import dataclasses
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from pprint import pprint
|
|
6
|
+
from diffusers import DiffusionPipeline
|
|
7
|
+
|
|
8
|
+
from typing import Dict, Any, List, Union
|
|
9
|
+
from cache_dit.caching import CacheType
|
|
10
|
+
from cache_dit.caching import BlockAdapter
|
|
11
|
+
from cache_dit.caching import BasicCacheConfig
|
|
12
|
+
from cache_dit.caching import CalibratorConfig
|
|
13
|
+
from cache_dit.caching import FakeDiffusionPipeline
|
|
14
|
+
from cache_dit.parallelism import ParallelismConfig
|
|
15
|
+
from cache_dit.logger import init_logger
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
logger = init_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclasses.dataclass
|
|
22
|
+
class CacheStats:
|
|
23
|
+
cache_options: dict = dataclasses.field(default_factory=dict)
|
|
24
|
+
# Dual Block Cache
|
|
25
|
+
cached_steps: list[int] = dataclasses.field(default_factory=list)
|
|
26
|
+
residual_diffs: dict[str, float] = dataclasses.field(default_factory=dict)
|
|
27
|
+
cfg_cached_steps: list[int] = dataclasses.field(default_factory=list)
|
|
28
|
+
cfg_residual_diffs: dict[str, float] = dataclasses.field(
|
|
29
|
+
default_factory=dict
|
|
30
|
+
)
|
|
31
|
+
# Dynamic Block Prune
|
|
32
|
+
pruned_steps: list[int] = dataclasses.field(default_factory=list)
|
|
33
|
+
pruned_blocks: list[int] = dataclasses.field(default_factory=list)
|
|
34
|
+
actual_blocks: list[int] = dataclasses.field(default_factory=list)
|
|
35
|
+
pruned_ratio: float = None
|
|
36
|
+
cfg_pruned_steps: list[int] = dataclasses.field(default_factory=list)
|
|
37
|
+
cfg_pruned_blocks: list[int] = dataclasses.field(default_factory=list)
|
|
38
|
+
cfg_actual_blocks: list[int] = dataclasses.field(default_factory=list)
|
|
39
|
+
cfg_pruned_ratio: float = None
|
|
40
|
+
# Parallelism Stats
|
|
41
|
+
parallelism_config: ParallelismConfig = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def summary(
|
|
45
|
+
adapter_or_others: Union[
|
|
46
|
+
BlockAdapter,
|
|
47
|
+
DiffusionPipeline,
|
|
48
|
+
FakeDiffusionPipeline,
|
|
49
|
+
torch.nn.Module,
|
|
50
|
+
],
|
|
51
|
+
details: bool = False,
|
|
52
|
+
logging: bool = True,
|
|
53
|
+
**kwargs,
|
|
54
|
+
) -> List[CacheStats]:
|
|
55
|
+
if adapter_or_others is None:
|
|
56
|
+
return [CacheStats()]
|
|
57
|
+
|
|
58
|
+
if isinstance(adapter_or_others, FakeDiffusionPipeline):
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"Please pass DiffusionPipeline, BlockAdapter or transfomer, "
|
|
61
|
+
"not FakeDiffusionPipeline."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
if not isinstance(adapter_or_others, BlockAdapter):
|
|
65
|
+
if not isinstance(adapter_or_others, DiffusionPipeline):
|
|
66
|
+
transformer = adapter_or_others # transformer-only
|
|
67
|
+
transformer_2 = None
|
|
68
|
+
else:
|
|
69
|
+
transformer = adapter_or_others.transformer
|
|
70
|
+
transformer_2 = None # Only for Wan2.2
|
|
71
|
+
if hasattr(adapter_or_others, "transformer_2"):
|
|
72
|
+
transformer_2 = adapter_or_others.transformer_2
|
|
73
|
+
|
|
74
|
+
if all(
|
|
75
|
+
(
|
|
76
|
+
not BlockAdapter.is_cached(transformer),
|
|
77
|
+
not BlockAdapter.is_parallelized(transformer),
|
|
78
|
+
)
|
|
79
|
+
):
|
|
80
|
+
return [CacheStats()]
|
|
81
|
+
|
|
82
|
+
blocks_stats: List[CacheStats] = []
|
|
83
|
+
if BlockAdapter.is_cached(transformer):
|
|
84
|
+
for blocks in BlockAdapter.find_blocks(transformer):
|
|
85
|
+
blocks_stats.append(
|
|
86
|
+
_summary(
|
|
87
|
+
blocks,
|
|
88
|
+
details=details,
|
|
89
|
+
logging=logging,
|
|
90
|
+
**kwargs,
|
|
91
|
+
)
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
if transformer_2 is not None and BlockAdapter.is_cached(transformer_2):
|
|
95
|
+
for blocks in BlockAdapter.find_blocks(transformer_2):
|
|
96
|
+
blocks_stats.append(
|
|
97
|
+
_summary(
|
|
98
|
+
blocks,
|
|
99
|
+
details=details,
|
|
100
|
+
logging=logging,
|
|
101
|
+
**kwargs,
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
blocks_stats.append(
|
|
106
|
+
_summary(
|
|
107
|
+
transformer,
|
|
108
|
+
details=details,
|
|
109
|
+
logging=logging,
|
|
110
|
+
**kwargs,
|
|
111
|
+
)
|
|
112
|
+
)
|
|
113
|
+
if transformer_2 is not None:
|
|
114
|
+
blocks_stats.append(
|
|
115
|
+
_summary(
|
|
116
|
+
transformer_2,
|
|
117
|
+
details=details,
|
|
118
|
+
logging=logging,
|
|
119
|
+
**kwargs,
|
|
120
|
+
)
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
blocks_stats = [
|
|
124
|
+
stats
|
|
125
|
+
for stats in blocks_stats
|
|
126
|
+
if (stats.cache_options or stats.parallelism_config)
|
|
127
|
+
]
|
|
128
|
+
|
|
129
|
+
return blocks_stats if len(blocks_stats) else [CacheStats()]
|
|
130
|
+
|
|
131
|
+
adapter = adapter_or_others
|
|
132
|
+
if not BlockAdapter.check_block_adapter(adapter):
|
|
133
|
+
return [CacheStats()]
|
|
134
|
+
|
|
135
|
+
blocks_stats = []
|
|
136
|
+
flatten_blocks = BlockAdapter.flatten(adapter.blocks)
|
|
137
|
+
for blocks in flatten_blocks:
|
|
138
|
+
blocks_stats.append(
|
|
139
|
+
_summary(
|
|
140
|
+
blocks,
|
|
141
|
+
details=details,
|
|
142
|
+
logging=logging,
|
|
143
|
+
**kwargs,
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
blocks_stats = [stats for stats in blocks_stats if stats.cache_options]
|
|
148
|
+
|
|
149
|
+
return blocks_stats if len(blocks_stats) else [CacheStats()]
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def strify(
|
|
153
|
+
adapter_or_others: Union[
|
|
154
|
+
BlockAdapter,
|
|
155
|
+
DiffusionPipeline,
|
|
156
|
+
FakeDiffusionPipeline,
|
|
157
|
+
torch.nn.Module,
|
|
158
|
+
CacheStats,
|
|
159
|
+
List[CacheStats],
|
|
160
|
+
Dict[str, Any],
|
|
161
|
+
],
|
|
162
|
+
) -> str:
|
|
163
|
+
if isinstance(adapter_or_others, FakeDiffusionPipeline):
|
|
164
|
+
raise ValueError(
|
|
165
|
+
"Please pass DiffusionPipeline, BlockAdapter or transfomer, "
|
|
166
|
+
"not FakeDiffusionPipeline."
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
parallelism_config: ParallelismConfig = None
|
|
170
|
+
if isinstance(adapter_or_others, BlockAdapter):
|
|
171
|
+
stats = summary(adapter_or_others, logging=False)[-1]
|
|
172
|
+
cache_options = stats.cache_options
|
|
173
|
+
cached_steps = len(stats.cached_steps)
|
|
174
|
+
elif isinstance(adapter_or_others, DiffusionPipeline):
|
|
175
|
+
stats = summary(adapter_or_others, logging=False)[-1]
|
|
176
|
+
cache_options = stats.cache_options
|
|
177
|
+
cached_steps = len(stats.cached_steps)
|
|
178
|
+
elif isinstance(adapter_or_others, torch.nn.Module):
|
|
179
|
+
stats = summary(adapter_or_others, logging=False)[-1]
|
|
180
|
+
cache_options = stats.cache_options
|
|
181
|
+
cached_steps = len(stats.cached_steps)
|
|
182
|
+
elif isinstance(adapter_or_others, CacheStats):
|
|
183
|
+
stats = adapter_or_others
|
|
184
|
+
cache_options = stats.cache_options
|
|
185
|
+
cached_steps = len(stats.cached_steps)
|
|
186
|
+
elif isinstance(adapter_or_others, list):
|
|
187
|
+
stats = adapter_or_others[0]
|
|
188
|
+
cache_options = stats.cache_options
|
|
189
|
+
cached_steps = len(stats.cached_steps)
|
|
190
|
+
elif isinstance(adapter_or_others, dict):
|
|
191
|
+
|
|
192
|
+
# Assume context_kwargs
|
|
193
|
+
cache_options = adapter_or_others
|
|
194
|
+
cached_steps = None
|
|
195
|
+
cache_type = cache_options.get("cache_type", CacheType.NONE)
|
|
196
|
+
stats = None
|
|
197
|
+
parallelism_config = cache_options.get("parallelism_config", None)
|
|
198
|
+
|
|
199
|
+
if cache_type == CacheType.NONE:
|
|
200
|
+
return "NONE"
|
|
201
|
+
else:
|
|
202
|
+
raise ValueError(
|
|
203
|
+
"Please set pipe_or_stats param as one of: "
|
|
204
|
+
"DiffusionPipeline | CacheStats | Dict[str, Any] | List[CacheStats]"
|
|
205
|
+
" | BlockAdapter | Transformer"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if stats is not None:
|
|
209
|
+
parallelism_config = stats.parallelism_config
|
|
210
|
+
|
|
211
|
+
if not cache_options and parallelism_config is None:
|
|
212
|
+
return "NONE"
|
|
213
|
+
|
|
214
|
+
def cache_str():
|
|
215
|
+
cache_config: BasicCacheConfig = cache_options.get("cache_config", None)
|
|
216
|
+
if cache_config is not None:
|
|
217
|
+
if cache_config.cache_type == CacheType.NONE:
|
|
218
|
+
return "NONE"
|
|
219
|
+
elif cache_config.cache_type == CacheType.DBCache:
|
|
220
|
+
return cache_config.strify()
|
|
221
|
+
elif cache_config.cache_type == CacheType.DBPrune:
|
|
222
|
+
pruned_ratio = stats.pruned_ratio
|
|
223
|
+
if pruned_ratio is not None:
|
|
224
|
+
return f"{cache_config.strify()}_P{round(pruned_ratio * 100, 2)}"
|
|
225
|
+
return cache_config.strify()
|
|
226
|
+
return "NONE"
|
|
227
|
+
|
|
228
|
+
def calibrator_str():
|
|
229
|
+
calibrator_config: CalibratorConfig = cache_options.get(
|
|
230
|
+
"calibrator_config", None
|
|
231
|
+
)
|
|
232
|
+
if calibrator_config is not None:
|
|
233
|
+
return calibrator_config.strify()
|
|
234
|
+
return "T0O0"
|
|
235
|
+
|
|
236
|
+
def parallelism_str():
|
|
237
|
+
if parallelism_config is not None:
|
|
238
|
+
return f"_{parallelism_config.strify()}"
|
|
239
|
+
return ""
|
|
240
|
+
|
|
241
|
+
cache_type_str = f"{cache_str()}"
|
|
242
|
+
if cache_type_str != "NONE":
|
|
243
|
+
cache_type_str += f"_{calibrator_str()}"
|
|
244
|
+
cache_type_str += f"{parallelism_str()}"
|
|
245
|
+
|
|
246
|
+
if cached_steps:
|
|
247
|
+
cache_type_str += f"_S{cached_steps}"
|
|
248
|
+
|
|
249
|
+
return cache_type_str
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _summary(
|
|
253
|
+
pipe_or_module: Union[
|
|
254
|
+
DiffusionPipeline,
|
|
255
|
+
torch.nn.Module,
|
|
256
|
+
],
|
|
257
|
+
details: bool = False,
|
|
258
|
+
logging: bool = True,
|
|
259
|
+
**kwargs,
|
|
260
|
+
) -> CacheStats:
|
|
261
|
+
cache_stats = CacheStats()
|
|
262
|
+
|
|
263
|
+
# Get stats from transformer
|
|
264
|
+
if not isinstance(pipe_or_module, torch.nn.Module):
|
|
265
|
+
assert hasattr(pipe_or_module, "transformer")
|
|
266
|
+
module = pipe_or_module.transformer
|
|
267
|
+
cls_name = module.__class__.__name__
|
|
268
|
+
else:
|
|
269
|
+
module = pipe_or_module
|
|
270
|
+
|
|
271
|
+
cls_name = module.__class__.__name__
|
|
272
|
+
if isinstance(module, torch.nn.ModuleList):
|
|
273
|
+
cls_name = module[0].__class__.__name__
|
|
274
|
+
|
|
275
|
+
if hasattr(module, "_context_kwargs"):
|
|
276
|
+
cache_options = module._context_kwargs
|
|
277
|
+
cache_stats.cache_options = cache_options
|
|
278
|
+
if logging:
|
|
279
|
+
print(f"\n🤗Context Options: {cls_name}\n\n{cache_options}")
|
|
280
|
+
else:
|
|
281
|
+
if logging:
|
|
282
|
+
logger.warning(f"Can't find Context Options for: {cls_name}")
|
|
283
|
+
|
|
284
|
+
if hasattr(module, "_parallelism_config"):
|
|
285
|
+
parallelism_config: ParallelismConfig = module._parallelism_config
|
|
286
|
+
cache_stats.parallelism_config = parallelism_config
|
|
287
|
+
if logging:
|
|
288
|
+
print(
|
|
289
|
+
f"\n🤖Parallelism Config: {cls_name}\n\n{parallelism_config.strify(True)}"
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
if logging:
|
|
293
|
+
logger.warning(f"Can't find Parallelism Config for: {cls_name}")
|
|
294
|
+
|
|
295
|
+
if hasattr(module, "_cached_steps"):
|
|
296
|
+
cached_steps: list[int] = module._cached_steps
|
|
297
|
+
residual_diffs: dict[str, list | float] = dict(module._residual_diffs)
|
|
298
|
+
|
|
299
|
+
if hasattr(module, "_pruned_steps"):
|
|
300
|
+
pruned_steps: list[int] = module._pruned_steps
|
|
301
|
+
pruned_blocks: list[int] = module._pruned_blocks
|
|
302
|
+
actual_blocks: list[int] = module._actual_blocks
|
|
303
|
+
pruned_ratio: float = module._pruned_ratio
|
|
304
|
+
else:
|
|
305
|
+
pruned_steps = []
|
|
306
|
+
pruned_blocks = []
|
|
307
|
+
actual_blocks = []
|
|
308
|
+
pruned_ratio = None
|
|
309
|
+
|
|
310
|
+
cache_stats.cached_steps = cached_steps
|
|
311
|
+
cache_stats.residual_diffs = residual_diffs
|
|
312
|
+
|
|
313
|
+
cache_stats.pruned_steps = pruned_steps
|
|
314
|
+
cache_stats.pruned_blocks = pruned_blocks
|
|
315
|
+
cache_stats.actual_blocks = actual_blocks
|
|
316
|
+
cache_stats.pruned_ratio = pruned_ratio
|
|
317
|
+
|
|
318
|
+
if residual_diffs and logging:
|
|
319
|
+
diffs_values = list(residual_diffs.values())
|
|
320
|
+
if isinstance(diffs_values[0], list):
|
|
321
|
+
diffs_values = [v for sublist in diffs_values for v in sublist]
|
|
322
|
+
qmin = np.min(diffs_values)
|
|
323
|
+
q0 = np.percentile(diffs_values, 0)
|
|
324
|
+
q1 = np.percentile(diffs_values, 25)
|
|
325
|
+
q2 = np.percentile(diffs_values, 50)
|
|
326
|
+
q3 = np.percentile(diffs_values, 75)
|
|
327
|
+
q4 = np.percentile(diffs_values, 95)
|
|
328
|
+
qmax = np.max(diffs_values)
|
|
329
|
+
|
|
330
|
+
if pruned_ratio is not None:
|
|
331
|
+
print(
|
|
332
|
+
f"\n⚡️Pruned Blocks and Residual Diffs Statistics: {cls_name}\n"
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
print(
|
|
336
|
+
"| Pruned Blocks | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
|
|
337
|
+
)
|
|
338
|
+
print(
|
|
339
|
+
"|---------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
|
|
340
|
+
)
|
|
341
|
+
print(
|
|
342
|
+
f"| {sum(pruned_blocks):<13} | {round(q0, 3):<9} | {round(q1, 3):<9} "
|
|
343
|
+
f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
|
|
344
|
+
f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
|
|
345
|
+
)
|
|
346
|
+
print("")
|
|
347
|
+
else:
|
|
348
|
+
print(
|
|
349
|
+
f"\n⚡️Cache Steps and Residual Diffs Statistics: {cls_name}\n"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
print(
|
|
353
|
+
"| Cache Steps | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
|
|
354
|
+
)
|
|
355
|
+
print(
|
|
356
|
+
"|-------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
|
|
357
|
+
)
|
|
358
|
+
print(
|
|
359
|
+
f"| {len(cached_steps):<11} | {round(q0, 3):<9} | {round(q1, 3):<9} "
|
|
360
|
+
f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
|
|
361
|
+
f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
|
|
362
|
+
)
|
|
363
|
+
print("")
|
|
364
|
+
|
|
365
|
+
if pruned_ratio is not None:
|
|
366
|
+
print(
|
|
367
|
+
f"Dynamic Block Prune Ratio: {round(pruned_ratio * 100, 2)}% ({sum(pruned_blocks)}/{sum(actual_blocks)})\n"
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
if details:
|
|
371
|
+
if pruned_ratio is not None:
|
|
372
|
+
print(
|
|
373
|
+
f"📚Pruned Blocks and Residual Diffs Details: {cls_name}\n"
|
|
374
|
+
)
|
|
375
|
+
pprint(
|
|
376
|
+
f"Pruned Blocks: {len(pruned_blocks)}, {pruned_blocks}",
|
|
377
|
+
)
|
|
378
|
+
pprint(
|
|
379
|
+
f"Actual Blocks: {len(actual_blocks)}, {actual_blocks}",
|
|
380
|
+
)
|
|
381
|
+
pprint(
|
|
382
|
+
f"Residual Diffs: {len(residual_diffs)}, {residual_diffs}",
|
|
383
|
+
compact=True,
|
|
384
|
+
)
|
|
385
|
+
else:
|
|
386
|
+
print(
|
|
387
|
+
f"📚Cache Steps and Residual Diffs Details: {cls_name}\n"
|
|
388
|
+
)
|
|
389
|
+
pprint(
|
|
390
|
+
f"Cache Steps: {len(cached_steps)}, {cached_steps}",
|
|
391
|
+
)
|
|
392
|
+
pprint(
|
|
393
|
+
f"Residual Diffs: {len(residual_diffs)}, {residual_diffs}",
|
|
394
|
+
compact=True,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
if hasattr(module, "_cfg_cached_steps"):
|
|
398
|
+
cfg_cached_steps: list[int] = module._cfg_cached_steps
|
|
399
|
+
cfg_residual_diffs: dict[str, list | float] = dict(
|
|
400
|
+
module._cfg_residual_diffs
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
if hasattr(module, "_cfg_pruned_steps"):
|
|
404
|
+
cfg_pruned_steps: list[int] = module._cfg_pruned_steps
|
|
405
|
+
cfg_pruned_blocks: list[int] = module._cfg_pruned_blocks
|
|
406
|
+
cfg_actual_blocks: list[int] = module._cfg_actual_blocks
|
|
407
|
+
cfg_pruned_ratio: float = module._cfg_pruned_ratio
|
|
408
|
+
else:
|
|
409
|
+
cfg_pruned_steps = []
|
|
410
|
+
cfg_pruned_blocks = []
|
|
411
|
+
cfg_actual_blocks = []
|
|
412
|
+
cfg_pruned_ratio = None
|
|
413
|
+
|
|
414
|
+
cache_stats.cfg_cached_steps = cfg_cached_steps
|
|
415
|
+
cache_stats.cfg_residual_diffs = cfg_residual_diffs
|
|
416
|
+
cache_stats.cfg_pruned_steps = cfg_pruned_steps
|
|
417
|
+
cache_stats.cfg_pruned_blocks = cfg_pruned_blocks
|
|
418
|
+
cache_stats.cfg_actual_blocks = cfg_actual_blocks
|
|
419
|
+
cache_stats.cfg_pruned_ratio = cfg_pruned_ratio
|
|
420
|
+
|
|
421
|
+
if cfg_residual_diffs and logging:
|
|
422
|
+
cfg_diffs_values = list(cfg_residual_diffs.values())
|
|
423
|
+
if isinstance(cfg_diffs_values[0], list):
|
|
424
|
+
cfg_diffs_values = [
|
|
425
|
+
v for sublist in cfg_diffs_values for v in sublist
|
|
426
|
+
]
|
|
427
|
+
qmin = np.min(cfg_diffs_values)
|
|
428
|
+
q0 = np.percentile(cfg_diffs_values, 0)
|
|
429
|
+
q1 = np.percentile(cfg_diffs_values, 25)
|
|
430
|
+
q2 = np.percentile(cfg_diffs_values, 50)
|
|
431
|
+
q3 = np.percentile(cfg_diffs_values, 75)
|
|
432
|
+
q4 = np.percentile(cfg_diffs_values, 95)
|
|
433
|
+
qmax = np.max(cfg_diffs_values)
|
|
434
|
+
|
|
435
|
+
if cfg_pruned_ratio is not None:
|
|
436
|
+
print(
|
|
437
|
+
f"\n⚡️CFG Pruned Blocks and Residual Diffs Statistics: {cls_name}\n"
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
print(
|
|
441
|
+
"| CFG Pruned Blocks | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
|
|
442
|
+
)
|
|
443
|
+
print(
|
|
444
|
+
"|-------------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
|
|
445
|
+
)
|
|
446
|
+
print(
|
|
447
|
+
f"| {sum(cfg_pruned_blocks):<18} | {round(q0, 3):<9} | {round(q1, 3):<9} "
|
|
448
|
+
f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
|
|
449
|
+
f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
|
|
450
|
+
)
|
|
451
|
+
print("")
|
|
452
|
+
else:
|
|
453
|
+
print(
|
|
454
|
+
f"\n⚡️CFG Cache Steps and Residual Diffs Statistics: {cls_name}\n"
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
print(
|
|
458
|
+
"| CFG Cache Steps | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
|
|
459
|
+
)
|
|
460
|
+
print(
|
|
461
|
+
"|-----------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
|
|
462
|
+
)
|
|
463
|
+
print(
|
|
464
|
+
f"| {len(cfg_cached_steps):<15} | {round(q0, 3):<9} | {round(q1, 3):<9} "
|
|
465
|
+
f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
|
|
466
|
+
f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
|
|
467
|
+
)
|
|
468
|
+
print("")
|
|
469
|
+
|
|
470
|
+
if cfg_pruned_ratio is not None:
|
|
471
|
+
print(
|
|
472
|
+
f"CFG Dynamic Block Prune Ratio: {round(cfg_pruned_ratio * 100, 2)}% ({sum(cfg_pruned_blocks)}/{sum(cfg_actual_blocks)})\n"
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
if details:
|
|
476
|
+
if cfg_pruned_ratio is not None:
|
|
477
|
+
print(
|
|
478
|
+
f"📚CFG Pruned Blocks and Residual Diffs Details: {cls_name}\n"
|
|
479
|
+
)
|
|
480
|
+
pprint(
|
|
481
|
+
f"CFG Pruned Blocks: {len(cfg_pruned_blocks)}, {cfg_pruned_blocks}",
|
|
482
|
+
)
|
|
483
|
+
pprint(
|
|
484
|
+
f"CFG Actual Blocks: {len(cfg_actual_blocks)}, {cfg_actual_blocks}",
|
|
485
|
+
)
|
|
486
|
+
pprint(
|
|
487
|
+
f"CFG Residual Diffs: {len(cfg_residual_diffs)}, {cfg_residual_diffs}",
|
|
488
|
+
compact=True,
|
|
489
|
+
)
|
|
490
|
+
else:
|
|
491
|
+
print(
|
|
492
|
+
f"📚CFG Cache Steps and Residual Diffs Details: {cls_name}\n"
|
|
493
|
+
)
|
|
494
|
+
pprint(
|
|
495
|
+
f"CFG Cache Steps: {len(cfg_cached_steps)}, {cfg_cached_steps}",
|
|
496
|
+
)
|
|
497
|
+
pprint(
|
|
498
|
+
f"CFG Residual Diffs: {len(cfg_residual_diffs)}, {cfg_residual_diffs}",
|
|
499
|
+
compact=True,
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
return cache_stats
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def supported_matrix() -> str | None:
|
|
506
|
+
try:
|
|
507
|
+
from cache_dit.caching.block_adapters.block_registers import (
|
|
508
|
+
BlockAdapterRegistry,
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
_pipelines_supported_cache = BlockAdapterRegistry.supported_pipelines()[
|
|
512
|
+
1
|
|
513
|
+
]
|
|
514
|
+
_pipelines_supported_cache += [
|
|
515
|
+
"LongCatVideo", # not in diffusers, but supported
|
|
516
|
+
]
|
|
517
|
+
from cache_dit.parallelism.backends.native_diffusers import (
|
|
518
|
+
ContextParallelismPlannerRegister,
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
_pipelines_supported_context_parallelism = (
|
|
522
|
+
ContextParallelismPlannerRegister.supported_planners()[1]
|
|
523
|
+
)
|
|
524
|
+
from cache_dit.parallelism.backends.native_pytorch import (
|
|
525
|
+
TensorParallelismPlannerRegister,
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
_pipelines_supported_tensor_parallelism = (
|
|
529
|
+
TensorParallelismPlannerRegister.supported_planners()[1]
|
|
530
|
+
)
|
|
531
|
+
# Add some special aliases since cp/tp planners use the name shortcut
|
|
532
|
+
# of Transformer only.
|
|
533
|
+
_pipelines_supported_context_parallelism += [
|
|
534
|
+
"Wan",
|
|
535
|
+
"LTX",
|
|
536
|
+
"VisualCloze",
|
|
537
|
+
]
|
|
538
|
+
_pipelines_supported_tensor_parallelism += [
|
|
539
|
+
"Wan",
|
|
540
|
+
"VisualCloze",
|
|
541
|
+
]
|
|
542
|
+
|
|
543
|
+
# Generate the supported matrix, markdown table format
|
|
544
|
+
matrix_lines: List[str] = []
|
|
545
|
+
header = "| Model | Cache | CP | TP | Model | Cache | CP | TP |"
|
|
546
|
+
matrix_lines.append(header)
|
|
547
|
+
matrix_lines.append("|:---|:---|:---|:---|:---|:---|:---|:---|")
|
|
548
|
+
half = (len(_pipelines_supported_cache) + 1) // 2
|
|
549
|
+
link = (
|
|
550
|
+
"https://github.com/vipshop/cache-dit/blob/main/examples/pipeline"
|
|
551
|
+
)
|
|
552
|
+
for i in range(half):
|
|
553
|
+
pipeline_left = _pipelines_supported_cache[i]
|
|
554
|
+
cp_support_left = (
|
|
555
|
+
"✅"
|
|
556
|
+
if pipeline_left in _pipelines_supported_context_parallelism
|
|
557
|
+
else "✖️"
|
|
558
|
+
)
|
|
559
|
+
tp_support_left = (
|
|
560
|
+
"✅"
|
|
561
|
+
if pipeline_left in _pipelines_supported_tensor_parallelism
|
|
562
|
+
else "✖️"
|
|
563
|
+
)
|
|
564
|
+
if i + half < len(_pipelines_supported_cache):
|
|
565
|
+
pipeline_right = _pipelines_supported_cache[i + half]
|
|
566
|
+
cp_support_right = (
|
|
567
|
+
"✅"
|
|
568
|
+
if pipeline_right
|
|
569
|
+
in _pipelines_supported_context_parallelism
|
|
570
|
+
else "✖️"
|
|
571
|
+
)
|
|
572
|
+
tp_support_right = (
|
|
573
|
+
"✅"
|
|
574
|
+
if pipeline_right in _pipelines_supported_tensor_parallelism
|
|
575
|
+
else "✖️"
|
|
576
|
+
)
|
|
577
|
+
else:
|
|
578
|
+
pipeline_right = ""
|
|
579
|
+
cp_support_right = ""
|
|
580
|
+
tp_support_right = ""
|
|
581
|
+
line = (
|
|
582
|
+
f"| **🎉[{pipeline_left}]({link})** | ✅ | {cp_support_left} | {tp_support_left} "
|
|
583
|
+
f"| **🎉[{pipeline_right}]({link})** | ✅ | {cp_support_right} | {tp_support_right} | "
|
|
584
|
+
)
|
|
585
|
+
matrix_lines.append(line)
|
|
586
|
+
|
|
587
|
+
matrix_str = "\n".join(matrix_lines)
|
|
588
|
+
|
|
589
|
+
print("\nSupported Cache and Parallelism Matrix:\n")
|
|
590
|
+
print(matrix_str)
|
|
591
|
+
return matrix_str
|
|
592
|
+
except Exception:
|
|
593
|
+
return None
|