cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/offload_utils.py +115 -0
  11. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  12. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  13. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  14. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  15. cache_dit/caching/cache_contexts/__init__.py +28 -0
  16. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
  18. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  21. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  22. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  23. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  24. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  25. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  26. cache_dit/caching/cache_interface.py +358 -0
  27. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  28. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  29. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  30. cache_dit/caching/patch_functors/__init__.py +15 -0
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  36. cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
  37. cache_dit/caching/utils.py +68 -0
  38. cache_dit/metrics/__init__.py +11 -0
  39. cache_dit/metrics/metrics.py +3 -0
  40. cache_dit/parallelism/__init__.py +3 -0
  41. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  57. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  58. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  59. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  60. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  61. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  62. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  68. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  69. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  70. cache_dit/parallelism/parallel_backend.py +26 -0
  71. cache_dit/parallelism/parallel_config.py +88 -0
  72. cache_dit/parallelism/parallel_interface.py +77 -0
  73. cache_dit/quantize/__init__.py +7 -0
  74. cache_dit/quantize/backends/__init__.py +1 -0
  75. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  76. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  77. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
  78. cache_dit/quantize/quantize_backend.py +0 -0
  79. cache_dit/quantize/quantize_config.py +0 -0
  80. cache_dit/quantize/quantize_interface.py +3 -16
  81. cache_dit/summary.py +593 -0
  82. cache_dit/utils.py +46 -290
  83. cache_dit-1.0.14.dist-info/METADATA +301 -0
  84. cache_dit-1.0.14.dist-info/RECORD +102 -0
  85. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  86. cache_dit/cache_factory/__init__.py +0 -28
  87. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  88. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  89. cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
  90. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  91. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
  92. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
  93. cache_dit/cache_factory/cache_blocks/utils.py +0 -41
  94. cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
  95. cache_dit/cache_factory/cache_interface.py +0 -217
  96. cache_dit/cache_factory/patch_functors/__init__.py +0 -12
  97. cache_dit/cache_factory/utils.py +0 -57
  98. cache_dit-0.3.2.dist-info/METADATA +0 -753
  99. cache_dit-0.3.2.dist-info/RECORD +0 -56
  100. cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
  101. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  102. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  103. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  104. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  105. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  106. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  107. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  108. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
cache_dit/utils.py CHANGED
@@ -1,18 +1,10 @@
1
+ import gc
2
+ import time
1
3
  import torch
2
- import dataclasses
3
4
  import diffusers
4
5
  import builtins as __builtin__
5
6
  import contextlib
6
7
 
7
- import numpy as np
8
- from pprint import pprint
9
- from diffusers import DiffusionPipeline
10
-
11
- from typing import Dict, Any, List, Union
12
- from cache_dit.cache_factory import CacheType
13
- from cache_dit.cache_factory import BlockAdapter
14
- from cache_dit.cache_factory import BasicCacheConfig
15
- from cache_dit.cache_factory import CalibratorConfig
16
8
  from cache_dit.logger import init_logger
17
9
 
18
10
 
@@ -36,290 +28,54 @@ def is_diffusers_at_least_0_3_5() -> bool:
36
28
  return diffusers.__version__ >= "0.35.0"
37
29
 
38
30
 
39
- @dataclasses.dataclass
40
- class CacheStats:
41
- cache_options: dict = dataclasses.field(default_factory=dict)
42
- cached_steps: list[int] = dataclasses.field(default_factory=list)
43
- residual_diffs: dict[str, float] = dataclasses.field(default_factory=dict)
44
- cfg_cached_steps: list[int] = dataclasses.field(default_factory=list)
45
- cfg_residual_diffs: dict[str, float] = dataclasses.field(
46
- default_factory=dict
47
- )
48
-
31
+ @torch.compiler.disable
32
+ def maybe_empty_cache():
33
+ try:
34
+ time.sleep(1)
35
+ gc.collect()
36
+ torch.cuda.empty_cache()
37
+ torch.cuda.ipc_collect()
38
+ time.sleep(1)
39
+ gc.collect()
40
+ torch.cuda.empty_cache()
41
+ torch.cuda.ipc_collect()
42
+ except Exception:
43
+ pass
49
44
 
50
- def summary(
51
- adapter_or_others: Union[
52
- BlockAdapter,
53
- DiffusionPipeline,
54
- torch.nn.Module,
55
- ],
56
- details: bool = False,
57
- logging: bool = True,
58
- **kwargs,
59
- ) -> List[CacheStats]:
60
- if adapter_or_others is None:
61
- return [CacheStats()]
62
45
 
63
- if not isinstance(adapter_or_others, BlockAdapter):
64
- if not isinstance(adapter_or_others, DiffusionPipeline):
65
- transformer = adapter_or_others
66
- transformer_2 = None
46
+ @torch.compiler.disable
47
+ def print_tensor(
48
+ x: torch.Tensor,
49
+ name: str,
50
+ dim: int = 1,
51
+ no_dist_shape: bool = True,
52
+ disable: bool = False,
53
+ ):
54
+ if disable:
55
+ return
56
+
57
+ x = x.contiguous()
58
+ if torch.distributed.is_initialized():
59
+ # all gather hidden_states and check values mean
60
+ gather_x = [
61
+ torch.zeros_like(x)
62
+ for _ in range(torch.distributed.get_world_size())
63
+ ]
64
+ torch.distributed.all_gather(gather_x, x)
65
+ gather_x = torch.cat(gather_x, dim=dim)
66
+
67
+ if not no_dist_shape:
68
+ x_shape = gather_x.shape
67
69
  else:
