cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/offload_utils.py +115 -0
  11. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  12. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  13. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  14. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  15. cache_dit/caching/cache_contexts/__init__.py +28 -0
  16. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
  18. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  21. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  22. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  23. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  24. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  25. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  26. cache_dit/caching/cache_interface.py +358 -0
  27. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  28. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  29. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  30. cache_dit/caching/patch_functors/__init__.py +15 -0
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  36. cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
  37. cache_dit/caching/utils.py +68 -0
  38. cache_dit/metrics/__init__.py +11 -0
  39. cache_dit/metrics/metrics.py +3 -0
  40. cache_dit/parallelism/__init__.py +3 -0
  41. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  57. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  58. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  59. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  60. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  61. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  62. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  68. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  69. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  70. cache_dit/parallelism/parallel_backend.py +26 -0
  71. cache_dit/parallelism/parallel_config.py +88 -0
  72. cache_dit/parallelism/parallel_interface.py +77 -0
  73. cache_dit/quantize/__init__.py +7 -0
  74. cache_dit/quantize/backends/__init__.py +1 -0
  75. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  76. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  77. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
  78. cache_dit/quantize/quantize_backend.py +0 -0
  79. cache_dit/quantize/quantize_config.py +0 -0
  80. cache_dit/quantize/quantize_interface.py +3 -16
  81. cache_dit/summary.py +593 -0
  82. cache_dit/utils.py +46 -290
  83. cache_dit-1.0.14.dist-info/METADATA +301 -0
  84. cache_dit-1.0.14.dist-info/RECORD +102 -0
  85. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  86. cache_dit/cache_factory/__init__.py +0 -28
  87. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  88. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  89. cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
  90. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  91. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
  92. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
  93. cache_dit/cache_factory/cache_blocks/utils.py +0 -41
  94. cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
  95. cache_dit/cache_factory/cache_interface.py +0 -217
  96. cache_dit/cache_factory/patch_functors/__init__.py +0 -12
  97. cache_dit/cache_factory/utils.py +0 -57
  98. cache_dit-0.3.2.dist-info/METADATA +0 -753
  99. cache_dit-0.3.2.dist-info/RECORD +0 -56
  100. cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
  101. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  102. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  103. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  104. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  105. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  106. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  107. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  108. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -1,25 +1,21 @@
1
+ import copy
1
2
  import torch
2
-
3
3
  import unittest
4
4
  import functools
5
-
6
5
  from contextlib import ExitStack
7
- from typing import Dict, List, Tuple, Any, Union, Callable
8
-
9
- from diffusers import DiffusionPipeline
10
-
11
- from cache_dit.cache_factory.cache_types import CacheType
12
- from cache_dit.cache_factory.block_adapters import BlockAdapter
13
- from cache_dit.cache_factory.block_adapters import ParamsModifier
14
- from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
15
- from cache_dit.cache_factory.cache_contexts import CachedContextManager
16
- from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
17
- from cache_dit.cache_factory.cache_contexts import CalibratorConfig
18
- from cache_dit.cache_factory.cache_blocks import CachedBlocks
19
- from cache_dit.cache_factory.cache_blocks.utils import (
20
- patch_cached_stats,
21
- remove_cached_stats,
22
- )
6
+ from typing import Dict, List, Tuple, Any, Union, Callable, Optional
7
+
8
+ from diffusers import DiffusionPipeline, ModelMixin
9
+
10
+ from cache_dit.caching.cache_types import CacheType
11
+ from cache_dit.caching.block_adapters import BlockAdapter
12
+ from cache_dit.caching.block_adapters import FakeDiffusionPipeline
13
+ from cache_dit.caching.block_adapters import ParamsModifier
14
+ from cache_dit.caching.block_adapters import BlockAdapterRegistry
15
+ from cache_dit.caching.cache_contexts import ContextManager
16
+ from cache_dit.caching.cache_contexts import BasicCacheConfig
17
+ from cache_dit.caching.cache_contexts import CalibratorConfig
18
+ from cache_dit.caching.cache_blocks import UnifiedBlocks
23
19
  from cache_dit.logger import init_logger
