cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,21 @@
1
+ import copy
1
2
  import torch
2
3
  import unittest
3
4
  import functools
4
5
  from contextlib import ExitStack
5
6
  from typing import Dict, List, Tuple, Any, Union, Callable, Optional
6
7
 
7
- from diffusers import DiffusionPipeline
8
-
9
- from cache_dit.cache_factory.cache_types import CacheType
10
- from cache_dit.cache_factory.block_adapters import BlockAdapter
11
- from cache_dit.cache_factory.block_adapters import ParamsModifier
12
- from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
13
- from cache_dit.cache_factory.cache_contexts import CachedContextManager
14
- from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
15
- from cache_dit.cache_factory.cache_contexts import CalibratorConfig
16
- from cache_dit.cache_factory.cache_blocks import CachedBlocks
8
+ from diffusers import DiffusionPipeline, ModelMixin
9
+
10
+ from cache_dit.caching.cache_types import CacheType
11
+ from cache_dit.caching.block_adapters import BlockAdapter
12
+ from cache_dit.caching.block_adapters import FakeDiffusionPipeline
13
+ from cache_dit.caching.block_adapters import ParamsModifier
14
+ from cache_dit.caching.block_adapters import BlockAdapterRegistry
15
+ from cache_dit.caching.cache_contexts import ContextManager
16
+ from cache_dit.caching.cache_contexts import BasicCacheConfig
17
+ from cache_dit.caching.cache_contexts import CalibratorConfig
18
+ from cache_dit.caching.cache_blocks import UnifiedBlocks
17
19
  from cache_dit.logger import init_logger
18
20
 
19
21
  logger = init_logger(__name__)
@@ -31,8 +33,11 @@ class CachedAdapter:
31
33
  pipe_or_adapter: Union[
32
34
  DiffusionPipeline,
33
35
  BlockAdapter,
36
+ # Transformer-only
37
+ torch.nn.Module,
38
+ ModelMixin,
34
39
  ],
35
- **cache_context_kwargs,
40
+ **context_kwargs,
36
41
  ) -> Union[
37
42
  DiffusionPipeline,
38
43
  BlockAdapter,
@@ -41,7 +46,9 @@ class CachedAdapter:
41
46
  pipe_or_adapter is not None
42
47
  ), "pipe or block_adapter can not both None!"
43
48
 
44
- if isinstance(pipe_or_adapter, DiffusionPipeline):
49
+ if isinstance(
50
+ pipe_or_adapter, (DiffusionPipeline, torch.nn.Module, ModelMixin)
51
+ ):
45
52
  if BlockAdapterRegistry.is_supported(pipe_or_adapter):
