cache-dit 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (29) hide show
  1. cache_dit/__init__.py +3 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +8 -1
  4. cache_dit/cache_factory/cache_adapters/cache_adapter.py +90 -76
  5. cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
  6. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
  7. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +271 -36
  8. cache_dit/cache_factory/cache_blocks/pattern_base.py +286 -45
  9. cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
  10. cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
  11. cache_dit/cache_factory/cache_contexts/cache_config.py +102 -0
  12. cache_dit/cache_factory/cache_contexts/cache_context.py +26 -89
  13. cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
  14. cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
  15. cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
  16. cache_dit/cache_factory/cache_contexts/prune_config.py +69 -0
  17. cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
  18. cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
  19. cache_dit/cache_factory/cache_interface.py +23 -14
  20. cache_dit/cache_factory/cache_types.py +19 -2
  21. cache_dit/cache_factory/params_modifier.py +7 -7
  22. cache_dit/cache_factory/utils.py +38 -27
  23. cache_dit/utils.py +191 -54
  24. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/METADATA +14 -7
  25. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/RECORD +29 -24
  26. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/WHEEL +0 -0
  27. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py CHANGED
@@ -19,6 +19,8 @@ from cache_dit.cache_factory import ParamsModifier
19
19
  from cache_dit.cache_factory import ForwardPattern
20
20
  from cache_dit.cache_factory import PatchFunctor
21
21
  from cache_dit.cache_factory import BasicCacheConfig
22
+ from cache_dit.cache_factory import DBCacheConfig
23
+ from cache_dit.cache_factory import DBPruneConfig
22
24
  from cache_dit.cache_factory import CalibratorConfig
23
25
  from cache_dit.cache_factory import TaylorSeerCalibratorConfig
24
26
  from cache_dit.cache_factory import FoCaCalibratorConfig
@@ -30,6 +32,7 @@ from cache_dit.quantize import quantize
30
32
 
31
33
  NONE = CacheType.NONE
32
34
  DBCache = CacheType.DBCache
35
+ DBPrune = CacheType.DBPrune
33
36
 
34
37
  Pattern_0 = ForwardPattern.Pattern_0
35
38
  Pattern_1 = ForwardPattern.Pattern_1
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.0.2'
32
- __version_tuple__ = version_tuple = (1, 0, 2)
31
+ __version__ = version = '1.0.4'
32
+ __version_tuple__ = version_tuple = (1, 0, 4)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -9,14 +9,21 @@ from cache_dit.cache_factory.patch_functors import PatchFunctor
9
9
  from cache_dit.cache_factory.block_adapters import BlockAdapter
10
10
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
11
11
 
12
- from cache_dit.cache_factory.cache_contexts import CachedContext
13
12
  from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
13
+ from cache_dit.cache_factory.cache_contexts import DBCacheConfig
14
+ from cache_dit.cache_factory.cache_contexts import CachedContext
14
15
  from cache_dit.cache_factory.cache_contexts import CachedContextManager
16
+ from cache_dit.cache_factory.cache_contexts import DBPruneConfig
17
+ from cache_dit.cache_factory.cache_contexts import PrunedContext
18
+ from cache_dit.cache_factory.cache_contexts import PrunedContextManager
19
+ from cache_dit.cache_factory.cache_contexts import ContextManager
15
20
  from cache_dit.cache_factory.cache_contexts import CalibratorConfig
16
21
  from cache_dit.cache_factory.cache_contexts import TaylorSeerCalibratorConfig
17
22
  from cache_dit.cache_factory.cache_contexts import FoCaCalibratorConfig
18
23
 
19
24
  from cache_dit.cache_factory.cache_blocks import CachedBlocks
25
+ from cache_dit.cache_factory.cache_blocks import PrunedBlocks
26
+ from cache_dit.cache_factory.cache_blocks import UnifiedBlocks
20
27
 
21
28
  from cache_dit.cache_factory.cache_adapters import CachedAdapter
22
29
 
@@ -10,10 +10,10 @@ from cache_dit.cache_factory.cache_types import CacheType
10
10
  from cache_dit.cache_factory.block_adapters import BlockAdapter
11
11
  from cache_dit.cache_factory.block_adapters import ParamsModifier
12
12
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
13
- from cache_dit.cache_factory.cache_contexts import CachedContextManager
13
+ from cache_dit.cache_factory.cache_contexts import ContextManager
14
14
  from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
15
15
  from cache_dit.cache_factory.cache_contexts import CalibratorConfig
