cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
cache_dit/summary.py ADDED
@@ -0,0 +1,593 @@
1
+ import torch
2
+ import dataclasses
3
+
4
+ import numpy as np
5
+ from pprint import pprint
6
+ from diffusers import DiffusionPipeline
7
+
8
+ from typing import Dict, Any, List, Union
9
+ from cache_dit.caching import CacheType
10
+ from cache_dit.caching import BlockAdapter
11
+ from cache_dit.caching import BasicCacheConfig
12
+ from cache_dit.caching import CalibratorConfig
13
+ from cache_dit.caching import FakeDiffusionPipeline
14
+ from cache_dit.parallelism import ParallelismConfig
15
+ from cache_dit.logger import init_logger
16
+
17
+
18
+ logger = init_logger(__name__)
19
+
20
+
21
+ @dataclasses.dataclass
22
+ class CacheStats:
23
+ cache_options: dict = dataclasses.field(default_factory=dict)
24
+ # Dual Block Cache
25
+ cached_steps: list[int] = dataclasses.field(default_factory=list)
26
+ residual_diffs: dict[str, float] = dataclasses.field(default_factory=dict)
27
+ cfg_cached_steps: list[int] = dataclasses.field(default_factory=list)
28
+ cfg_residual_diffs: dict[str, float] = dataclasses.field(
29
+ default_factory=dict
30
+ )
31
+ # Dynamic Block Prune
32
+ pruned_steps: list[int] = dataclasses.field(default_factory=list)
33
+ pruned_blocks: list[int] = dataclasses.field(default_factory=list)
34
+ actual_blocks: list[int] = dataclasses.field(default_factory=list)
35
+ pruned_ratio: float = None
36
+ cfg_pruned_steps: list[int] = dataclasses.field(default_factory=list)
37
+ cfg_pruned_blocks: list[int] = dataclasses.field(default_factory=list)
38
+ cfg_actual_blocks: list[int] = dataclasses.field(default_factory=list)
39
+ cfg_pruned_ratio: float = None
40
+ # Parallelism Stats
41
+ parallelism_config: ParallelismConfig = None
42
+
43
+
44
+ def summary(
45
+ adapter_or_others: Union[
46
+ BlockAdapter,
47
+ DiffusionPipeline,
48
+ FakeDiffusionPipeline,
49
+ torch.nn.Module,
50
+ ],
51
+ details: bool = False,
52
+ logging: bool = True,
53
+ **kwargs,
54
+ ) -> List[CacheStats]:
55
+ if adapter_or_others is None:
56
+ return [CacheStats()]
57
+
58
+ if isinstance(adapter_or_others, FakeDiffusionPipeline):
59
+ raise ValueError(
60
+ "Please pass DiffusionPipeline, BlockAdapter or transfomer, "
61
+ "not FakeDiffusionPipeline."
62
+ )
63
+
64
+ if not isinstance(adapter_or_others, BlockAdapter):
65
+ if not isinstance(adapter_or_others, DiffusionPipeline):
66
+ transformer = adapter_or_others # transformer-only
67
+ transformer_2 = None
68
+ else:
69
+ transformer = adapter_or_others.transformer
70
+ transformer_2 = None # Only for Wan2.2
71
+ if hasattr(adapter_or_others, "transformer_2"):
72
+ transformer_2 = adapter_or_others.transformer_2
73
+
74
+ if all(
75
+ (
76
+ not BlockAdapter.is_cached(transformer),
77
+ not BlockAdapter.is_parallelized(transformer),
78
+ )
79
+ ):
80
+ return [CacheStats()]
81
+
82
+ blocks_stats: List[CacheStats] = []
83
+ if BlockAdapter.is_cached(transformer):
84
+ for blocks in BlockAdapter.find_blocks(transformer):
85
+ blocks_stats.append(
86
+ _summary(
87
+ blocks,
88
+ details=details,
89
+ logging=logging,
90
+ **kwargs,
91
+ )
92
+ )
93
+
94
+ if transformer_2 is not None and BlockAdapter.is_cached(transformer_2):
95
+ for blocks in BlockAdapter.find_blocks(transformer_2):
96
+ blocks_stats.append(
97
+ _summary(
98
+ blocks,
99
+ details=details,
100
+ logging=logging,
101
+ **kwargs,
102
+ )
103
+ )
104
+
105
+ blocks_stats.append(
106
+ _summary(
107
+ transformer,
108
+ details=details,
109
+ logging=logging,
110
+ **kwargs,
111
+ )
112
+ )
113
+ if transformer_2 is not None:
114
+ blocks_stats.append(
115
+ _summary(
116
+ transformer_2,
117
+ details=details,
118
+ logging=logging,
119
+ **kwargs,
120
+ )
121
+ )
122
+
123
+ blocks_stats = [
124
+ stats
125
+ for stats in blocks_stats
126
+ if (stats.cache_options or stats.parallelism_config)
127
+ ]
128
+
129
+ return blocks_stats if len(blocks_stats) else [CacheStats()]
130
+
131
+ adapter = adapter_or_others
132
+ if not BlockAdapter.check_block_adapter(adapter):
133
+ return [CacheStats()]
134
+
135
+ blocks_stats = []
136
+ flatten_blocks = BlockAdapter.flatten(adapter.blocks)
137
+ for blocks in flatten_blocks:
138
+ blocks_stats.append(
139
+ _summary(
140
+ blocks,
141
+ details=details,
142
+ logging=logging,
143
+ **kwargs,
144
+ )
145
+ )
146
+
147
+ blocks_stats = [stats for stats in blocks_stats if stats.cache_options]
148
+
149
+ return blocks_stats if len(blocks_stats) else [CacheStats()]
150
+
151
+
152
+ def strify(
153
+ adapter_or_others: Union[
154
+ BlockAdapter,
155
+ DiffusionPipeline,
156
+ FakeDiffusionPipeline,
157
+ torch.nn.Module,
158
+ CacheStats,
159
+ List[CacheStats],
160
+ Dict[str, Any],
161
+ ],
162
+ ) -> str:
163
+ if isinstance(adapter_or_others, FakeDiffusionPipeline):
164
+ raise ValueError(
165
+ "Please pass DiffusionPipeline, BlockAdapter or transfomer, "
166
+ "not FakeDiffusionPipeline."
167
+ )
168
+
169
+ parallelism_config: ParallelismConfig = None
170
+ if isinstance(adapter_or_others, BlockAdapter):
171
+ stats = summary(adapter_or_others, logging=False)[-1]
172
+ cache_options = stats.cache_options
173
+ cached_steps = len(stats.cached_steps)
174
+ elif isinstance(adapter_or_others, DiffusionPipeline):
175
+ stats = summary(adapter_or_others, logging=False)[-1]
176
+ cache_options = stats.cache_options
177
+ cached_steps = len(stats.cached_steps)
178
+ elif isinstance(adapter_or_others, torch.nn.Module):
179
+ stats = summary(adapter_or_others, logging=False)[-1]
180
+ cache_options = stats.cache_options
181
+ cached_steps = len(stats.cached_steps)
182
+ elif isinstance(adapter_or_others, CacheStats):
183
+ stats = adapter_or_others
184
+ cache_options = stats.cache_options
185
+ cached_steps = len(stats.cached_steps)
186
+ elif isinstance(adapter_or_others, list):
187
+ stats = adapter_or_others[0]
188
+ cache_options = stats.cache_options
189
+ cached_steps = len(stats.cached_steps)
190
+ elif isinstance(adapter_or_others, dict):
191
+
192
+ # Assume context_kwargs
193
+ cache_options = adapter_or_others
194
+ cached_steps = None
195
+ cache_type = cache_options.get("cache_type", CacheType.NONE)
196
+ stats = None
197
+ parallelism_config = cache_options.get("parallelism_config", None)
198
+
199
+ if cache_type == CacheType.NONE:
200
+ return "NONE"
201
+ else:
202
+ raise ValueError(
203
+ "Please set pipe_or_stats param as one of: "
204
+ "DiffusionPipeline | CacheStats | Dict[str, Any] | List[CacheStats]"
205
+ " | BlockAdapter | Transformer"
206
+ )
207
+
208
+ if stats is not None:
209
+ parallelism_config = stats.parallelism_config
210
+
211
+ if not cache_options and parallelism_config is None:
212
+ return "NONE"
213
+
214
+ def cache_str():
215
+ cache_config: BasicCacheConfig = cache_options.get("cache_config", None)
216
+ if cache_config is not None:
217
+ if cache_config.cache_type == CacheType.NONE:
218
+ return "NONE"
219
+ elif cache_config.cache_type == CacheType.DBCache:
220
+ return cache_config.strify()
221
+ elif cache_config.cache_type == CacheType.DBPrune:
222
+ pruned_ratio = stats.pruned_ratio
223
+ if pruned_ratio is not None:
224
+ return f"{cache_config.strify()}_P{round(pruned_ratio * 100, 2)}"
225
+ return cache_config.strify()
226
+ return "NONE"
227
+
228
+ def calibrator_str():
229
+ calibrator_config: CalibratorConfig = cache_options.get(
230
+ "calibrator_config", None
231
+ )
232
+ if calibrator_config is not None:
233
+ return calibrator_config.strify()
234
+ return "T0O0"
235
+
236
+ def parallelism_str():
237
+ if parallelism_config is not None:
238
+ return f"_{parallelism_config.strify()}"
239
+ return ""
240
+
241
+ cache_type_str = f"{cache_str()}"
242
+ if cache_type_str != "NONE":
243
+ cache_type_str += f"_{calibrator_str()}"
244
+ cache_type_str += f"{parallelism_str()}"
245
+
246
+ if cached_steps:
247
+ cache_type_str += f"_S{cached_steps}"
248
+
249
+ return cache_type_str
250
+
251
+
252
+ def _summary(
253
+ pipe_or_module: Union[
254
+ DiffusionPipeline,
255
+ torch.nn.Module,
256
+ ],
257
+ details: bool = False,
258
+ logging: bool = True,
259
+ **kwargs,
260
+ ) -> CacheStats:
261
+ cache_stats = CacheStats()
262
+
263
+ # Get stats from transformer
264
+ if not isinstance(pipe_or_module, torch.nn.Module):
265
+ assert hasattr(pipe_or_module, "transformer")
266
+ module = pipe_or_module.transformer
267
+ cls_name = module.__class__.__name__
268
+ else:
269
+ module = pipe_or_module
270
+
271
+ cls_name = module.__class__.__name__
272
+ if isinstance(module, torch.nn.ModuleList):
273
+ cls_name = module[0].__class__.__name__
274
+
275
+ if hasattr(module, "_context_kwargs"):
276
+ cache_options = module._context_kwargs
277
+ cache_stats.cache_options = cache_options
278
+ if logging:
279
+ print(f"\n🤗Context Options: {cls_name}\n\n{cache_options}")
280
+ else:
281
+ if logging:
282
+ logger.warning(f"Can't find Context Options for: {cls_name}")
283
+
284
+ if hasattr(module, "_parallelism_config"):
285
+ parallelism_config: ParallelismConfig = module._parallelism_config
286
+ cache_stats.parallelism_config = parallelism_config
287
+ if logging:
288
+ print(
289
+ f"\n🤖Parallelism Config: {cls_name}\n\n{parallelism_config.strify(True)}"
290
+ )
291
+ else:
292
+ if logging:
293
+ logger.warning(f"Can't find Parallelism Config for: {cls_name}")
294
+
295
+ if hasattr(module, "_cached_steps"):
296
+ cached_steps: list[int] = module._cached_steps
297
+ residual_diffs: dict[str, list | float] = dict(module._residual_diffs)
298
+
299
+ if hasattr(module, "_pruned_steps"):
300
+ pruned_steps: list[int] = module._pruned_steps
301
+ pruned_blocks: list[int] = module._pruned_blocks
302
+ actual_blocks: list[int] = module._actual_blocks
303
+ pruned_ratio: float = module._pruned_ratio
304
+ else:
305
+ pruned_steps = []
306
+ pruned_blocks = []
307
+ actual_blocks = []
308
+ pruned_ratio = None
309
+
310
+ cache_stats.cached_steps = cached_steps
311
+ cache_stats.residual_diffs = residual_diffs
312
+
313
+ cache_stats.pruned_steps = pruned_steps
314
+ cache_stats.pruned_blocks = pruned_blocks
315
+ cache_stats.actual_blocks = actual_blocks
316
+ cache_stats.pruned_ratio = pruned_ratio
317
+
318
+ if residual_diffs and logging:
319
+ diffs_values = list(residual_diffs.values())
320
+ if isinstance(diffs_values[0], list):
321
+ diffs_values = [v for sublist in diffs_values for v in sublist]
322
+ qmin = np.min(diffs_values)
323
+ q0 = np.percentile(diffs_values, 0)
324
+ q1 = np.percentile(diffs_values, 25)
325
+ q2 = np.percentile(diffs_values, 50)
326
+ q3 = np.percentile(diffs_values, 75)
327
+ q4 = np.percentile(diffs_values, 95)
328
+ qmax = np.max(diffs_values)
329
+
330
+ if pruned_ratio is not None:
331
+ print(
332
+ f"\n⚡️Pruned Blocks and Residual Diffs Statistics: {cls_name}\n"
333
+ )
334
+
335
+ print(
336
+ "| Pruned Blocks | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
337
+ )
338
+ print(
339
+ "|---------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
340
+ )
341
+ print(
342
+ f"| {sum(pruned_blocks):<13} | {round(q0, 3):<9} | {round(q1, 3):<9} "
343
+ f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
344
+ f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
345
+ )
346
+ print("")
347
+ else:
348
+ print(
349
+ f"\n⚡️Cache Steps and Residual Diffs Statistics: {cls_name}\n"
350
+ )
351
+
352
+ print(
353
+ "| Cache Steps | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
354
+ )
355
+ print(
356
+ "|-------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
357
+ )
358
+ print(
359
+ f"| {len(cached_steps):<11} | {round(q0, 3):<9} | {round(q1, 3):<9} "
360
+ f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
361
+ f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
362
+ )
363
+ print("")
364
+
365
+ if pruned_ratio is not None:
366
+ print(
367
+ f"Dynamic Block Prune Ratio: {round(pruned_ratio * 100, 2)}% ({sum(pruned_blocks)}/{sum(actual_blocks)})\n"
368
+ )
369
+
370
+ if details:
371
+ if pruned_ratio is not None:
372
+ print(
373
+ f"📚Pruned Blocks and Residual Diffs Details: {cls_name}\n"
374
+ )
375
+ pprint(
376
+ f"Pruned Blocks: {len(pruned_blocks)}, {pruned_blocks}",
377
+ )
378
+ pprint(
379
+ f"Actual Blocks: {len(actual_blocks)}, {actual_blocks}",
380
+ )
381
+ pprint(
382
+ f"Residual Diffs: {len(residual_diffs)}, {residual_diffs}",
383
+ compact=True,
384
+ )
385
+ else:
386
+ print(
387
+ f"📚Cache Steps and Residual Diffs Details: {cls_name}\n"
388
+ )
389
+ pprint(
390
+ f"Cache Steps: {len(cached_steps)}, {cached_steps}",
391
+ )
392
+ pprint(
393
+ f"Residual Diffs: {len(residual_diffs)}, {residual_diffs}",
394
+ compact=True,
395
+ )
396
+
397
+ if hasattr(module, "_cfg_cached_steps"):
398
+ cfg_cached_steps: list[int] = module._cfg_cached_steps
399
+ cfg_residual_diffs: dict[str, list | float] = dict(
400
+ module._cfg_residual_diffs
401
+ )
402
+
403
+ if hasattr(module, "_cfg_pruned_steps"):
404
+ cfg_pruned_steps: list[int] = module._cfg_pruned_steps
405
+ cfg_pruned_blocks: list[int] = module._cfg_pruned_blocks
406
+ cfg_actual_blocks: list[int] = module._cfg_actual_blocks
407
+ cfg_pruned_ratio: float = module._cfg_pruned_ratio
408
+ else:
409
+ cfg_pruned_steps = []
410
+ cfg_pruned_blocks = []
411
+ cfg_actual_blocks = []
412
+ cfg_pruned_ratio = None
413
+
414
+ cache_stats.cfg_cached_steps = cfg_cached_steps
415
+ cache_stats.cfg_residual_diffs = cfg_residual_diffs
416
+ cache_stats.cfg_pruned_steps = cfg_pruned_steps
417
+ cache_stats.cfg_pruned_blocks = cfg_pruned_blocks
418
+ cache_stats.cfg_actual_blocks = cfg_actual_blocks
419
+ cache_stats.cfg_pruned_ratio = cfg_pruned_ratio
420
+
421
+ if cfg_residual_diffs and logging:
422
+ cfg_diffs_values = list(cfg_residual_diffs.values())
423
+ if isinstance(cfg_diffs_values[0], list):
424
+ cfg_diffs_values = [
425
+ v for sublist in cfg_diffs_values for v in sublist
426
+ ]
427
+ qmin = np.min(cfg_diffs_values)
428
+ q0 = np.percentile(cfg_diffs_values, 0)
429
+ q1 = np.percentile(cfg_diffs_values, 25)
430
+ q2 = np.percentile(cfg_diffs_values, 50)
431
+ q3 = np.percentile(cfg_diffs_values, 75)
432
+ q4 = np.percentile(cfg_diffs_values, 95)
433
+ qmax = np.max(cfg_diffs_values)
434
+
435
+ if cfg_pruned_ratio is not None:
436
+ print(
437
+ f"\n⚡️CFG Pruned Blocks and Residual Diffs Statistics: {cls_name}\n"
438
+ )
439
+
440
+ print(
441
+ "| CFG Pruned Blocks | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
442
+ )
443
+ print(
444
+ "|-------------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
445
+ )
446
+ print(
447
+ f"| {sum(cfg_pruned_blocks):<18} | {round(q0, 3):<9} | {round(q1, 3):<9} "
448
+ f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
449
+ f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
450
+ )
451
+ print("")
452
+ else:
453
+ print(
454
+ f"\n⚡️CFG Cache Steps and Residual Diffs Statistics: {cls_name}\n"
455
+ )
456
+
457
+ print(
458
+ "| CFG Cache Steps | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
459
+ )
460
+ print(
461
+ "|-----------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
462
+ )
463
+ print(
464
+ f"| {len(cfg_cached_steps):<15} | {round(q0, 3):<9} | {round(q1, 3):<9} "
465
+ f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
466
+ f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
467
+ )
468
+ print("")
469
+
470
+ if cfg_pruned_ratio is not None:
471
+ print(
472
+ f"CFG Dynamic Block Prune Ratio: {round(cfg_pruned_ratio * 100, 2)}% ({sum(cfg_pruned_blocks)}/{sum(cfg_actual_blocks)})\n"
473
+ )
474
+
475
+ if details:
476
+ if cfg_pruned_ratio is not None:
477
+ print(
478
+ f"📚CFG Pruned Blocks and Residual Diffs Details: {cls_name}\n"
479
+ )
480
+ pprint(
481
+ f"CFG Pruned Blocks: {len(cfg_pruned_blocks)}, {cfg_pruned_blocks}",
482
+ )
483
+ pprint(
484
+ f"CFG Actual Blocks: {len(cfg_actual_blocks)}, {cfg_actual_blocks}",
485
+ )
486
+ pprint(
487
+ f"CFG Residual Diffs: {len(cfg_residual_diffs)}, {cfg_residual_diffs}",
488
+ compact=True,
489
+ )
490
+ else:
491
+ print(
492
+ f"📚CFG Cache Steps and Residual Diffs Details: {cls_name}\n"
493
+ )
494
+ pprint(
495
+ f"CFG Cache Steps: {len(cfg_cached_steps)}, {cfg_cached_steps}",
496
+ )
497
+ pprint(
498
+ f"CFG Residual Diffs: {len(cfg_residual_diffs)}, {cfg_residual_diffs}",
499
+ compact=True,
500
+ )
501
+
502
+ return cache_stats
503
+
504
+
505
+ def supported_matrix() -> str | None:
506
+ try:
507
+ from cache_dit.caching.block_adapters.block_registers import (
508
+ BlockAdapterRegistry,
509
+ )
510
+
511
+ _pipelines_supported_cache = BlockAdapterRegistry.supported_pipelines()[
512
+ 1
513
+ ]
514
+ _pipelines_supported_cache += [
515
+ "LongCatVideo", # not in diffusers, but supported
516
+ ]
517
+ from cache_dit.parallelism.backends.native_diffusers import (
518
+ ContextParallelismPlannerRegister,
519
+ )
520
+
521
+ _pipelines_supported_context_parallelism = (
522
+ ContextParallelismPlannerRegister.supported_planners()[1]
523
+ )
524
+ from cache_dit.parallelism.backends.native_pytorch import (
525
+ TensorParallelismPlannerRegister,
526
+ )
527
+
528
+ _pipelines_supported_tensor_parallelism = (
529
+ TensorParallelismPlannerRegister.supported_planners()[1]
530
+ )
531
+ # Add some special aliases since cp/tp planners use the name shortcut
532
+ # of Transformer only.
533
+ _pipelines_supported_context_parallelism += [
534
+ "Wan",
535
+ "LTX",
536
+ "VisualCloze",
537
+ ]
538
+ _pipelines_supported_tensor_parallelism += [
539
+ "Wan",
540
+ "VisualCloze",
541
+ ]
542
+
543
+ # Generate the supported matrix, markdown table format
544
+ matrix_lines: List[str] = []
545
+ header = "| Model | Cache | CP | TP | Model | Cache | CP | TP |"
546
+ matrix_lines.append(header)
547
+ matrix_lines.append("|:---|:---|:---|:---|:---|:---|:---|:---|")
548
+ half = (len(_pipelines_supported_cache) + 1) // 2
549
+ link = (
550
+ "https://github.com/vipshop/cache-dit/blob/main/examples/pipeline"
551
+ )
552
+ for i in range(half):
553
+ pipeline_left = _pipelines_supported_cache[i]
554
+ cp_support_left = (
555
+ "✅"
556
+ if pipeline_left in _pipelines_supported_context_parallelism
557
+ else "✖️"
558
+ )
559
+ tp_support_left = (
560
+ "✅"
561
+ if pipeline_left in _pipelines_supported_tensor_parallelism
562
+ else "✖️"
563
+ )
564
+ if i + half < len(_pipelines_supported_cache):
565
+ pipeline_right = _pipelines_supported_cache[i + half]
566
+ cp_support_right = (
567
+ "✅"
568
+ if pipeline_right
569
+ in _pipelines_supported_context_parallelism
570
+ else "✖️"
571
+ )
572
+ tp_support_right = (
573
+ "✅"
574
+ if pipeline_right in _pipelines_supported_tensor_parallelism
575
+ else "✖️"
576
+ )
577
+ else:
578
+ pipeline_right = ""
579
+ cp_support_right = ""
580
+ tp_support_right = ""
581
+ line = (
582
+ f"| **🎉[{pipeline_left}]({link})** | ✅ | {cp_support_left} | {tp_support_left} "
583
+ f"| **🎉[{pipeline_right}]({link})** | ✅ | {cp_support_right} | {tp_support_right} | "
584
+ )
585
+ matrix_lines.append(line)
586
+
587
+ matrix_str = "\n".join(matrix_lines)
588
+
589
+ print("\nSupported Cache and Parallelism Matrix:\n")
590
+ print(matrix_str)
591
+ return matrix_str
592
+ except Exception:
593
+ return None