68
- transformer = adapter_or_others.transformer
69
- transformer_2 = None
70
- if hasattr(adapter_or_others, "transformer_2"):
71
- transformer_2 = adapter_or_others.transformer_2
72
-
73
- if not BlockAdapter.is_cached(transformer):
74
- return [CacheStats()]
75
-
76
- blocks_stats: List[CacheStats] = []
77
- for blocks in BlockAdapter.find_blocks(transformer):
78
- blocks_stats.append(
79
- _summary(
80
- blocks,
81
- details=details,
82
- logging=logging,
83
- **kwargs,
84
- )
85
- )
86
-
87
- if transformer_2 is not None:
88
- for blocks in BlockAdapter.find_blocks(transformer_2):
89
- blocks_stats.append(
90
- _summary(
91
- blocks,
92
- details=details,
93
- logging=logging,
94
- **kwargs,
95
- )
96
- )
70
+ x_shape = x.shape
97
71
 
98
- blocks_stats.append(
99
- _summary(
100
- transformer,
101
- details=details,
102
- logging=logging,
103
- **kwargs,
104
- )
105
- )
106
- if transformer_2 is not None:
107
- blocks_stats.append(
108
- _summary(
109
- transformer_2,
110
- details=details,
111
- logging=logging,
112
- **kwargs,
113
- )
114
- )
115
-
116
- blocks_stats = [stats for stats in blocks_stats if stats.cache_options]
117
-
118
- return blocks_stats if len(blocks_stats) else [CacheStats()]
119
-
120
- adapter = adapter_or_others
121
- if not BlockAdapter.check_block_adapter(adapter):
122
- return [CacheStats()]
123
-
124
- blocks_stats = []
125
- flatten_blocks = BlockAdapter.flatten(adapter.blocks)
126
- for blocks in flatten_blocks:
127
- blocks_stats.append(
128
- _summary(
129
- blocks,
130
- details=details,
131
- logging=logging,
132
- **kwargs,
72
+ if torch.distributed.get_rank() == 0:
73
+ print(
74
+ f"{name}, mean: {gather_x.float().mean().item()}, "
75
+ f"std: {gather_x.float().std().item()}, shape: {x_shape}"
133
76
  )
134
- )
135
-
136
- blocks_stats = [stats for stats in blocks_stats if stats.cache_options]
137
-
138
- return blocks_stats if len(blocks_stats) else [CacheStats()]
139
-
140
-
141
- def strify(
142
- adapter_or_others: Union[
143
- BlockAdapter,
144
- DiffusionPipeline,
145
- CacheStats,
146
- List[CacheStats],
147
- Dict[str, Any],
148
- ],
149
- ) -> str:
150
- if isinstance(adapter_or_others, BlockAdapter):
151
- stats = summary(adapter_or_others, logging=False)[-1]
152
- cache_options = stats.cache_options
153
- cached_steps = len(stats.cached_steps)
154
- elif isinstance(adapter_or_others, DiffusionPipeline):
155
- stats = summary(adapter_or_others, logging=False)[-1]
156
- cache_options = stats.cache_options
157
- cached_steps = len(stats.cached_steps)
158
- elif isinstance(adapter_or_others, CacheStats):
159
- stats = adapter_or_others
160
- cache_options = stats.cache_options
161
- cached_steps = len(stats.cached_steps)
162
- elif isinstance(adapter_or_others, list):
163
- stats = adapter_or_others[0]
164
- cache_options = stats.cache_options
165
- cached_steps = len(stats.cached_steps)
166
- elif isinstance(adapter_or_others, dict):
167
-
168
- # Assume cache_context_kwargs
169
- cache_options = adapter_or_others
170
- cached_steps = None
171
- cache_type = cache_options.get("cache_type", CacheType.NONE)
172
-
173
- if cache_type == CacheType.NONE:
174
- return "NONE"
175
77
  else:
176
- raise ValueError(
177
- "Please set pipe_or_stats param as one of: "
178
- "DiffusionPipeline | CacheStats | Dict[str, Any]"
179
- )
180
-
181
- if not cache_options:
182
- return "NONE"
183
-
184
- def basic_cache_str():
185
- cache_config: BasicCacheConfig = cache_options.get("cache_config", None)
186
- if cache_config is not None:
187
- return cache_config.strify()
188
- return "NONE"
189
-
190
- def calibrator_str():
191
- calibrator_config: CalibratorConfig = cache_options.get(
192
- "calibrator_config", None
78
+ print(
79
+ f"{name}, mean: {x.float().mean().item()}, "
80
+ f"std: {x.float().std().item()}, shape: {x.shape}"
193
81
  )
194
- if calibrator_config is not None:
195
- return calibrator_config.strify()
196
- return "T0O0"
197
-
198
- cache_type_str = f"{basic_cache_str()}_{calibrator_str()}"
199
-
200
- if cached_steps:
201
- cache_type_str += f"_S{cached_steps}"
202
-
203
- return cache_type_str
204
-
205
-
206
- def _summary(
207
- pipe_or_module: Union[
208
- DiffusionPipeline,
209
- torch.nn.Module,
210
- ],
211
- details: bool = False,
212
- logging: bool = True,
213
- **kwargs,
214
- ) -> CacheStats:
215
- cache_stats = CacheStats()
216
-
217
- if not isinstance(pipe_or_module, torch.nn.Module):
218
- assert hasattr(pipe_or_module, "transformer")
219
- module = pipe_or_module.transformer
220
- cls_name = module.__class__.__name__
221
- else:
222
- module = pipe_or_module
223
-
224
- cls_name = module.__class__.__name__
225
- if isinstance(module, torch.nn.ModuleList):
226
- cls_name = module[0].__class__.__name__
227
-
228
- if hasattr(module, "_cache_context_kwargs"):
229
- cache_options = module._cache_context_kwargs
230
- cache_stats.cache_options = cache_options
231
- if logging:
232
- print(f"\n🤗Cache Options: {cls_name}\n\n{cache_options}")
233
- else:
234
- if logging:
235
- logger.warning(f"Can't find Cache Options for: {cls_name}")
236
-
237
- if hasattr(module, "_cached_steps"):
238
- cached_steps: list[int] = module._cached_steps
239
- residual_diffs: dict[str, float] = dict(module._residual_diffs)
240
- cache_stats.cached_steps = cached_steps
241
- cache_stats.residual_diffs = residual_diffs
242
-
243
- if residual_diffs and logging:
244
- diffs_values = list(residual_diffs.values())
245
- qmin = np.min(diffs_values)
246
- q0 = np.percentile(diffs_values, 0)
247
- q1 = np.percentile(diffs_values, 25)
248
- q2 = np.percentile(diffs_values, 50)
249
- q3 = np.percentile(diffs_values, 75)
250
- q4 = np.percentile(diffs_values, 95)
251
- qmax = np.max(diffs_values)
252
-
253
- print(
254
- f"\n⚡️Cache Steps and Residual Diffs Statistics: {cls_name}\n"
255
- )
256
-
257
- print(
258
- "| Cache Steps | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
259
- )
260
- print(
261
- "|-------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
262
- )
263
- print(
264
- f"| {len(cached_steps):<11} | {round(q0, 3):<9} | {round(q1, 3):<9} "
265
- f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
266
- f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
267
- )
268
- print("")
269
-
270
- if details:
271
- print(f"📚Cache Steps and Residual Diffs Details: {cls_name}\n")
272
- pprint(
273
- f"Cache Steps: {len(cached_steps)}, {cached_steps}",
274
- )
275
- pprint(
276
- f"Residual Diffs: {len(residual_diffs)}, {residual_diffs}",
277
- compact=True,
278
- )
279
-
280
- if hasattr(module, "_cfg_cached_steps"):
281
- cfg_cached_steps: list[int] = module._cfg_cached_steps
282
- cfg_residual_diffs: dict[str, float] = dict(module._cfg_residual_diffs)
283
- cache_stats.cfg_cached_steps = cfg_cached_steps
284
- cache_stats.cfg_residual_diffs = cfg_residual_diffs
285
-
286
- if cfg_residual_diffs and logging:
287
- cfg_diffs_values = list(cfg_residual_diffs.values())
288
- qmin = np.min(cfg_diffs_values)
289
- q0 = np.percentile(cfg_diffs_values, 0)
290
- q1 = np.percentile(cfg_diffs_values, 25)
291
- q2 = np.percentile(cfg_diffs_values, 50)
292
- q3 = np.percentile(cfg_diffs_values, 75)
293
- q4 = np.percentile(cfg_diffs_values, 95)
294
- qmax = np.max(cfg_diffs_values)
295
-
296
- print(
297
- f"\n⚡️CFG Cache Steps and Residual Diffs Statistics: {cls_name}\n"
298
- )
299
-
300
- print(
301
- "| CFG Cache Steps | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
302
- )
303
- print(
304
- "|-----------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
305
- )
306
- print(
307
- f"| {len(cfg_cached_steps):<15} | {round(q0, 3):<9} | {round(q1, 3):<9} "
308
- f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
309
- f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
310
- )
311
- print("")
312
-
313
- if details:
314
- print(
315
- f"📚CFG Cache Steps and Residual Diffs Details: {cls_name}\n"
316
- )
317
- pprint(
318
- f"CFG Cache Steps: {len(cfg_cached_steps)}, {cfg_cached_steps}",
319
- )
320
- pprint(
321
- f"CFG Residual Diffs: {len(cfg_residual_diffs)}, {cfg_residual_diffs}",
322
- compact=True,
323
- )
324
-
325
- return cache_stats