16
- from cache_dit.cache_factory.cache_blocks import CachedBlocks
16
+ from cache_dit.cache_factory.cache_blocks import UnifiedBlocks
17
17
  from cache_dit.logger import init_logger
18
18
 
19
19
  logger = init_logger(__name__)
@@ -32,7 +32,7 @@ class CachedAdapter:
32
32
  DiffusionPipeline,
33
33
  BlockAdapter,
34
34
  ],
35
- **cache_context_kwargs,
35
+ **context_kwargs,
36
36
  ) -> Union[
37
37
  DiffusionPipeline,
38
38
  BlockAdapter,
@@ -51,7 +51,7 @@ class CachedAdapter:
51
51
  block_adapter = BlockAdapterRegistry.get_adapter(
52
52
  pipe_or_adapter
53
53
  )
54
- if params_modifiers := cache_context_kwargs.pop(
54
+ if params_modifiers := context_kwargs.pop(
55
55
  "params_modifiers",
56
56
  None,
57
57
  ):
@@ -59,7 +59,7 @@ class CachedAdapter:
59
59
 
60
60
  return cls.cachify(
61
61
  block_adapter,
62
- **cache_context_kwargs,
62
+ **context_kwargs,
63
63
  ).pipe
64
64
  else:
65
65
  raise ValueError(
@@ -72,21 +72,21 @@ class CachedAdapter:
72
72
  "Adapting Cache Acceleration using custom BlockAdapter!"
73
73
  )
74
74
  if pipe_or_adapter.params_modifiers is None:
75
- if params_modifiers := cache_context_kwargs.pop(
75
+ if params_modifiers := context_kwargs.pop(
76
76
  "params_modifiers", None
77
77
  ):
78
78
  pipe_or_adapter.params_modifiers = params_modifiers
79
79
 
80
80
  return cls.cachify(
81
81
  pipe_or_adapter,
82
- **cache_context_kwargs,
82
+ **context_kwargs,
83
83
  )
84
84
 
85
85
  @classmethod
86
86
  def cachify(
87
87
  cls,
88
88
  block_adapter: BlockAdapter,
89
- **cache_context_kwargs,
89
+ **context_kwargs,
90
90
  ) -> BlockAdapter:
91
91
 
92
92
  if block_adapter.auto:
@@ -103,14 +103,15 @@ class CachedAdapter:
103
103
 
104
104
  # 1. Apply cache on pipeline: wrap cache context, must
105
105
  # call create_context before mock_blocks.
106
- cls.create_context(
106
+ _, contexts_kwargs = cls.create_context(
107
107
  block_adapter,
108
- **cache_context_kwargs,
108
+ **context_kwargs,
109
109
  )
110
110
 
111
111
  # 2. Apply cache on transformer: mock cached blocks
112
112
  cls.mock_blocks(
113
113
  block_adapter,
114
+ contexts_kwargs,
114
115
  )
115
116
 
116
117
  return block_adapter
@@ -119,12 +120,10 @@ class CachedAdapter:
119
120
  def check_context_kwargs(
120
121
  cls,
121
122
  block_adapter: BlockAdapter,
122
- **cache_context_kwargs,
123
+ **context_kwargs,
123
124
  ):
124
- # Check cache_context_kwargs
125
- cache_config: BasicCacheConfig = cache_context_kwargs[
126
- "cache_config"
127
- ] # ref
125
+ # Check context_kwargs
126
+ cache_config: BasicCacheConfig = context_kwargs["cache_config"] # ref
128
127
  assert cache_config is not None, "cache_config can not be None."
129
128
  if cache_config.enable_separate_cfg is None:
130
129
  # Check cfg for some specific case if users don't set it as True
@@ -150,19 +149,23 @@ class CachedAdapter:
150
149
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
151
150
  )
152
151
 
153
- cache_type = cache_context_kwargs.pop("cache_type", None)
152
+ cache_type = context_kwargs.pop("cache_type", None)
154
153
  if cache_type is not None:
155
- assert (
156
- cache_type == CacheType.DBCache
157
- ), "Custom cache setting only support for DBCache now!"
154
+ assert isinstance(
155
+ cache_type, CacheType
156
+ ), f"cache_type must be CacheType Enum, but got {type(cache_type)}."
157
+ assert cache_type == cache_config.cache_type, (
158
+ f"cache_type from context_kwargs ({cache_type}) must be the same "
159
+ f"as that from cache_config ({cache_config.cache_type})."
160
+ )
158
161
 
159
- return cache_context_kwargs
162
+ return context_kwargs
160
163
 
161
164
  @classmethod
162
165
  def create_context(
163
166
  cls,
164
167
  block_adapter: BlockAdapter,
165
- **cache_context_kwargs,
168
+ **context_kwargs,
166
169
  ) -> Tuple[List[str], List[Dict[str, Any]]]:
167
170
 
168
171
  BlockAdapter.assert_normalized(block_adapter)
@@ -170,9 +173,9 @@ class CachedAdapter:
170
173
  if BlockAdapter.is_cached(block_adapter.pipe):
171
174
  return block_adapter.pipe
172
175
 
173
- # Check cache_context_kwargs
174
- cache_context_kwargs = cls.check_context_kwargs(
175
- block_adapter, **cache_context_kwargs
176
+ # Check context_kwargs
177
+ context_kwargs = cls.check_context_kwargs(
178
+ block_adapter, **context_kwargs
176
179
  )
177
180
  # Apply cache on pipeline: wrap cache context
178
181
  pipe_cls_name = block_adapter.pipe.__class__.__name__
@@ -181,13 +184,18 @@ class CachedAdapter:
181
184
  # Different transformers (Wan2.2, etc) should shared the same
182
185
  # cache manager but with different cache context (according
183
186
  # to their unique instance id).
184
- cache_manager = CachedContextManager(
187
+ cache_config: BasicCacheConfig = context_kwargs.get(
188
+ "cache_config", None
189
+ )
190
+ assert cache_config is not None, "cache_config can not be None."
191
+ context_manager = ContextManager(
185
192
  name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
193
+ cache_type=cache_config.cache_type,
186
194
  )
187
- block_adapter.pipe._cache_manager = cache_manager # instance level
195
+ block_adapter.pipe._context_manager = context_manager # instance level
188
196
 
189
197
  flatten_contexts, contexts_kwargs = cls.modify_context_params(
190
- block_adapter, **cache_context_kwargs
198
+ block_adapter, **context_kwargs
191
199
  )
192
200
 
193
201
  original_call = block_adapter.pipe.__class__.__call__
@@ -200,8 +208,8 @@ class CachedAdapter:
200
208
  flatten_contexts, contexts_kwargs
201
209
  ):
202
210
  stack.enter_context(
203
- cache_manager.enter_context(
204
- cache_manager.reset_context(
211
+ context_manager.enter_context(
212
+ context_manager.reset_context(
205
213
  context_name,
206
214
  **context_kwargs,
207
215
  ),
@@ -223,14 +231,14 @@ class CachedAdapter:
223
231
  def modify_context_params(
224
232
  cls,
225
233
  block_adapter: BlockAdapter,
226
- **cache_context_kwargs,
234
+ **context_kwargs,
227
235
  ) -> Tuple[List[str], List[Dict[str, Any]]]:
228
236
 
229
237
  flatten_contexts = BlockAdapter.flatten(
230
238
  block_adapter.unique_blocks_name
231
239
  )
232
240
  contexts_kwargs = [
233
- cache_context_kwargs.copy()
241
+ context_kwargs.copy()
234
242
  for _ in range(
235
243
  len(flatten_contexts),
236
244
  )
@@ -267,7 +275,7 @@ class CachedAdapter:
267
275
  "calibrator_config", None
268
276
  )
269
277
  if cache_config is not None:
270
- message = f"Collected Cache Config: {cache_config.strify()}"
278
+ message = f"Collected Context Config: {cache_config.strify()}"
271
279
  if calibrator_config is not None:
272
280
  message += f", Calibrator Config: {calibrator_config.strify(details=True)}"
273
281
  else:
@@ -278,6 +286,7 @@ class CachedAdapter:
278
286
  def mock_blocks(
279
287
  cls,
280
288
  block_adapter: BlockAdapter,
289
+ contexts_kwargs: List[Dict],
281
290
  ) -> List[torch.nn.Module]:
282
291
 
283
292
  BlockAdapter.assert_normalized(block_adapter)
@@ -287,20 +296,23 @@ class CachedAdapter:
287
296
 
288
297
  # Apply cache on transformer: mock cached transformer blocks
289
298
  for (
290
- cached_blocks,
299
+ unified_blocks,
291
300
  transformer,
292
301
  blocks_name,
293
302
  unique_blocks_name,
294
303
  dummy_blocks_names,
295
304
  ) in zip(
296
- cls.collect_cached_blocks(block_adapter),
305
+ cls.collect_unified_blocks(
306
+ block_adapter,
307
+ contexts_kwargs,
308
+ ),
297
309
  block_adapter.transformer,
298
310
  block_adapter.blocks_name,
299
311
  block_adapter.unique_blocks_name,
300
312
  block_adapter.dummy_blocks_names,
301
313
  ):
302
314
  cls.mock_transformer(
303
- cached_blocks,
315
+ unified_blocks,
304
316
  transformer,
305
317
  blocks_name,
306
318
  unique_blocks_name,
@@ -312,7 +324,7 @@ class CachedAdapter:
312
324
  @classmethod
313
325
  def mock_transformer(
314
326
  cls,
315
- cached_blocks: Dict[str, torch.nn.ModuleList],
327
+ unified_blocks: Dict[str, torch.nn.ModuleList],
316
328
  transformer: torch.nn.Module,
317
329
  blocks_name: List[str],
318
330
  unique_blocks_name: List[str],
@@ -352,7 +364,7 @@ class CachedAdapter:
352
364
  ):
353
365
  stack.enter_context(
354
366
  unittest.mock.patch.object(
355
- self, name, cached_blocks[context_name]
367
+ self, name, unified_blocks[context_name]
356
368
  )
357
369
  )
358
370
  for dummy_name in dummy_blocks_names:
@@ -388,46 +400,51 @@ class CachedAdapter:
388
400
  return transformer
389
401
 
390
402
  @classmethod
391
- def collect_cached_blocks(
403
+ def collect_unified_blocks(
392
404
  cls,
393
405
  block_adapter: BlockAdapter,
406
+ contexts_kwargs: List[Dict],
394
407
  ) -> List[Dict[str, torch.nn.ModuleList]]:
395
408
 
396
409
  BlockAdapter.assert_normalized(block_adapter)
397
410
 
398
411
  total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
399
- assert hasattr(block_adapter.pipe, "_cache_manager")
412
+ assert hasattr(block_adapter.pipe, "_context_manager")
400
413
  assert isinstance(
401
- block_adapter.pipe._cache_manager,
402
- CachedContextManager,
414
+ block_adapter.pipe._context_manager,
415
+ ContextManager._supported_managers,
403
416
  )
404
417
 
405
418
  for i in range(len(block_adapter.transformer)):
406
419
 
407
- cached_blocks_bind_context = {}
420
+ unified_blocks_bind_context = {}
408
421
  for j in range(len(block_adapter.blocks[i])):
409
- cached_blocks_bind_context[
422
+ cache_config: BasicCacheConfig = contexts_kwargs[
423
+ i * len(block_adapter.blocks[i]) + j
424
+ ]["cache_config"]
425
+ unified_blocks_bind_context[
410
426
  block_adapter.unique_blocks_name[i][j]
411
427
  ] = torch.nn.ModuleList(
412
428
  [
413
- CachedBlocks(
429
+ UnifiedBlocks(
414
430
  # 0. Transformer blocks configuration
415
431
  block_adapter.blocks[i][j],
416
432
  transformer=block_adapter.transformer[i],
417
433
  forward_pattern=block_adapter.forward_pattern[i][j],
418
434
  check_forward_pattern=block_adapter.check_forward_pattern,
419
435
  check_num_outputs=block_adapter.check_num_outputs,
420
- # 1. Cache context configuration
436
+ # 1. Cache/Prune context configuration
421
437
  cache_prefix=block_adapter.blocks_name[i][j],
422
438
  cache_context=block_adapter.unique_blocks_name[i][
423
439
  j
424
440
  ],
425
- cache_manager=block_adapter.pipe._cache_manager,
441
+ context_manager=block_adapter.pipe._context_manager,
442
+ cache_type=cache_config.cache_type,
426
443
  )
427
444
  ]
428
445
  )
429
446
 
430
- total_cached_blocks.append(cached_blocks_bind_context)
447
+ total_cached_blocks.append(unified_blocks_bind_context)
431
448
 
432
449
  return total_cached_blocks
433
450
 
@@ -437,7 +454,7 @@ class CachedAdapter:
437
454
  block_adapter: BlockAdapter,
438
455
  contexts_kwargs: List[Dict],
439
456
  ):
440
- block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
457
+ block_adapter.pipe._context_kwargs = contexts_kwargs[0]
441
458
 
442
459
  params_shift = 0
443
460
  for i in range(len(block_adapter.transformer)):
@@ -448,16 +465,14 @@ class CachedAdapter:
448
465
  block_adapter.transformer[i]._has_separate_cfg = (
449
466
  block_adapter.has_separate_cfg
450
467
  )
451
- block_adapter.transformer[i]._cache_context_kwargs = (
452
- contexts_kwargs[params_shift]
453
- )
468
+ block_adapter.transformer[i]._context_kwargs = contexts_kwargs[
469
+ params_shift
470
+ ]
454
471
 
455
472
  blocks = block_adapter.blocks[i]
456
473
  for j in range(len(blocks)):
457
474
  blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
458
- blocks[j]._cache_context_kwargs = contexts_kwargs[
459
- params_shift + j
460
- ]
475
+ blocks[j]._context_kwargs = contexts_kwargs[params_shift + j]
461
476
 
462
477
  params_shift += len(blocks)
463
478
 
@@ -467,25 +482,25 @@ class CachedAdapter:
467
482
  block_adapter: BlockAdapter,
468
483
  ):
469
484
  from cache_dit.cache_factory.cache_blocks import (
470
- patch_cached_stats,
485
+ apply_stats,
471
486
  )
472
487
 
473
- cache_manager = block_adapter.pipe._cache_manager
488
+ context_manager = block_adapter.pipe._context_manager
474
489
 
475
490
  for i in range(len(block_adapter.transformer)):
476
- patch_cached_stats(
491
+ apply_stats(
477
492
  block_adapter.transformer[i],
478
493
  cache_context=block_adapter.unique_blocks_name[i][-1],
479
- cache_manager=cache_manager,
494
+ context_manager=context_manager,
480
495
  )
481
496
  for blocks, unique_name in zip(
482
497
  block_adapter.blocks[i],
483
498
  block_adapter.unique_blocks_name[i],
484
499
  ):
485
- patch_cached_stats(
500
+ apply_stats(
486
501
  blocks,
487
502
  cache_context=unique_name,
488
- cache_manager=cache_manager,
503
+ context_manager=context_manager,
489
504
  )
490
505
 
491
506
  @classmethod
@@ -513,11 +528,13 @@ class CachedAdapter:
513
528
  original_call = pipe.__class__._original_call
514
529
  pipe.__class__.__call__ = original_call
515
530
  del pipe.__class__._original_call
516
- if hasattr(pipe, "_cache_manager"):
517
- cache_manager = pipe._cache_manager
518
- if isinstance(cache_manager, CachedContextManager):
519
- cache_manager.clear_contexts()
520
- del pipe._cache_manager
531
+ if hasattr(pipe, "_context_manager"):
532
+ context_manager = pipe._context_manager
533
+ if isinstance(
534
+ context_manager, ContextManager._supported_managers
535
+ ):
536
+ context_manager.clear_contexts()
537
+ del pipe._context_manager
521
538
  if hasattr(pipe, "_is_cached"):
522
539
  del pipe.__class__._is_cached
523
540
 
@@ -532,22 +549,22 @@ class CachedAdapter:
532
549
  def _release_blocks_params(blocks):
533
550
  if hasattr(blocks, "_forward_pattern"):
534
551
  del blocks._forward_pattern
535
- if hasattr(blocks, "_cache_context_kwargs"):
536
- del blocks._cache_context_kwargs
552
+ if hasattr(blocks, "_context_kwargs"):
553
+ del blocks._context_kwargs
537
554
 
538
555
  def _release_transformer_params(transformer):
539
556
  if hasattr(transformer, "_forward_pattern"):
540
557
  del transformer._forward_pattern
541
558
  if hasattr(transformer, "_has_separate_cfg"):
542
559
  del transformer._has_separate_cfg
543
- if hasattr(transformer, "_cache_context_kwargs"):
544
- del transformer._cache_context_kwargs
560
+ if hasattr(transformer, "_context_kwargs"):
561
+ del transformer._context_kwargs
545
562
  for blocks in BlockAdapter.find_blocks(transformer):
546
563
  _release_blocks_params(blocks)
547
564
 
548
565
  def _release_pipeline_params(pipe):
549
- if hasattr(pipe, "_cache_context_kwargs"):
550
- del pipe._cache_context_kwargs
566
+ if hasattr(pipe, "_context_kwargs"):
567
+ del pipe._context_kwargs
551
568
 
552
569
  cls.release_hooks(
553
570
  pipe_or_adapter,
@@ -558,14 +575,11 @@ class CachedAdapter:
558
575
 
559
576
  # release stats hooks
560
577
  from cache_dit.cache_factory.cache_blocks import (
561
- remove_cached_stats,
578
+ remove_stats,
562
579
  )
563
580
 
564
581
  cls.release_hooks(
565
- pipe_or_adapter,
566
- remove_cached_stats,
567
- remove_cached_stats,
568
- remove_cached_stats,
582
+ pipe_or_adapter, remove_stats, remove_stats, remove_stats
569
583
  )
570
584
 
571
585
  @classmethod