24
20
 
25
21
  logger = init_logger(__name__)
@@ -37,8 +33,11 @@ class CachedAdapter:
37
33
  pipe_or_adapter: Union[
38
34
  DiffusionPipeline,
39
35
  BlockAdapter,
36
+ # Transformer-only
37
+ torch.nn.Module,
38
+ ModelMixin,
40
39
  ],
41
- **cache_context_kwargs,
40
+ **context_kwargs,
42
41
  ) -> Union[
43
42
  DiffusionPipeline,
44
43
  BlockAdapter,
@@ -47,7 +46,9 @@ class CachedAdapter:
47
46
  pipe_or_adapter is not None
48
47
  ), "pipe or block_adapter can not both None!"
49
48
 
50
- if isinstance(pipe_or_adapter, DiffusionPipeline):
49
+ if isinstance(
50
+ pipe_or_adapter, (DiffusionPipeline, torch.nn.Module, ModelMixin)
51
+ ):
51
52
  if BlockAdapterRegistry.is_supported(pipe_or_adapter):
52
53
  logger.info(
53
54
  f"{pipe_or_adapter.__class__.__name__} is officially "
@@ -57,16 +58,22 @@ class CachedAdapter:
57
58
  block_adapter = BlockAdapterRegistry.get_adapter(
58
59
  pipe_or_adapter
59
60
  )
60
- if params_modifiers := cache_context_kwargs.pop(
61
+ assert block_adapter is not None, (
62
+ f"BlockAdapter for {pipe_or_adapter.__class__.__name__} "
63
+ "should not be None!"
64
+ )
65
+ if params_modifiers := context_kwargs.pop(
61
66
  "params_modifiers",
62
67
  None,
63
68
  ):
64
69
  block_adapter.params_modifiers = params_modifiers
65
70
 
66
- return cls.cachify(
67
- block_adapter,
68
- **cache_context_kwargs,
69
- ).pipe
71
+ block_adapter = cls.cachify(block_adapter, **context_kwargs)
72
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
73
+ return block_adapter.pipe
74
+
75
+ return block_adapter.transformer
76
+
70
77
  else:
71
78
  raise ValueError(
72
79
  f"{pipe_or_adapter.__class__.__name__} is not officially supported "
@@ -78,21 +85,21 @@ class CachedAdapter:
78
85
  "Adapting Cache Acceleration using custom BlockAdapter!"
79
86
  )
80
87
  if pipe_or_adapter.params_modifiers is None:
81
- if params_modifiers := cache_context_kwargs.pop(
88
+ if params_modifiers := context_kwargs.pop(
82
89
  "params_modifiers", None
83
90
  ):
84
91
  pipe_or_adapter.params_modifiers = params_modifiers
85
92
 
86
93
  return cls.cachify(
87
94
  pipe_or_adapter,
88
- **cache_context_kwargs,
95
+ **context_kwargs,
89
96
  )
90
97
 
91
98
  @classmethod
92
99
  def cachify(
93
100
  cls,
94
101
  block_adapter: BlockAdapter,
95
- **cache_context_kwargs,
102
+ **context_kwargs,
96
103
  ) -> BlockAdapter:
97
104
 
98
105
  if block_adapter.auto:
@@ -109,14 +116,15 @@ class CachedAdapter:
109
116
 
110
117
  # 1. Apply cache on pipeline: wrap cache context, must
111
118
  # call create_context before mock_blocks.
112
- cls.create_context(
119
+ _, contexts_kwargs = cls.create_context(
113
120
  block_adapter,
114
- **cache_context_kwargs,
121
+ **context_kwargs,
115
122
  )
116
123
 
117
124
  # 2. Apply cache on transformer: mock cached blocks
118
125
  cls.mock_blocks(
119
126
  block_adapter,
127
+ contexts_kwargs,
120
128
  )
121
129
 
122
130
  return block_adapter
@@ -125,12 +133,10 @@ class CachedAdapter:
125
133
  def check_context_kwargs(
126
134
  cls,
127
135
  block_adapter: BlockAdapter,
128
- **cache_context_kwargs,
136
+ **context_kwargs,
129
137
  ):
130
- # Check cache_context_kwargs
131
- cache_config: BasicCacheConfig = cache_context_kwargs[
132
- "cache_config"
133
- ] # ref
138
+ # Check context_kwargs
139
+ cache_config: BasicCacheConfig = context_kwargs["cache_config"] # ref
134
140
  assert cache_config is not None, "cache_config can not be None."
135
141
  if cache_config.enable_separate_cfg is None:
136
142
  # Check cfg for some specific case if users don't set it as True
@@ -156,87 +162,113 @@ class CachedAdapter:
156
162
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
157
163
  )
158
164
 
159
- cache_type = cache_context_kwargs.pop("cache_type", None)
165
+ cache_type = context_kwargs.pop("cache_type", None)
160
166
  if cache_type is not None:
161
- assert (
162
- cache_type == CacheType.DBCache
163
- ), "Custom cache setting only support for DBCache now!"
167
+ assert isinstance(
168
+ cache_type, CacheType
169
+ ), f"cache_type must be CacheType Enum, but got {type(cache_type)}."
170
+ assert cache_type == cache_config.cache_type, (
171
+ f"cache_type from context_kwargs ({cache_type}) must be the same "
172
+ f"as that from cache_config ({cache_config.cache_type})."
173
+ )
164
174
 
165
- return cache_context_kwargs
175
+ return context_kwargs
166
176
 
167
177
  @classmethod
168
178
  def create_context(
169
179
  cls,
170
180
  block_adapter: BlockAdapter,
171
- **cache_context_kwargs,
172
- ) -> DiffusionPipeline:
181
+ **context_kwargs,
182
+ ) -> Tuple[List[str], List[Dict[str, Any]]]:
173
183
 
174
184
  BlockAdapter.assert_normalized(block_adapter)
175
185
 
176
186
  if BlockAdapter.is_cached(block_adapter.pipe):
177
187
  return block_adapter.pipe
178
188
 
179
- # Check cache_context_kwargs
180
- cache_context_kwargs = cls.check_context_kwargs(
181
- block_adapter, **cache_context_kwargs
189
+ # Check context_kwargs
190
+ context_kwargs = cls.check_context_kwargs(
191
+ block_adapter, **context_kwargs
182
192
  )
183
- # Apply cache on pipeline: wrap cache context
184
- pipe_cls_name = block_adapter.pipe.__class__.__name__
185
193
 
186
194
  # Each Pipeline should have it's own context manager instance.
187
195
  # Different transformers (Wan2.2, etc) should shared the same
188
196
  # cache manager but with different cache context (according
189
197
  # to their unique instance id).
190
- cache_manager = CachedContextManager(
198
+ cache_config: BasicCacheConfig = context_kwargs.get(
199
+ "cache_config", None
200
+ )
201
+ assert cache_config is not None, "cache_config can not be None."
202
+ # Apply cache on pipeline: wrap cache context
203
+ pipe_cls_name = block_adapter.pipe.__class__.__name__
204
+ context_manager = ContextManager(
191
205
  name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
206
+ cache_type=cache_config.cache_type,
207
+ # Force use persistent_context for FakeDiffusionPipeline
208
+ persistent_context=isinstance(
209
+ block_adapter.pipe, FakeDiffusionPipeline
210
+ ),
192
211
  )
193
- block_adapter.pipe._cache_manager = cache_manager # instance level
194
-
195
212
  flatten_contexts, contexts_kwargs = cls.modify_context_params(
196
- block_adapter, **cache_context_kwargs
213
+ block_adapter, **context_kwargs
197
214
  )
198
215
 
199
- original_call = block_adapter.pipe.__class__.__call__
200
-
201
- @functools.wraps(original_call)
202
- def new_call(self, *args, **kwargs):
203
- with ExitStack() as stack:
204
- # cache context will be reset for each pipe inference
205
- for context_name, context_kwargs in zip(
206
- flatten_contexts, contexts_kwargs
207
- ):
208
- stack.enter_context(
209
- cache_manager.enter_context(
210
- cache_manager.reset_context(
211
- context_name,
212
- **context_kwargs,
213
- ),
216
+ block_adapter.pipe._context_manager = context_manager # instance level
217
+
218
+ if not context_manager.persistent_context:
219
+
220
+ original_call = block_adapter.pipe.__class__.__call__
221
+
222
+ @functools.wraps(original_call)
223
+ def new_call(self, *args, **kwargs):
224
+ with ExitStack() as stack:
225
+ # cache context will be reset for each pipe inference
226
+ for context_name, context_kwargs in zip(
227
+ flatten_contexts, contexts_kwargs
228
+ ):
229
+ stack.enter_context(
230
+ context_manager.enter_context(
231
+ context_manager.reset_context(
232
+ context_name,
233
+ **context_kwargs,
234
+ ),
235
+ )
214
236
  )
215
- )
216
- outputs = original_call(self, *args, **kwargs)
217
- cls.apply_stats_hooks(block_adapter)
218
- return outputs
237
+ outputs = original_call(self, *args, **kwargs)
238
+ cls.apply_stats_hooks(block_adapter)
239
+ return outputs
240
+
241
+ block_adapter.pipe.__class__.__call__ = new_call
242
+ block_adapter.pipe.__class__._original_call = original_call
243
+
244
+ else:
245
+ # Init persistent cache context for transformer
246
+ for context_name, context_kwargs in zip(
247
+ flatten_contexts, contexts_kwargs
248
+ ):
249
+ context_manager.reset_context(
250
+ context_name,
251
+ **context_kwargs,
252
+ )
219
253
 
220
- block_adapter.pipe.__class__.__call__ = new_call
221
- block_adapter.pipe.__class__._original_call = original_call
222
254
  block_adapter.pipe.__class__._is_cached = True
223
255
 
224
256
  cls.apply_params_hooks(block_adapter, contexts_kwargs)
225
257
 
226
- return block_adapter.pipe
258
+ return flatten_contexts, contexts_kwargs
227
259
 
228
260
  @classmethod
229
261
  def modify_context_params(
230
262
  cls,
231
263
  block_adapter: BlockAdapter,
232
- **cache_context_kwargs,
264
+ **context_kwargs,
233
265
  ) -> Tuple[List[str], List[Dict[str, Any]]]:
234
266
 
235
267
  flatten_contexts = BlockAdapter.flatten(
236
268
  block_adapter.unique_blocks_name
237
269
  )
238
270
  contexts_kwargs = [
239
- cache_context_kwargs.copy()
271
+ copy.deepcopy(context_kwargs) # must deep copy
240
272
  for _ in range(
241
273
  len(flatten_contexts),
242
274
  )
@@ -257,9 +289,41 @@ class CachedAdapter:
257
289
  for i in range(
258
290
  min(len(contexts_kwargs), len(flatten_modifiers)),
259
291
  ):
260
- contexts_kwargs[i].update(
261
- flatten_modifiers[i]._context_kwargs,
262
- )
292
+ if "cache_config" in flatten_modifiers[i]._context_kwargs:
293
+ modifier_cache_config = flatten_modifiers[
294
+ i
295
+ ]._context_kwargs.get("cache_config", None)
296
+ modifier_calibrator_config = flatten_modifiers[
297
+ i
298
+ ]._context_kwargs.get("calibrator_config", None)
299
+ if modifier_cache_config is not None:
300
+ assert isinstance(
301
+ modifier_cache_config, BasicCacheConfig
302
+ ), (
303
+ f"cache_config must be BasicCacheConfig, but got "
304
+ f"{type(modifier_cache_config)}."
305
+ )
306
+ contexts_kwargs[i]["cache_config"].update(
307
+ **modifier_cache_config.as_dict()
308
+ )
309
+ if modifier_calibrator_config is not None:
310
+ assert isinstance(
311
+ modifier_calibrator_config, CalibratorConfig
312
+ ), (
313
+ f"calibrator_config must be CalibratorConfig, but got "
314
+ f"{type(modifier_calibrator_config)}."
315
+ )
316
+ if (
317
+ contexts_kwargs[i].get("calibrator_config", None)
318
+ is None
319
+ ):
320
+ contexts_kwargs[i][
321
+ "calibrator_config"
322
+ ] = modifier_calibrator_config
323
+ else:
324
+ contexts_kwargs[i]["calibrator_config"].update(
325
+ **modifier_calibrator_config.as_dict()
326
+ )
263
327
  cls._config_messages(**contexts_kwargs[i])
264
328
 
265
329
  return flatten_contexts, contexts_kwargs
@@ -273,7 +337,7 @@ class CachedAdapter:
273
337
  "calibrator_config", None
274
338
  )
275
339
  if cache_config is not None:
276
- message = f"Collected Cache Config: {cache_config.strify()}"
340
+ message = f"Collected Context Config: {cache_config.strify()}"
277
341
  if calibrator_config is not None:
278
342
  message += f", Calibrator Config: {calibrator_config.strify(details=True)}"
279
343
  else:
@@ -284,6 +348,7 @@ class CachedAdapter:
284
348
  def mock_blocks(
285
349
  cls,
286
350
  block_adapter: BlockAdapter,
351
+ contexts_kwargs: List[Dict],
287
352
  ) -> List[torch.nn.Module]:
288
353
 
289
354
  BlockAdapter.assert_normalized(block_adapter)
@@ -293,24 +358,28 @@ class CachedAdapter:
293
358
 
294
359
  # Apply cache on transformer: mock cached transformer blocks
295
360
  for (
296
- cached_blocks,
361
+ unified_blocks,
297
362
  transformer,
298
363
  blocks_name,
299
364
  unique_blocks_name,
300
365
  dummy_blocks_names,
301
366
  ) in zip(
302
- cls.collect_cached_blocks(block_adapter),
367
+ cls.collect_unified_blocks(
368
+ block_adapter,
369
+ contexts_kwargs,
370
+ ),
303
371
  block_adapter.transformer,
304
372
  block_adapter.blocks_name,
305
373
  block_adapter.unique_blocks_name,
306
374
  block_adapter.dummy_blocks_names,
307
375
  ):
308
376
  cls.mock_transformer(
309
- cached_blocks,
377
+ unified_blocks,
310
378
  transformer,
311
379
  blocks_name,
312
380
  unique_blocks_name,
313
381
  dummy_blocks_names,
382
+ block_adapter,
314
383
  )
315
384
 
316
385
  return block_adapter.transformer
@@ -318,11 +387,12 @@ class CachedAdapter:
318
387
  @classmethod
319
388
  def mock_transformer(
320
389
  cls,
321
- cached_blocks: Dict[str, torch.nn.ModuleList],
390
+ unified_blocks: Dict[str, torch.nn.ModuleList],
322
391
  transformer: torch.nn.Module,
323
392
  blocks_name: List[str],
324
393
  unique_blocks_name: List[str],
325
394
  dummy_blocks_names: List[str],
395
+ block_adapter: BlockAdapter,
326
396
  ) -> torch.nn.Module:
327
397
  dummy_blocks = torch.nn.ModuleList()
328
398
 
@@ -330,7 +400,28 @@ class CachedAdapter:
330
400
 
331
401
  assert isinstance(dummy_blocks_names, list)
332
402
 
333
- @functools.wraps(original_forward)
403
+ from accelerate import hooks
404
+
405
+ _hf_hook: Optional[hooks.ModelHook] = None
406
+
407
+ if getattr(transformer, "_hf_hook", None) is not None:
408
+ _hf_hook = transformer._hf_hook # hooks from accelerate.hooks
409
+ if hasattr(transformer, "_old_forward"):
410
+ logger.warning(
411
+ "_hf_hook is not None, so, we have to re-direct transformer's "
412
+ f"original_forward({id(original_forward)}) to transformer's "
413
+ f"_old_forward({id(transformer._old_forward)})"
414
+ )
415
+ original_forward = transformer._old_forward
416
+
417
+ # TODO: remove group offload hooks the re-apply after cache applied.
418
+ # hooks = _diffusers_hook.hooks.copy(); _diffusers_hook.hooks.clear()
419
+ # re-apply hooks to transformer after cache applied.
420
+ # from diffusers.hooks.hooks import HookFunctionReference, HookRegistry
421
+ # from diffusers.hooks.group_offloading import apply_group_offloading
422
+ context_manager: ContextManager = block_adapter.pipe._context_manager
423
+ assert isinstance(context_manager, ContextManager._supported_managers)
424
+
334
425
  def new_forward(self, *args, **kwargs):
335
426
  with ExitStack() as stack:
336
427
  for name, context_name in zip(
@@ -339,7 +430,7 @@ class CachedAdapter:
339
430
  ):
340
431
  stack.enter_context(
341
432
  unittest.mock.patch.object(
342
- self, name, cached_blocks[context_name]
433
+ self, name, unified_blocks[context_name]
343
434
  )
344
435
  )
345
436
  for dummy_name in dummy_blocks_names:
@@ -348,55 +439,85 @@ class CachedAdapter:
348
439
  self, dummy_name, dummy_blocks
349
440
  )
350
441
  )
351
- return original_forward(*args, **kwargs)
442
+ outputs = original_forward(*args, **kwargs)
443
+
444
+ if (
445
+ context_manager.persistent_context
446
+ and context_manager.is_pre_refreshed()
447
+ ):
448
+ cls.apply_stats_hooks(block_adapter)
449
+
450
+ return outputs
451
+
452
+ def new_forward_with_hf_hook(self, *args, **kwargs):
453
+ # Compatible with model cpu offload
454
+ if _hf_hook is not None and hasattr(_hf_hook, "pre_forward"):
455
+ args, kwargs = _hf_hook.pre_forward(self, *args, **kwargs)
456
+
457
+ outputs = new_forward(self, *args, **kwargs)
458
+
459
+ if _hf_hook is not None and hasattr(_hf_hook, "post_forward"):
460
+ outputs = _hf_hook.post_forward(self, outputs)
461
+
462
+ return outputs
463
+
464
+ # NOTE: Still can't fully compatible with group offloading
465
+ transformer.forward = functools.update_wrapper(
466
+ functools.partial(new_forward_with_hf_hook, transformer),
467
+ new_forward_with_hf_hook,
468
+ )
352
469
 
353
- transformer.forward = new_forward.__get__(transformer)
354
470
  transformer._original_forward = original_forward
355
471
  transformer._is_cached = True
356
472
 
357
473
  return transformer
358
474
 
359
475
  @classmethod
360
- def collect_cached_blocks(
476
+ def collect_unified_blocks(
361
477
  cls,
362
478
  block_adapter: BlockAdapter,
479
+ contexts_kwargs: List[Dict],
363
480
  ) -> List[Dict[str, torch.nn.ModuleList]]:
364
481
 
365
482
  BlockAdapter.assert_normalized(block_adapter)
366
483
 
367
484
  total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
368
- assert hasattr(block_adapter.pipe, "_cache_manager")
485
+ assert hasattr(block_adapter.pipe, "_context_manager")
369
486
  assert isinstance(
370
- block_adapter.pipe._cache_manager,
371
- CachedContextManager,
487
+ block_adapter.pipe._context_manager,
488
+ ContextManager._supported_managers,
372
489
  )
373
490
 
374
491
  for i in range(len(block_adapter.transformer)):
375
492
 
376
- cached_blocks_bind_context = {}
493
+ unified_blocks_bind_context = {}
377
494
  for j in range(len(block_adapter.blocks[i])):
378
- cached_blocks_bind_context[
495
+ cache_config: BasicCacheConfig = contexts_kwargs[
496
+ i * len(block_adapter.blocks[i]) + j
497
+ ]["cache_config"]
498
+ unified_blocks_bind_context[
379
499
  block_adapter.unique_blocks_name[i][j]
380
500
  ] = torch.nn.ModuleList(
381
501
  [
382
- CachedBlocks(
502
+ UnifiedBlocks(
383
503
  # 0. Transformer blocks configuration
384
504
  block_adapter.blocks[i][j],
385
505
  transformer=block_adapter.transformer[i],
386
506
  forward_pattern=block_adapter.forward_pattern[i][j],
387
507
  check_forward_pattern=block_adapter.check_forward_pattern,
388
508
  check_num_outputs=block_adapter.check_num_outputs,
389
- # 1. Cache context configuration
509
+ # 1. Cache/Prune context configuration
390
510
  cache_prefix=block_adapter.blocks_name[i][j],
391
511
  cache_context=block_adapter.unique_blocks_name[i][
392
512
  j
393
513
  ],
394
- cache_manager=block_adapter.pipe._cache_manager,
514
+ context_manager=block_adapter.pipe._context_manager,
515
+ cache_type=cache_config.cache_type,
395
516
  )
396
517
  ]
397
518
  )
398
519
 
399
- total_cached_blocks.append(cached_blocks_bind_context)
520
+ total_cached_blocks.append(unified_blocks_bind_context)
400
521
 
401
522
  return total_cached_blocks
402
523
 
@@ -406,7 +527,7 @@ class CachedAdapter:
406
527
  block_adapter: BlockAdapter,
407
528
  contexts_kwargs: List[Dict],
408
529
  ):
409
- block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
530
+ block_adapter.pipe._context_kwargs = contexts_kwargs[0]
410
531
 
411
532
  params_shift = 0
412
533
  for i in range(len(block_adapter.transformer)):
@@ -417,40 +538,43 @@ class CachedAdapter:
417
538
  block_adapter.transformer[i]._has_separate_cfg = (
418
539
  block_adapter.has_separate_cfg
419
540
  )
420
- block_adapter.transformer[i]._cache_context_kwargs = (
421
- contexts_kwargs[params_shift]
422
- )
541
+ block_adapter.transformer[i]._context_kwargs = contexts_kwargs[
542
+ params_shift
543
+ ]
423
544
 
424
545
  blocks = block_adapter.blocks[i]
425
546
  for j in range(len(blocks)):
426
547
  blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
427
- blocks[j]._cache_context_kwargs = contexts_kwargs[
428
- params_shift + j
429
- ]
548
+ blocks[j]._context_kwargs = contexts_kwargs[params_shift + j]
430
549
 
431
550
  params_shift += len(blocks)
432
551
 
433
552
  @classmethod
553
+ @torch.compiler.disable
434
554
  def apply_stats_hooks(
435
555
  cls,
436
556
  block_adapter: BlockAdapter,
437
557
  ):
438
- cache_manager = block_adapter.pipe._cache_manager
558
+ from cache_dit.caching.cache_blocks import (
559
+ apply_stats,
560
+ )
561
+
562
+ context_manager = block_adapter.pipe._context_manager
439
563
 
440
564
  for i in range(len(block_adapter.transformer)):
441
- patch_cached_stats(
565
+ apply_stats(
442
566
  block_adapter.transformer[i],
443
567
  cache_context=block_adapter.unique_blocks_name[i][-1],
444
- cache_manager=cache_manager,
568
+ context_manager=context_manager,
445
569
  )
446
570
  for blocks, unique_name in zip(
447
571
  block_adapter.blocks[i],
448
572
  block_adapter.unique_blocks_name[i],
449
573
  ):
450
- patch_cached_stats(
574
+ apply_stats(
451
575
  blocks,
452
576
  cache_context=unique_name,
453
- cache_manager=cache_manager,
577
+ context_manager=context_manager,
454
578
  )
455
579
 
456
580
  @classmethod
@@ -478,11 +602,13 @@ class CachedAdapter:
478
602
  original_call = pipe.__class__._original_call
479
603
  pipe.__class__.__call__ = original_call
480
604
  del pipe.__class__._original_call
481
- if hasattr(pipe, "_cache_manager"):
482
- cache_manager = pipe._cache_manager
483
- if isinstance(cache_manager, CachedContextManager):
484
- cache_manager.clear_contexts()
485
- del pipe._cache_manager
605
+ if hasattr(pipe, "_context_manager"):
606
+ context_manager = pipe._context_manager
607
+ if isinstance(
608
+ context_manager, ContextManager._supported_managers
609
+ ):
610
+ context_manager.clear_contexts()
611
+ del pipe._context_manager
486
612
  if hasattr(pipe, "_is_cached"):
487
613
  del pipe.__class__._is_cached
488
614
 
@@ -497,22 +623,22 @@ class CachedAdapter:
497
623
  def _release_blocks_params(blocks):
498
624
  if hasattr(blocks, "_forward_pattern"):
499
625
  del blocks._forward_pattern
500
- if hasattr(blocks, "_cache_context_kwargs"):
501
- del blocks._cache_context_kwargs
626
+ if hasattr(blocks, "_context_kwargs"):
627
+ del blocks._context_kwargs
502
628
 
503
629
  def _release_transformer_params(transformer):
504
630
  if hasattr(transformer, "_forward_pattern"):
505
631
  del transformer._forward_pattern
506
632
  if hasattr(transformer, "_has_separate_cfg"):
507
633
  del transformer._has_separate_cfg
508
- if hasattr(transformer, "_cache_context_kwargs"):
509
- del transformer._cache_context_kwargs
634
+ if hasattr(transformer, "_context_kwargs"):
635
+ del transformer._context_kwargs
510
636
  for blocks in BlockAdapter.find_blocks(transformer):
511
637
  _release_blocks_params(blocks)
512
638
 
513
639
  def _release_pipeline_params(pipe):
514
- if hasattr(pipe, "_cache_context_kwargs"):
515
- del pipe._cache_context_kwargs
640
+ if hasattr(pipe, "_context_kwargs"):
641
+ del pipe._context_kwargs
516
642
 
517
643
  cls.release_hooks(
518
644
  pipe_or_adapter,
@@ -522,11 +648,24 @@ class CachedAdapter:
522
648
  )
523
649
 
524
650
  # release stats hooks
651
+ from cache_dit.caching.cache_blocks import (
652
+ remove_stats,
653
+ )
654
+
655
+ cls.release_hooks(
656
+ pipe_or_adapter, remove_stats, remove_stats, remove_stats
657
+ )
658
+
659
+ # maybe release parallelism stats
660
+ from cache_dit.parallelism.parallel_interface import (
661
+ remove_parallelism_stats,
662
+ )
663
+
525
664
  cls.release_hooks(
526
665
  pipe_or_adapter,
527
- remove_cached_stats,
528
- remove_cached_stats,
529
- remove_cached_stats,
666
+ remove_parallelism_stats,
667
+ remove_parallelism_stats,
668
+ remove_parallelism_stats,
530
669
  )
531
670
 
532
671
  @classmethod