46
53
  logger.info(
47
54
  f"{pipe_or_adapter.__class__.__name__} is officially "
@@ -51,16 +58,22 @@ class CachedAdapter:
51
58
  block_adapter = BlockAdapterRegistry.get_adapter(
52
59
  pipe_or_adapter
53
60
  )
54
- if params_modifiers := cache_context_kwargs.pop(
61
+ assert block_adapter is not None, (
62
+ f"BlockAdapter for {pipe_or_adapter.__class__.__name__} "
63
+ "should not be None!"
64
+ )
65
+ if params_modifiers := context_kwargs.pop(
55
66
  "params_modifiers",
56
67
  None,
57
68
  ):
58
69
  block_adapter.params_modifiers = params_modifiers
59
70
 
60
- return cls.cachify(
61
- block_adapter,
62
- **cache_context_kwargs,
63
- ).pipe
71
+ block_adapter = cls.cachify(block_adapter, **context_kwargs)
72
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
73
+ return block_adapter.pipe
74
+
75
+ return block_adapter.transformer
76
+
64
77
  else:
65
78
  raise ValueError(
66
79
  f"{pipe_or_adapter.__class__.__name__} is not officially supported "
@@ -72,21 +85,21 @@ class CachedAdapter:
72
85
  "Adapting Cache Acceleration using custom BlockAdapter!"
73
86
  )
74
87
  if pipe_or_adapter.params_modifiers is None:
75
- if params_modifiers := cache_context_kwargs.pop(
88
+ if params_modifiers := context_kwargs.pop(
76
89
  "params_modifiers", None
77
90
  ):
78
91
  pipe_or_adapter.params_modifiers = params_modifiers
79
92
 
80
93
  return cls.cachify(
81
94
  pipe_or_adapter,
82
- **cache_context_kwargs,
95
+ **context_kwargs,
83
96
  )
84
97
 
85
98
  @classmethod
86
99
  def cachify(
87
100
  cls,
88
101
  block_adapter: BlockAdapter,
89
- **cache_context_kwargs,
102
+ **context_kwargs,
90
103
  ) -> BlockAdapter:
91
104
 
92
105
  if block_adapter.auto:
@@ -103,14 +116,15 @@ class CachedAdapter:
103
116
 
104
117
  # 1. Apply cache on pipeline: wrap cache context, must
105
118
  # call create_context before mock_blocks.
106
- cls.create_context(
119
+ _, contexts_kwargs = cls.create_context(
107
120
  block_adapter,
108
- **cache_context_kwargs,
121
+ **context_kwargs,
109
122
  )
110
123
 
111
124
  # 2. Apply cache on transformer: mock cached blocks
112
125
  cls.mock_blocks(
113
126
  block_adapter,
127
+ contexts_kwargs,
114
128
  )
115
129
 
116
130
  return block_adapter
@@ -119,12 +133,10 @@ class CachedAdapter:
119
133
  def check_context_kwargs(
120
134
  cls,
121
135
  block_adapter: BlockAdapter,
122
- **cache_context_kwargs,
136
+ **context_kwargs,
123
137
  ):
124
- # Check cache_context_kwargs
125
- cache_config: BasicCacheConfig = cache_context_kwargs[
126
- "cache_config"
127
- ] # ref
138
+ # Check context_kwargs
139
+ cache_config: BasicCacheConfig = context_kwargs["cache_config"] # ref
128
140
  assert cache_config is not None, "cache_config can not be None."
129
141
  if cache_config.enable_separate_cfg is None:
130
142
  # Check cfg for some specific case if users don't set it as True
@@ -150,19 +162,23 @@ class CachedAdapter:
150
162
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
151
163
  )
152
164
 
153
- cache_type = cache_context_kwargs.pop("cache_type", None)
165
+ cache_type = context_kwargs.pop("cache_type", None)
154
166
  if cache_type is not None:
155
- assert (
156
- cache_type == CacheType.DBCache
157
- ), "Custom cache setting only support for DBCache now!"
167
+ assert isinstance(
168
+ cache_type, CacheType
169
+ ), f"cache_type must be CacheType Enum, but got {type(cache_type)}."
170
+ assert cache_type == cache_config.cache_type, (
171
+ f"cache_type from context_kwargs ({cache_type}) must be the same "
172
+ f"as that from cache_config ({cache_config.cache_type})."
173
+ )
158
174
 
159
- return cache_context_kwargs
175
+ return context_kwargs
160
176
 
161
177
  @classmethod
162
178
  def create_context(
163
179
  cls,
164
180
  block_adapter: BlockAdapter,
165
- **cache_context_kwargs,
181
+ **context_kwargs,
166
182
  ) -> Tuple[List[str], List[Dict[str, Any]]]:
167
183
 
168
184
  BlockAdapter.assert_normalized(block_adapter)
@@ -170,49 +186,71 @@ class CachedAdapter:
170
186
  if BlockAdapter.is_cached(block_adapter.pipe):
171
187
  return block_adapter.pipe
172
188
 
173
- # Check cache_context_kwargs
174
- cache_context_kwargs = cls.check_context_kwargs(
175
- block_adapter, **cache_context_kwargs
189
+ # Check context_kwargs
190
+ context_kwargs = cls.check_context_kwargs(
191
+ block_adapter, **context_kwargs
176
192
  )
177
- # Apply cache on pipeline: wrap cache context
178
- pipe_cls_name = block_adapter.pipe.__class__.__name__
179
193
 
180
194
  # Each Pipeline should have it's own context manager instance.
181
195
  # Different transformers (Wan2.2, etc) should shared the same
182
196
  # cache manager but with different cache context (according
183
197
  # to their unique instance id).
184
- cache_manager = CachedContextManager(
198
+ cache_config: BasicCacheConfig = context_kwargs.get(
199
+ "cache_config", None
200
+ )
201
+ assert cache_config is not None, "cache_config can not be None."
202
+ # Apply cache on pipeline: wrap cache context
203
+ pipe_cls_name = block_adapter.pipe.__class__.__name__
204
+ context_manager = ContextManager(
185
205
  name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
206
+ cache_type=cache_config.cache_type,
207
+ # Force use persistent_context for FakeDiffusionPipeline
208
+ persistent_context=isinstance(
209
+ block_adapter.pipe, FakeDiffusionPipeline
210
+ ),
186
211
  )
187
- block_adapter.pipe._cache_manager = cache_manager # instance level
188
-
189
212
  flatten_contexts, contexts_kwargs = cls.modify_context_params(
190
- block_adapter, **cache_context_kwargs
213
+ block_adapter, **context_kwargs
191
214
  )
192
215
 
193
- original_call = block_adapter.pipe.__class__.__call__
194
-
195
- @functools.wraps(original_call)
196
- def new_call(self, *args, **kwargs):
197
- with ExitStack() as stack:
198
- # cache context will be reset for each pipe inference
199
- for context_name, context_kwargs in zip(
200
- flatten_contexts, contexts_kwargs
201
- ):
202
- stack.enter_context(
203
- cache_manager.enter_context(
204
- cache_manager.reset_context(
205
- context_name,
206
- **context_kwargs,
207
- ),
216
+ block_adapter.pipe._context_manager = context_manager # instance level
217
+
218
+ if not context_manager.persistent_context:
219
+
220
+ original_call = block_adapter.pipe.__class__.__call__
221
+
222
+ @functools.wraps(original_call)
223
+ def new_call(self, *args, **kwargs):
224
+ with ExitStack() as stack:
225
+ # cache context will be reset for each pipe inference
226
+ for context_name, context_kwargs in zip(
227
+ flatten_contexts, contexts_kwargs
228
+ ):
229
+ stack.enter_context(
230
+ context_manager.enter_context(
231
+ context_manager.reset_context(
232
+ context_name,
233
+ **context_kwargs,
234
+ ),
235
+ )
208
236
  )
209
- )
210
- outputs = original_call(self, *args, **kwargs)
211
- cls.apply_stats_hooks(block_adapter)
212
- return outputs
237
+ outputs = original_call(self, *args, **kwargs)
238
+ cls.apply_stats_hooks(block_adapter)
239
+ return outputs
240
+
241
+ block_adapter.pipe.__class__.__call__ = new_call
242
+ block_adapter.pipe.__class__._original_call = original_call
243
+
244
+ else:
245
+ # Init persistent cache context for transformer
246
+ for context_name, context_kwargs in zip(
247
+ flatten_contexts, contexts_kwargs
248
+ ):
249
+ context_manager.reset_context(
250
+ context_name,
251
+ **context_kwargs,
252
+ )
213
253
 
214
- block_adapter.pipe.__class__.__call__ = new_call
215
- block_adapter.pipe.__class__._original_call = original_call
216
254
  block_adapter.pipe.__class__._is_cached = True
217
255
 
218
256
  cls.apply_params_hooks(block_adapter, contexts_kwargs)
@@ -223,14 +261,14 @@ class CachedAdapter:
223
261
  def modify_context_params(
224
262
  cls,
225
263
  block_adapter: BlockAdapter,
226
- **cache_context_kwargs,
264
+ **context_kwargs,
227
265
  ) -> Tuple[List[str], List[Dict[str, Any]]]:
228
266
 
229
267
  flatten_contexts = BlockAdapter.flatten(
230
268
  block_adapter.unique_blocks_name
231
269
  )
232
270
  contexts_kwargs = [
233
- cache_context_kwargs.copy()
271
+ copy.deepcopy(context_kwargs) # must deep copy
234
272
  for _ in range(
235
273
  len(flatten_contexts),
236
274
  )
@@ -251,9 +289,41 @@ class CachedAdapter:
251
289
  for i in range(
252
290
  min(len(contexts_kwargs), len(flatten_modifiers)),
253
291
  ):
254
- contexts_kwargs[i].update(
255
- flatten_modifiers[i]._context_kwargs,
256
- )
292
+ if "cache_config" in flatten_modifiers[i]._context_kwargs:
293
+ modifier_cache_config = flatten_modifiers[
294
+ i
295
+ ]._context_kwargs.get("cache_config", None)
296
+ modifier_calibrator_config = flatten_modifiers[
297
+ i
298
+ ]._context_kwargs.get("calibrator_config", None)
299
+ if modifier_cache_config is not None:
300
+ assert isinstance(
301
+ modifier_cache_config, BasicCacheConfig
302
+ ), (
303
+ f"cache_config must be BasicCacheConfig, but got "
304
+ f"{type(modifier_cache_config)}."
305
+ )
306
+ contexts_kwargs[i]["cache_config"].update(
307
+ **modifier_cache_config.as_dict()
308
+ )
309
+ if modifier_calibrator_config is not None:
310
+ assert isinstance(
311
+ modifier_calibrator_config, CalibratorConfig
312
+ ), (
313
+ f"calibrator_config must be CalibratorConfig, but got "
314
+ f"{type(modifier_calibrator_config)}."
315
+ )
316
+ if (
317
+ contexts_kwargs[i].get("calibrator_config", None)
318
+ is None
319
+ ):
320
+ contexts_kwargs[i][
321
+ "calibrator_config"
322
+ ] = modifier_calibrator_config
323
+ else:
324
+ contexts_kwargs[i]["calibrator_config"].update(
325
+ **modifier_calibrator_config.as_dict()
326
+ )
257
327
  cls._config_messages(**contexts_kwargs[i])
258
328
 
259
329
  return flatten_contexts, contexts_kwargs
@@ -267,7 +337,7 @@ class CachedAdapter:
267
337
  "calibrator_config", None
268
338
  )
269
339
  if cache_config is not None:
270
- message = f"Collected Cache Config: {cache_config.strify()}"
340
+ message = f"Collected Context Config: {cache_config.strify()}"
271
341
  if calibrator_config is not None:
272
342
  message += f", Calibrator Config: {calibrator_config.strify(details=True)}"
273
343
  else:
@@ -278,6 +348,7 @@ class CachedAdapter:
278
348
  def mock_blocks(
279
349
  cls,
280
350
  block_adapter: BlockAdapter,
351
+ contexts_kwargs: List[Dict],
281
352
  ) -> List[torch.nn.Module]:
282
353
 
283
354
  BlockAdapter.assert_normalized(block_adapter)
@@ -287,24 +358,28 @@ class CachedAdapter:
287
358
 
288
359
  # Apply cache on transformer: mock cached transformer blocks
289
360
  for (
290
- cached_blocks,
361
+ unified_blocks,
291
362
  transformer,
292
363
  blocks_name,
293
364
  unique_blocks_name,
294
365
  dummy_blocks_names,
295
366
  ) in zip(
296
- cls.collect_cached_blocks(block_adapter),
367
+ cls.collect_unified_blocks(
368
+ block_adapter,
369
+ contexts_kwargs,
370
+ ),
297
371
  block_adapter.transformer,
298
372
  block_adapter.blocks_name,
299
373
  block_adapter.unique_blocks_name,
300
374
  block_adapter.dummy_blocks_names,
301
375
  ):
302
376
  cls.mock_transformer(
303
- cached_blocks,
377
+ unified_blocks,
304
378
  transformer,
305
379
  blocks_name,
306
380
  unique_blocks_name,
307
381
  dummy_blocks_names,
382
+ block_adapter,
308
383
  )
309
384
 
310
385
  return block_adapter.transformer
@@ -312,11 +387,12 @@ class CachedAdapter:
312
387
  @classmethod
313
388
  def mock_transformer(
314
389
  cls,
315
- cached_blocks: Dict[str, torch.nn.ModuleList],
390
+ unified_blocks: Dict[str, torch.nn.ModuleList],
316
391
  transformer: torch.nn.Module,
317
392
  blocks_name: List[str],
318
393
  unique_blocks_name: List[str],
319
394
  dummy_blocks_names: List[str],
395
+ block_adapter: BlockAdapter,
320
396
  ) -> torch.nn.Module:
321
397
  dummy_blocks = torch.nn.ModuleList()
322
398
 
@@ -343,6 +419,8 @@ class CachedAdapter:
343
419
  # re-apply hooks to transformer after cache applied.
344
420
  # from diffusers.hooks.hooks import HookFunctionReference, HookRegistry
345
421
  # from diffusers.hooks.group_offloading import apply_group_offloading
422
+ context_manager: ContextManager = block_adapter.pipe._context_manager
423
+ assert isinstance(context_manager, ContextManager._supported_managers)
346
424
 
347
425
  def new_forward(self, *args, **kwargs):
348
426
  with ExitStack() as stack:
@@ -352,7 +430,7 @@ class CachedAdapter:
352
430
  ):
353
431
  stack.enter_context(
354
432
  unittest.mock.patch.object(
355
- self, name, cached_blocks[context_name]
433
+ self, name, unified_blocks[context_name]
356
434
  )
357
435
  )
358
436
  for dummy_name in dummy_blocks_names:
@@ -362,6 +440,13 @@ class CachedAdapter:
362
440
  )
363
441
  )
364
442
  outputs = original_forward(*args, **kwargs)
443
+
444
+ if (
445
+ context_manager.persistent_context
446
+ and context_manager.is_pre_refreshed()
447
+ ):
448
+ cls.apply_stats_hooks(block_adapter)
449
+
365
450
  return outputs
366
451
 
367
452
  def new_forward_with_hf_hook(self, *args, **kwargs):
@@ -388,46 +473,51 @@ class CachedAdapter:
388
473
  return transformer
389
474
 
390
475
  @classmethod
391
- def collect_cached_blocks(
476
+ def collect_unified_blocks(
392
477
  cls,
393
478
  block_adapter: BlockAdapter,
479
+ contexts_kwargs: List[Dict],
394
480
  ) -> List[Dict[str, torch.nn.ModuleList]]:
395
481
 
396
482
  BlockAdapter.assert_normalized(block_adapter)
397
483
 
398
484
  total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
399
- assert hasattr(block_adapter.pipe, "_cache_manager")
485
+ assert hasattr(block_adapter.pipe, "_context_manager")
400
486
  assert isinstance(
401
- block_adapter.pipe._cache_manager,
402
- CachedContextManager,
487
+ block_adapter.pipe._context_manager,
488
+ ContextManager._supported_managers,
403
489
  )
404
490
 
405
491
  for i in range(len(block_adapter.transformer)):
406
492
 
407
- cached_blocks_bind_context = {}
493
+ unified_blocks_bind_context = {}
408
494
  for j in range(len(block_adapter.blocks[i])):
409
- cached_blocks_bind_context[
495
+ cache_config: BasicCacheConfig = contexts_kwargs[
496
+ i * len(block_adapter.blocks[i]) + j
497
+ ]["cache_config"]
498
+ unified_blocks_bind_context[
410
499
  block_adapter.unique_blocks_name[i][j]
411
500
  ] = torch.nn.ModuleList(
412
501
  [
413
- CachedBlocks(
502
+ UnifiedBlocks(
414
503
  # 0. Transformer blocks configuration
415
504
  block_adapter.blocks[i][j],
416
505
  transformer=block_adapter.transformer[i],
417
506
  forward_pattern=block_adapter.forward_pattern[i][j],
418
507
  check_forward_pattern=block_adapter.check_forward_pattern,
419
508
  check_num_outputs=block_adapter.check_num_outputs,
420
- # 1. Cache context configuration
509
+ # 1. Cache/Prune context configuration
421
510
  cache_prefix=block_adapter.blocks_name[i][j],
422
511
  cache_context=block_adapter.unique_blocks_name[i][
423
512
  j
424
513
  ],
425
- cache_manager=block_adapter.pipe._cache_manager,
514
+ context_manager=block_adapter.pipe._context_manager,
515
+ cache_type=cache_config.cache_type,
426
516
  )
427
517
  ]
428
518
  )
429
519
 
430
- total_cached_blocks.append(cached_blocks_bind_context)
520
+ total_cached_blocks.append(unified_blocks_bind_context)
431
521
 
432
522
  return total_cached_blocks
433
523
 
@@ -437,7 +527,7 @@ class CachedAdapter:
437
527
  block_adapter: BlockAdapter,
438
528
  contexts_kwargs: List[Dict],
439
529
  ):
440
- block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
530
+ block_adapter.pipe._context_kwargs = contexts_kwargs[0]
441
531
 
442
532
  params_shift = 0
443
533
  for i in range(len(block_adapter.transformer)):
@@ -448,44 +538,43 @@ class CachedAdapter:
448
538
  block_adapter.transformer[i]._has_separate_cfg = (
449
539
  block_adapter.has_separate_cfg
450
540
  )
451
- block_adapter.transformer[i]._cache_context_kwargs = (
452
- contexts_kwargs[params_shift]
453
- )
541
+ block_adapter.transformer[i]._context_kwargs = contexts_kwargs[
542
+ params_shift
543
+ ]
454
544
 
455
545
  blocks = block_adapter.blocks[i]
456
546
  for j in range(len(blocks)):
457
547
  blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
458
- blocks[j]._cache_context_kwargs = contexts_kwargs[
459
- params_shift + j
460
- ]
548
+ blocks[j]._context_kwargs = contexts_kwargs[params_shift + j]
461
549
 
462
550
  params_shift += len(blocks)
463
551
 
464
552
  @classmethod
553
+ @torch.compiler.disable
465
554
  def apply_stats_hooks(
466
555
  cls,
467
556
  block_adapter: BlockAdapter,
468
557
  ):
469
- from cache_dit.cache_factory.cache_blocks import (
470
- patch_cached_stats,
558
+ from cache_dit.caching.cache_blocks import (
559
+ apply_stats,
471
560
  )
472
561
 
473
- cache_manager = block_adapter.pipe._cache_manager
562
+ context_manager = block_adapter.pipe._context_manager
474
563
 
475
564
  for i in range(len(block_adapter.transformer)):
476
- patch_cached_stats(
565
+ apply_stats(
477
566
  block_adapter.transformer[i],
478
567
  cache_context=block_adapter.unique_blocks_name[i][-1],
479
- cache_manager=cache_manager,
568
+ context_manager=context_manager,
480
569
  )
481
570
  for blocks, unique_name in zip(
482
571
  block_adapter.blocks[i],
483
572
  block_adapter.unique_blocks_name[i],
484
573
  ):
485
- patch_cached_stats(
574
+ apply_stats(
486
575
  blocks,
487
576
  cache_context=unique_name,
488
- cache_manager=cache_manager,
577
+ context_manager=context_manager,
489
578
  )
490
579
 
491
580
  @classmethod
@@ -513,11 +602,13 @@ class CachedAdapter:
513
602
  original_call = pipe.__class__._original_call
514
603
  pipe.__class__.__call__ = original_call
515
604
  del pipe.__class__._original_call
516
- if hasattr(pipe, "_cache_manager"):
517
- cache_manager = pipe._cache_manager
518
- if isinstance(cache_manager, CachedContextManager):
519
- cache_manager.clear_contexts()
520
- del pipe._cache_manager
605
+ if hasattr(pipe, "_context_manager"):
606
+ context_manager = pipe._context_manager
607
+ if isinstance(
608
+ context_manager, ContextManager._supported_managers
609
+ ):
610
+ context_manager.clear_contexts()
611
+ del pipe._context_manager
521
612
  if hasattr(pipe, "_is_cached"):
522
613
  del pipe.__class__._is_cached
523
614
 
@@ -532,22 +623,22 @@ class CachedAdapter:
532
623
  def _release_blocks_params(blocks):
533
624
  if hasattr(blocks, "_forward_pattern"):
534
625
  del blocks._forward_pattern
535
- if hasattr(blocks, "_cache_context_kwargs"):
536
- del blocks._cache_context_kwargs
626
+ if hasattr(blocks, "_context_kwargs"):
627
+ del blocks._context_kwargs
537
628
 
538
629
  def _release_transformer_params(transformer):
539
630
  if hasattr(transformer, "_forward_pattern"):
540
631
  del transformer._forward_pattern
541
632
  if hasattr(transformer, "_has_separate_cfg"):
542
633
  del transformer._has_separate_cfg
543
- if hasattr(transformer, "_cache_context_kwargs"):
544
- del transformer._cache_context_kwargs
634
+ if hasattr(transformer, "_context_kwargs"):
635
+ del transformer._context_kwargs
545
636
  for blocks in BlockAdapter.find_blocks(transformer):
546
637
  _release_blocks_params(blocks)
547
638
 
548
639
  def _release_pipeline_params(pipe):
549
- if hasattr(pipe, "_cache_context_kwargs"):
550
- del pipe._cache_context_kwargs
640
+ if hasattr(pipe, "_context_kwargs"):
641
+ del pipe._context_kwargs
551
642
 
552
643
  cls.release_hooks(
553
644
  pipe_or_adapter,
@@ -557,15 +648,24 @@ class CachedAdapter:
557
648
  )
558
649
 
559
650
  # release stats hooks
560
- from cache_dit.cache_factory.cache_blocks import (
561
- remove_cached_stats,
651
+ from cache_dit.caching.cache_blocks import (
652
+ remove_stats,
653
+ )
654
+
655
+ cls.release_hooks(
656
+ pipe_or_adapter, remove_stats, remove_stats, remove_stats
657
+ )
658
+
659
+ # maybe release parallelism stats
660
+ from cache_dit.parallelism.parallel_interface import (
661
+ remove_parallelism_stats,
562
662
  )
563
663
 
564
664
  cls.release_hooks(
565
665
  pipe_or_adapter,
566
- remove_cached_stats,
567
- remove_cached_stats,
568
- remove_cached_stats,
666
+ remove_parallelism_stats,
667
+ remove_parallelism_stats,
668
+ remove_parallelism_stats,
569
669
  )
570
670
 
571
671
  @classmethod