cache-dit 0.2.29__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.29'
32
- __version_tuple__ = version_tuple = (0, 2, 29)
31
+ __version__ = version = '0.2.30'
32
+ __version_tuple__ = version_tuple = (0, 2, 30)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -123,6 +123,7 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
123
123
  assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
124
124
  return BlockAdapter(
125
125
  pipe=pipe,
126
+ transformer=pipe.transformer,
126
127
  blocks=[
127
128
  pipe.transformer.transformer_blocks,
128
129
  pipe.transformer.single_transformer_blocks,
@@ -131,6 +132,8 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
131
132
  ForwardPattern.Pattern_0,
132
133
  ForwardPattern.Pattern_0,
133
134
  ],
135
+ # The type hint in diffusers is wrong
136
+ check_num_outputs=False,
134
137
  **kwargs,
135
138
  )
136
139
 
@@ -327,37 +330,6 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
327
330
  )
328
331
 
329
332
 
330
- @BlockAdapterRegistry.register("HunyuanDiT")
331
- def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
332
- from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
333
-
334
- assert isinstance(
335
- pipe.transformer,
336
- (HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
337
- )
338
- return BlockAdapter(
339
- pipe=pipe,
340
- transformer=pipe.transformer,
341
- blocks=pipe.transformer.blocks,
342
- forward_pattern=ForwardPattern.Pattern_3,
343
- **kwargs,
344
- )
345
-
346
-
347
- @BlockAdapterRegistry.register("HunyuanDiTPAG")
348
- def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
349
- from diffusers import HunyuanDiT2DModel
350
-
351
- assert isinstance(pipe.transformer, HunyuanDiT2DModel)
352
- return BlockAdapter(
353
- pipe=pipe,
354
- transformer=pipe.transformer,
355
- blocks=pipe.transformer.blocks,
356
- forward_pattern=ForwardPattern.Pattern_3,
357
- **kwargs,
358
- )
359
-
360
-
361
333
  @BlockAdapterRegistry.register("Lumina")
362
334
  def lumina_adapter(pipe, **kwargs) -> BlockAdapter:
363
335
  from diffusers import LuminaNextDiT2DModel
@@ -414,10 +386,12 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
414
386
  )
415
387
 
416
388
 
417
- @BlockAdapterRegistry.register("Sana")
389
+ @BlockAdapterRegistry.register("Sana", supported=False)
418
390
  def sana_adapter(pipe, **kwargs) -> BlockAdapter:
419
391
  from diffusers import SanaTransformer2DModel
420
392
 
393
+ # TODO: fix -> got multiple values for argument 'encoder_hidden_states'
394
+
421
395
  assert isinstance(pipe.transformer, SanaTransformer2DModel)
422
396
  return BlockAdapter(
423
397
  pipe=pipe,
@@ -428,20 +402,6 @@ def sana_adapter(pipe, **kwargs) -> BlockAdapter:
428
402
  )
429
403
 
430
404
 
431
- @BlockAdapterRegistry.register("ShapE")
432
- def shape_adapter(pipe, **kwargs) -> BlockAdapter:
433
- from diffusers import PriorTransformer
434
-
435
- assert isinstance(pipe.prior, PriorTransformer)
436
- return BlockAdapter(
437
- pipe=pipe,
438
- transformer=pipe.prior,
439
- blocks=pipe.prior.transformer_blocks,
440
- forward_pattern=ForwardPattern.Pattern_3,
441
- **kwargs,
442
- )
443
-
444
-
445
405
  @BlockAdapterRegistry.register("StableAudio")
446
406
  def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
447
407
  from diffusers import StableAudioDiTModel
@@ -459,21 +419,37 @@ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
459
419
  @BlockAdapterRegistry.register("VisualCloze")
460
420
  def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
461
421
  from diffusers import FluxTransformer2DModel
422
+ from cache_dit.utils import is_diffusers_at_least_0_3_5
462
423
 
463
424
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
464
- return BlockAdapter(
465
- pipe=pipe,
466
- transformer=pipe.transformer,
467
- blocks=[
468
- pipe.transformer.transformer_blocks,
469
- pipe.transformer.single_transformer_blocks,
470
- ],
471
- forward_pattern=[
472
- ForwardPattern.Pattern_1,
473
- ForwardPattern.Pattern_3,
474
- ],
475
- **kwargs,
476
- )
425
+ if is_diffusers_at_least_0_3_5():
426
+ return BlockAdapter(
427
+ pipe=pipe,
428
+ transformer=pipe.transformer,
429
+ blocks=[
430
+ pipe.transformer.transformer_blocks,
431
+ pipe.transformer.single_transformer_blocks,
432
+ ],
433
+ forward_pattern=[
434
+ ForwardPattern.Pattern_1,
435
+ ForwardPattern.Pattern_1,
436
+ ],
437
+ **kwargs,
438
+ )
439
+ else:
440
+ return BlockAdapter(
441
+ pipe=pipe,
442
+ transformer=pipe.transformer,
443
+ blocks=[
444
+ pipe.transformer.transformer_blocks,
445
+ pipe.transformer.single_transformer_blocks,
446
+ ],
447
+ forward_pattern=[
448
+ ForwardPattern.Pattern_1,
449
+ ForwardPattern.Pattern_3,
450
+ ],
451
+ **kwargs,
452
+ )
477
453
 
478
454
 
479
455
  @BlockAdapterRegistry.register("AuraFlow")
@@ -511,8 +487,27 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
511
487
  )
512
488
 
513
489
 
514
- @BlockAdapterRegistry.register("HiDream")
490
+ @BlockAdapterRegistry.register("ShapE")
491
+ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
492
+ from diffusers import PriorTransformer
493
+
494
+ assert isinstance(pipe.prior, PriorTransformer)
495
+ return BlockAdapter(
496
+ pipe=pipe,
497
+ transformer=pipe.prior,
498
+ blocks=pipe.prior.transformer_blocks,
499
+ forward_pattern=ForwardPattern.Pattern_3,
500
+ **kwargs,
501
+ )
502
+
503
+
504
+ @BlockAdapterRegistry.register("HiDream", supported=True)
515
505
  def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
506
+ # NOTE: Need to patch Transformer forward to fully support
507
+ # double_stream_blocks and single_stream_blocks, namely, need
508
+ # to remove the logics inside the blocks forward loop:
509
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
510
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
516
511
  from diffusers import HiDreamImageTransformer2DModel
517
512
 
518
513
  assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
@@ -520,13 +515,47 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
520
515
  pipe=pipe,
521
516
  transformer=pipe.transformer,
522
517
  blocks=[
523
- pipe.transformer.double_stream_blocks,
518
+ # pipe.transformer.double_stream_blocks,
524
519
  pipe.transformer.single_stream_blocks,
525
520
  ],
526
521
  forward_pattern=[
527
- ForwardPattern.Pattern_4,
522
+ # ForwardPattern.Pattern_4,
528
523
  ForwardPattern.Pattern_3,
529
524
  ],
525
+ # The type hint in diffusers is wrong
530
526
  check_num_outputs=False,
531
527
  **kwargs,
532
528
  )
529
+
530
+
531
+ @BlockAdapterRegistry.register("HunyuanDiT", supported=False)
532
+ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
533
+ # TODO: Patch Transformer forward
534
+ from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
535
+
536
+ assert isinstance(
537
+ pipe.transformer,
538
+ (HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
539
+ )
540
+ return BlockAdapter(
541
+ pipe=pipe,
542
+ transformer=pipe.transformer,
543
+ blocks=pipe.transformer.blocks,
544
+ forward_pattern=ForwardPattern.Pattern_3,
545
+ **kwargs,
546
+ )
547
+
548
+
549
+ @BlockAdapterRegistry.register("HunyuanDiTPAG", supported=False)
550
+ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
551
+ # TODO: Patch Transformer forward
552
+ from diffusers import HunyuanDiT2DModel
553
+
554
+ assert isinstance(pipe.transformer, HunyuanDiT2DModel)
555
+ return BlockAdapter(
556
+ pipe=pipe,
557
+ transformer=pipe.transformer,
558
+ blocks=pipe.transformer.blocks,
559
+ forward_pattern=ForwardPattern.Pattern_3,
560
+ **kwargs,
561
+ )
@@ -75,7 +75,7 @@ class BlockAdapter:
75
75
  List[List[ParamsModifier]],
76
76
  ] = None
77
77
 
78
- check_num_outputs: bool = True
78
+ check_num_outputs: bool = False
79
79
 
80
80
  # Pipeline Level Flags
81
81
  # Patch Functor: Flux, etc.
@@ -111,9 +111,9 @@ class BlockAdapter:
111
111
  def __post_init__(self):
112
112
  if self.skip_post_init:
113
113
  return
114
- assert any((self.pipe is not None, self.transformer is not None))
115
- self.maybe_fill_attrs()
116
- self.maybe_patchify()
114
+ if any((self.pipe is not None, self.transformer is not None)):
115
+ self.maybe_fill_attrs()
116
+ self.maybe_patchify()
117
117
 
118
118
  def maybe_fill_attrs(self):
119
119
  # NOTE: This func should be call before normalize.
@@ -130,7 +130,9 @@ class BlockAdapter:
130
130
  assert isinstance(blocks, torch.nn.ModuleList)
131
131
  blocks_name = None
132
132
  for attr_name in attr_names:
133
- if attr := getattr(transformer, attr_name, None):
133
+ if (
134
+ attr := getattr(transformer, attr_name, None)
135
+ ) is not None:
134
136
  if isinstance(attr, torch.nn.ModuleList) and id(
135
137
  attr
136
138
  ) == id(blocks):
@@ -558,7 +560,7 @@ class BlockAdapter:
558
560
  assert isinstance(adapter[0], torch.nn.Module)
559
561
  return getattr(adapter[0], "_is_cached", False)
560
562
  else:
561
- raise TypeError(f"Can't check this type: {adapter}!")
563
+ raise TypeError(f"Can't check this type: {type(adapter)}!")
562
564
 
563
565
  @classmethod
564
566
  def nested_depth(cls, obj: Any):
@@ -1,4 +1,4 @@
1
- from typing import Any, Tuple, List, Dict
1
+ from typing import Any, Tuple, List, Dict, Callable
2
2
 
3
3
  from diffusers import DiffusionPipeline
4
4
  from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
@@ -9,20 +9,23 @@ logger = init_logger(__name__)
9
9
 
10
10
 
11
11
  class BlockAdapterRegistry:
12
- _adapters: Dict[str, BlockAdapter] = {}
13
- _predefined_adapters_has_spearate_cfg: List[str] = {
12
+ _adapters: Dict[str, Callable[..., BlockAdapter]] = {}
13
+ _predefined_adapters_has_spearate_cfg: List[str] = [
14
14
  "QwenImage",
15
15
  "Wan",
16
16
  "CogView4",
17
17
  "Cosmos",
18
18
  "SkyReelsV2",
19
19
  "Chroma",
20
- }
20
+ ]
21
21
 
22
22
  @classmethod
23
- def register(cls, name):
24
- def decorator(func):
25
- cls._adapters[name] = func
23
+ def register(cls, name: str, supported: bool = True):
24
+ def decorator(
25
+ func: Callable[..., BlockAdapter]
26
+ ) -> Callable[..., BlockAdapter]:
27
+ if supported:
28
+ cls._adapters[name] = func
26
29
  return func
27
30
 
28
31
  return decorator
@@ -4,7 +4,7 @@ import unittest
4
4
  import functools
5
5
 
6
6
  from contextlib import ExitStack
7
- from typing import Dict, List, Tuple, Any
7
+ from typing import Dict, List, Tuple, Any, Union, Callable
8
8
 
9
9
  from diffusers import DiffusionPipeline
10
10
 
@@ -14,7 +14,10 @@ from cache_dit.cache_factory import ParamsModifier
14
14
  from cache_dit.cache_factory import BlockAdapterRegistry
15
15
  from cache_dit.cache_factory import CachedContextManager
16
16
  from cache_dit.cache_factory import CachedBlocks
17
-
17
+ from cache_dit.cache_factory.cache_blocks.utils import (
18
+ patch_cached_stats,
19
+ remove_cached_stats,
20
+ )
18
21
  from cache_dit.logger import init_logger
19
22
 
20
23
  logger = init_logger(__name__)
@@ -29,9 +32,15 @@ class CachedAdapter:
29
32
  @classmethod
30
33
  def apply(
31
34
  cls,
32
- pipe_or_adapter: DiffusionPipeline | BlockAdapter,
35
+ pipe_or_adapter: Union[
36
+ DiffusionPipeline,
37
+ BlockAdapter,
38
+ ],
33
39
  **cache_context_kwargs,
34
- ) -> BlockAdapter:
40
+ ) -> Union[
41
+ DiffusionPipeline,
42
+ BlockAdapter,
43
+ ]:
35
44
  assert (
36
45
  pipe_or_adapter is not None
37
46
  ), "pipe or block_adapter can not both None!"
@@ -49,7 +58,7 @@ class CachedAdapter:
49
58
  return cls.cachify(
50
59
  block_adapter,
51
60
  **cache_context_kwargs,
52
- )
61
+ ).pipe
53
62
  else:
54
63
  raise ValueError(
55
64
  f"{pipe_or_adapter.__class__.__name__} is not officially supported "
@@ -82,7 +91,7 @@ class CachedAdapter:
82
91
  # 0. Must normalize block_adapter before apply cache
83
92
  block_adapter = BlockAdapter.normalize(block_adapter)
84
93
  if BlockAdapter.is_cached(block_adapter):
85
- return block_adapter.pipe
94
+ return block_adapter
86
95
 
87
96
  # 1. Apply cache on pipeline: wrap cache context, must
88
97
  # call create_context before mock_blocks.
@@ -98,36 +107,6 @@ class CachedAdapter:
98
107
 
99
108
  return block_adapter
100
109
 
101
- @classmethod
102
- def patch_params(
103
- cls,
104
- block_adapter: BlockAdapter,
105
- contexts_kwargs: List[Dict],
106
- ):
107
- block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
108
-
109
- params_shift = 0
110
- for i in range(len(block_adapter.transformer)):
111
-
112
- block_adapter.transformer[i]._forward_pattern = (
113
- block_adapter.forward_pattern
114
- )
115
- block_adapter.transformer[i]._has_separate_cfg = (
116
- block_adapter.has_separate_cfg
117
- )
118
- block_adapter.transformer[i]._cache_context_kwargs = (
119
- contexts_kwargs[params_shift]
120
- )
121
-
122
- blocks = block_adapter.blocks[i]
123
- for j in range(len(blocks)):
124
- blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
125
- blocks[j]._cache_context_kwargs = contexts_kwargs[
126
- params_shift + j
127
- ]
128
-
129
- params_shift += len(blocks)
130
-
131
110
  @classmethod
132
111
  def check_context_kwargs(
133
112
  cls,
@@ -153,7 +132,9 @@ class CachedAdapter:
153
132
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
154
133
  )
155
134
 
156
- if cache_type := cache_context_kwargs.pop("cache_type", None):
135
+ if (
136
+ cache_type := cache_context_kwargs.pop("cache_type", None)
137
+ ) is not None:
157
138
  assert (
158
139
  cache_type == CacheType.DBCache
159
140
  ), "Custom cache setting only support for DBCache now!"
@@ -210,14 +191,14 @@ class CachedAdapter:
210
191
  )
211
192
  )
212
193
  outputs = original_call(self, *args, **kwargs)
213
- cls.patch_stats(block_adapter)
194
+ cls.apply_stats_hooks(block_adapter)
214
195
  return outputs
215
196
 
216
197
  block_adapter.pipe.__class__.__call__ = new_call
217
198
  block_adapter.pipe.__class__._original_call = original_call
218
199
  block_adapter.pipe.__class__._is_cached = True
219
200
 
220
- cls.patch_params(block_adapter, contexts_kwargs)
201
+ cls.apply_params_hooks(block_adapter, contexts_kwargs)
221
202
 
222
203
  return block_adapter.pipe
223
204
 
@@ -261,33 +242,6 @@ class CachedAdapter:
261
242
 
262
243
  return flatten_contexts, contexts_kwargs
263
244
 
264
- @classmethod
265
- def patch_stats(
266
- cls,
267
- block_adapter: BlockAdapter,
268
- ):
269
- from cache_dit.cache_factory.cache_blocks.utils import (
270
- patch_cached_stats,
271
- )
272
-
273
- cache_manager = block_adapter.pipe._cache_manager
274
-
275
- for i in range(len(block_adapter.transformer)):
276
- patch_cached_stats(
277
- block_adapter.transformer[i],
278
- cache_context=block_adapter.unique_blocks_name[i][-1],
279
- cache_manager=cache_manager,
280
- )
281
- for blocks, unique_name in zip(
282
- block_adapter.blocks[i],
283
- block_adapter.unique_blocks_name[i],
284
- ):
285
- patch_cached_stats(
286
- blocks,
287
- cache_context=unique_name,
288
- cache_manager=cache_manager,
289
- )
290
-
291
245
  @classmethod
292
246
  def mock_blocks(
293
247
  cls,
@@ -405,3 +359,159 @@ class CachedAdapter:
405
359
  total_cached_blocks.append(cached_blocks_bind_context)
406
360
 
407
361
  return total_cached_blocks
362
+
363
+ @classmethod
364
+ def apply_params_hooks(
365
+ cls,
366
+ block_adapter: BlockAdapter,
367
+ contexts_kwargs: List[Dict],
368
+ ):
369
+ block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
370
+
371
+ params_shift = 0
372
+ for i in range(len(block_adapter.transformer)):
373
+
374
+ block_adapter.transformer[i]._forward_pattern = (
375
+ block_adapter.forward_pattern
376
+ )
377
+ block_adapter.transformer[i]._has_separate_cfg = (
378
+ block_adapter.has_separate_cfg
379
+ )
380
+ block_adapter.transformer[i]._cache_context_kwargs = (
381
+ contexts_kwargs[params_shift]
382
+ )
383
+
384
+ blocks = block_adapter.blocks[i]
385
+ for j in range(len(blocks)):
386
+ blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
387
+ blocks[j]._cache_context_kwargs = contexts_kwargs[
388
+ params_shift + j
389
+ ]
390
+
391
+ params_shift += len(blocks)
392
+
393
+ @classmethod
394
+ def apply_stats_hooks(
395
+ cls,
396
+ block_adapter: BlockAdapter,
397
+ ):
398
+ cache_manager = block_adapter.pipe._cache_manager
399
+
400
+ for i in range(len(block_adapter.transformer)):
401
+ patch_cached_stats(
402
+ block_adapter.transformer[i],
403
+ cache_context=block_adapter.unique_blocks_name[i][-1],
404
+ cache_manager=cache_manager,
405
+ )
406
+ for blocks, unique_name in zip(
407
+ block_adapter.blocks[i],
408
+ block_adapter.unique_blocks_name[i],
409
+ ):
410
+ patch_cached_stats(
411
+ blocks,
412
+ cache_context=unique_name,
413
+ cache_manager=cache_manager,
414
+ )
415
+
416
+ @classmethod
417
+ def maybe_release_hooks(
418
+ cls,
419
+ pipe_or_adapter: Union[
420
+ DiffusionPipeline,
421
+ BlockAdapter,
422
+ ],
423
+ ):
424
+ # release model hooks
425
+ def _release_blocks_hooks(blocks):
426
+ return
427
+
428
+ def _release_transformer_hooks(transformer):
429
+ if hasattr(transformer, "_original_forward"):
430
+ original_forward = transformer._original_forward
431
+ transformer.forward = original_forward.__get__(transformer)
432
+ del transformer._original_forward
433
+ if hasattr(transformer, "_is_cached"):
434
+ del transformer._is_cached
435
+
436
+ def _release_pipeline_hooks(pipe):
437
+ if hasattr(pipe, "_original_call"):
438
+ original_call = pipe.__class__._original_call
439
+ pipe.__class__.__call__ = original_call
440
+ del pipe.__class__._original_call
441
+ if hasattr(pipe, "_cache_manager"):
442
+ cache_manager = pipe._cache_manager
443
+ if isinstance(cache_manager, CachedContextManager):
444
+ cache_manager.clear_contexts()
445
+ del pipe._cache_manager
446
+ if hasattr(pipe, "_is_cached"):
447
+ del pipe.__class__._is_cached
448
+
449
+ cls.release_hooks(
450
+ pipe_or_adapter,
451
+ _release_blocks_hooks,
452
+ _release_transformer_hooks,
453
+ _release_pipeline_hooks,
454
+ )
455
+
456
+ # release params hooks
457
+ def _release_blocks_params(blocks):
458
+ if hasattr(blocks, "_forward_pattern"):
459
+ del blocks._forward_pattern
460
+ if hasattr(blocks, "_cache_context_kwargs"):
461
+ del blocks._cache_context_kwargs
462
+
463
+ def _release_transformer_params(transformer):
464
+ if hasattr(transformer, "_forward_pattern"):
465
+ del transformer._forward_pattern
466
+ if hasattr(transformer, "_has_separate_cfg"):
467
+ del transformer._has_separate_cfg
468
+ if hasattr(transformer, "_cache_context_kwargs"):
469
+ del transformer._cache_context_kwargs
470
+ for blocks in BlockAdapter.find_blocks(transformer):
471
+ _release_blocks_params(blocks)
472
+
473
+ def _release_pipeline_params(pipe):
474
+ if hasattr(pipe, "_cache_context_kwargs"):
475
+ del pipe._cache_context_kwargs
476
+
477
+ cls.release_hooks(
478
+ pipe_or_adapter,
479
+ _release_blocks_params,
480
+ _release_transformer_params,
481
+ _release_pipeline_params,
482
+ )
483
+
484
+ # release stats hooks
485
+ cls.release_hooks(
486
+ pipe_or_adapter,
487
+ remove_cached_stats,
488
+ remove_cached_stats,
489
+ remove_cached_stats,
490
+ )
491
+
492
+ @classmethod
493
+ def release_hooks(
494
+ cls,
495
+ pipe_or_adapter: Union[
496
+ DiffusionPipeline,
497
+ BlockAdapter,
498
+ ],
499
+ _release_blocks: Callable,
500
+ _release_transformer: Callable,
501
+ _release_pipeline: Callable,
502
+ ):
503
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
504
+ pipe = pipe_or_adapter
505
+ _release_pipeline(pipe)
506
+ if hasattr(pipe, "transformer"):
507
+ _release_transformer(pipe.transformer)
508
+ if hasattr(pipe, "transformer_2"): # Wan 2.2
509
+ _release_transformer(pipe.transformer_2)
510
+ elif isinstance(pipe_or_adapter, BlockAdapter):
511
+ adapter = pipe_or_adapter
512
+ BlockAdapter.assert_normalized(adapter)
513
+ _release_pipeline(adapter.pipe)
514
+ for transformer in BlockAdapter.flatten(adapter.transformer):
515
+ _release_transformer(transformer)
516
+ for blocks in BlockAdapter.flatten(adapter.blocks):
517
+ _release_blocks(blocks)
@@ -1,5 +1,6 @@
1
1
  import torch
2
2
 
3
+ from typing import Dict, Any
3
4
  from cache_dit.cache_factory import ForwardPattern
4
5
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
5
6
  CachedBlocks_Pattern_Base,
@@ -31,7 +32,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
31
32
  # Call first `n` blocks to process the hidden states for
32
33
  # more stable diff calculation.
33
34
  # encoder_hidden_states: None Pattern 3, else 4, 5
34
- hidden_states, encoder_hidden_states = self.call_Fn_blocks(
35
+ hidden_states, new_encoder_hidden_states = self.call_Fn_blocks(
35
36
  hidden_states,
36
37
  *args,
37
38
  **kwargs,
@@ -60,11 +61,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
60
61
  if can_use_cache:
61
62
  self.cache_manager.add_cached_step()
62
63
  del Fn_hidden_states_residual
63
- hidden_states, encoder_hidden_states = (
64
+ hidden_states, new_encoder_hidden_states = (
64
65
  self.cache_manager.apply_cache(
65
66
  hidden_states,
66
- # None Pattern 3, else 4, 5
67
- encoder_hidden_states,
67
+ new_encoder_hidden_states, # encoder_hidden_states not use cache
68
68
  prefix=(
69
69
  f"{self.cache_prefix}_Bn_residual"
70
70
  if self.cache_manager.is_cache_residual()
@@ -80,12 +80,12 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
80
80
  torch._dynamo.graph_break()
81
81
  # Call last `n` blocks to further process the hidden states
82
82
  # for higher precision.
83
- hidden_states, encoder_hidden_states = self.call_Bn_blocks(
84
- hidden_states,
85
- encoder_hidden_states,
86
- *args,
87
- **kwargs,
88
- )
83
+ if self.cache_manager.Bn_compute_blocks() > 0:
84
+ hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
85
+ hidden_states,
86
+ *args,
87
+ **kwargs,
88
+ )
89
89
  else:
90
90
  self.cache_manager.set_Fn_buffer(
91
91
  Fn_hidden_states_residual,
@@ -99,19 +99,20 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
99
99
  )
100
100
  del Fn_hidden_states_residual
101
101
  torch._dynamo.graph_break()
102
+ old_encoder_hidden_states = new_encoder_hidden_states
102
103
  (
103
104
  hidden_states,
104
- encoder_hidden_states,
105
+ new_encoder_hidden_states,
105
106
  hidden_states_residual,
106
- # None Pattern 3, else 4, 5
107
- encoder_hidden_states_residual,
108
107
  ) = self.call_Mn_blocks( # middle
109
108
  hidden_states,
110
- # None Pattern 3, else 4, 5
111
- encoder_hidden_states,
112
109
  *args,
113
110
  **kwargs,
114
111
  )
112
+ if new_encoder_hidden_states is not None:
113
+ new_encoder_hidden_states_residual = (
114
+ new_encoder_hidden_states - old_encoder_hidden_states
115
+ )
115
116
  torch._dynamo.graph_break()
116
117
  if self.cache_manager.is_cache_residual():
117
118
  self.cache_manager.set_Bn_buffer(
@@ -119,34 +120,32 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
119
120
  prefix=f"{self.cache_prefix}_Bn_residual",
120
121
  )
121
122
  else:
122
- # TaylorSeer
123
123
  self.cache_manager.set_Bn_buffer(
124
124
  hidden_states,
125
125
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
126
126
  )
127
+
127
128
  if self.cache_manager.is_encoder_cache_residual():
128
- self.cache_manager.set_Bn_encoder_buffer(
129
- # None Pattern 3, else 4, 5
130
- encoder_hidden_states_residual,
131
- prefix=f"{self.cache_prefix}_Bn_residual",
132
- )
129
+ if new_encoder_hidden_states is not None:
130
+ self.cache_manager.set_Bn_encoder_buffer(
131
+ new_encoder_hidden_states_residual,
132
+ prefix=f"{self.cache_prefix}_Bn_residual",
133
+ )
133
134
  else:
134
- # TaylorSeer
135
- self.cache_manager.set_Bn_encoder_buffer(
136
- # None Pattern 3, else 4, 5
137
- encoder_hidden_states,
138
- prefix=f"{self.cache_prefix}_Bn_hidden_states",
139
- )
135
+ if new_encoder_hidden_states is not None:
136
+ self.cache_manager.set_Bn_encoder_buffer(
137
+ new_encoder_hidden_states_residual,
138
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
139
+ )
140
140
  torch._dynamo.graph_break()
141
141
  # Call last `n` blocks to further process the hidden states
142
142
  # for higher precision.
143
- hidden_states, encoder_hidden_states = self.call_Bn_blocks(
144
- hidden_states,
145
- # None Pattern 3, else 4, 5
146
- encoder_hidden_states,
147
- *args,
148
- **kwargs,
149
- )
143
+ if self.cache_manager.Bn_compute_blocks() > 0:
144
+ hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
145
+ hidden_states,
146
+ *args,
147
+ **kwargs,
148
+ )
150
149
 
151
150
  torch._dynamo.graph_break()
152
151
 
@@ -154,12 +153,21 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
154
153
  hidden_states
155
154
  if self.forward_pattern.Return_H_Only
156
155
  else (
157
- (hidden_states, encoder_hidden_states)
156
+ (hidden_states, new_encoder_hidden_states)
158
157
  if self.forward_pattern.Return_H_First
159
- else (encoder_hidden_states, hidden_states)
158
+ else (new_encoder_hidden_states, hidden_states)
160
159
  )
161
160
  )
162
161
 
162
+ @torch.compiler.disable
163
+ def maybe_update_kwargs(
164
+ self, encoder_hidden_states, kwargs: Dict[str, Any]
165
+ ) -> Dict[str, Any]:
166
+ # if "encoder_hidden_states" in kwargs:
167
+ # kwargs["encoder_hidden_states"] = encoder_hidden_states
168
+ # return kwargs
169
+ return kwargs
170
+
163
171
  def call_Fn_blocks(
164
172
  self,
165
173
  hidden_states: torch.Tensor,
@@ -172,7 +180,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
172
180
  f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
173
181
  f"the number of transformer blocks {len(self.transformer_blocks)}"
174
182
  )
175
- encoder_hidden_states = None # Pattern 3
183
+ new_encoder_hidden_states = None
176
184
  for block in self._Fn_blocks():
177
185
  hidden_states = block(
178
186
  hidden_states,
@@ -180,25 +188,27 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
180
188
  **kwargs,
181
189
  )
182
190
  if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
183
- hidden_states, encoder_hidden_states = hidden_states
191
+ hidden_states, new_encoder_hidden_states = hidden_states
184
192
  if not self.forward_pattern.Return_H_First:
185
- hidden_states, encoder_hidden_states = (
186
- encoder_hidden_states,
193
+ hidden_states, new_encoder_hidden_states = (
194
+ new_encoder_hidden_states,
187
195
  hidden_states,
188
196
  )
197
+ kwargs = self.maybe_update_kwargs(
198
+ new_encoder_hidden_states,
199
+ kwargs,
200
+ )
189
201
 
190
- return hidden_states, encoder_hidden_states
202
+ return hidden_states, new_encoder_hidden_states
191
203
 
192
204
  def call_Mn_blocks(
193
205
  self,
194
206
  hidden_states: torch.Tensor,
195
- # None Pattern 3, else 4, 5
196
- encoder_hidden_states: torch.Tensor | None,
197
207
  *args,
198
208
  **kwargs,
199
209
  ):
200
210
  original_hidden_states = hidden_states
201
- original_encoder_hidden_states = encoder_hidden_states
211
+ new_encoder_hidden_states = None
202
212
  for block in self._Mn_blocks():
203
213
  hidden_states = block(
204
214
  hidden_states,
@@ -206,44 +216,33 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
206
216
  **kwargs,
207
217
  )
208
218
  if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
209
- hidden_states, encoder_hidden_states = hidden_states
219
+ hidden_states, new_encoder_hidden_states = hidden_states
210
220
  if not self.forward_pattern.Return_H_First:
211
- hidden_states, encoder_hidden_states = (
212
- encoder_hidden_states,
221
+ hidden_states, new_encoder_hidden_states = (
222
+ new_encoder_hidden_states,
213
223
  hidden_states,
214
224
  )
225
+ kwargs = self.maybe_update_kwargs(
226
+ new_encoder_hidden_states,
227
+ kwargs,
228
+ )
215
229
 
216
230
  # compute hidden_states residual
217
231
  hidden_states = hidden_states.contiguous()
218
232
  hidden_states_residual = hidden_states - original_hidden_states
219
- if (
220
- original_encoder_hidden_states is not None
221
- and encoder_hidden_states is not None
222
- ): # Pattern 4, 5
223
- encoder_hidden_states_residual = (
224
- encoder_hidden_states - original_encoder_hidden_states
225
- )
226
- else:
227
- encoder_hidden_states_residual = None # Pattern 3
228
233
 
229
234
  return (
230
235
  hidden_states,
231
- encoder_hidden_states,
236
+ new_encoder_hidden_states,
232
237
  hidden_states_residual,
233
- encoder_hidden_states_residual,
234
238
  )
235
239
 
236
240
  def call_Bn_blocks(
237
241
  self,
238
242
  hidden_states: torch.Tensor,
239
- # None Pattern 3, else 4, 5
240
- encoder_hidden_states: torch.Tensor | None,
241
243
  *args,
242
244
  **kwargs,
243
245
  ):
244
- if self.cache_manager.Bn_compute_blocks() == 0:
245
- return hidden_states, encoder_hidden_states
246
-
247
246
  assert self.cache_manager.Bn_compute_blocks() <= len(
248
247
  self.transformer_blocks
249
248
  ), (
@@ -264,11 +263,15 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
264
263
  **kwargs,
265
264
  )
266
265
  if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
267
- hidden_states, encoder_hidden_states = hidden_states
266
+ hidden_states, new_encoder_hidden_states = hidden_states
268
267
  if not self.forward_pattern.Return_H_First:
269
- hidden_states, encoder_hidden_states = (
270
- encoder_hidden_states,
268
+ hidden_states, new_encoder_hidden_states = (
269
+ new_encoder_hidden_states,
271
270
  hidden_states,
272
271
  )
272
+ kwargs = self.maybe_update_kwargs(
273
+ new_encoder_hidden_states,
274
+ kwargs,
275
+ )
273
276
 
274
- return hidden_states, encoder_hidden_states
277
+ return hidden_states, new_encoder_hidden_states
@@ -733,17 +733,15 @@ class CachedContextManager:
733
733
  encoder_prefix
734
734
  )
735
735
 
736
- assert (
737
- encoder_hidden_states_prev is not None
738
- ), f"{prefix}_encoder_buffer must be set before"
736
+ if encoder_hidden_states_prev is not None:
739
737
 
740
- if self.is_encoder_cache_residual():
741
- encoder_hidden_states = (
742
- encoder_hidden_states_prev + encoder_hidden_states
743
- )
744
- else:
745
- # If encoder cache is not residual, we use the encoder hidden states directly
746
- encoder_hidden_states = encoder_hidden_states_prev
738
+ if self.is_encoder_cache_residual():
739
+ encoder_hidden_states = (
740
+ encoder_hidden_states_prev + encoder_hidden_states
741
+ )
742
+ else:
743
+ # If encoder cache is not residual, we use the encoder hidden states directly
744
+ encoder_hidden_states = encoder_hidden_states_prev
747
745
 
748
746
  encoder_hidden_states = encoder_hidden_states.contiguous()
749
747
 
@@ -1,11 +1,9 @@
1
- import torch
2
- from typing import Any, Tuple, List
1
+ from typing import Any, Tuple, List, Union
3
2
  from diffusers import DiffusionPipeline
4
3
  from cache_dit.cache_factory.cache_types import CacheType
5
4
  from cache_dit.cache_factory.block_adapters import BlockAdapter
6
5
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
7
6
  from cache_dit.cache_factory.cache_adapters import CachedAdapter
8
- from cache_dit.cache_factory.cache_contexts import CachedContextManager
9
7
 
10
8
  from cache_dit.logger import init_logger
11
9
 
@@ -14,7 +12,10 @@ logger = init_logger(__name__)
14
12
 
15
13
  def enable_cache(
16
14
  # DiffusionPipeline or BlockAdapter
17
- pipe_or_adapter: DiffusionPipeline | BlockAdapter,
15
+ pipe_or_adapter: Union[
16
+ DiffusionPipeline,
17
+ BlockAdapter,
18
+ ],
18
19
  # Cache context kwargs
19
20
  Fn_compute_blocks: int = 8,
20
21
  Bn_compute_blocks: int = 0,
@@ -32,7 +33,10 @@ def enable_cache(
32
33
  taylorseer_cache_type: str = "residual",
33
34
  taylorseer_order: int = 2,
34
35
  **other_cache_context_kwargs,
35
- ) -> BlockAdapter:
36
+ ) -> Union[
37
+ DiffusionPipeline,
38
+ BlockAdapter,
39
+ ]:
36
40
  r"""
37
41
  Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
38
42
  that match the specific Input and Output patterns).
@@ -102,11 +106,11 @@ def enable_cache(
102
106
  >>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
103
107
  >>> output = pipe(...) # Just call the pipe as normal.
104
108
  >>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
109
+ >>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
105
110
  """
106
-
107
111
  # Collect cache context kwargs
108
112
  cache_context_kwargs = other_cache_context_kwargs.copy()
109
- if cache_type := cache_context_kwargs.get("cache_type", None):
113
+ if (cache_type := cache_context_kwargs.get("cache_type", None)) is not None:
110
114
  if cache_type == CacheType.NONE:
111
115
  return pipe_or_adapter
112
116
 
@@ -145,79 +149,17 @@ def enable_cache(
145
149
 
146
150
 
147
151
  def disable_cache(
148
- # DiffusionPipeline or BlockAdapter
149
- pipe_or_adapter: DiffusionPipeline | BlockAdapter,
152
+ pipe_or_adapter: Union[
153
+ DiffusionPipeline,
154
+ BlockAdapter,
155
+ ],
150
156
  ):
151
- from cache_dit.cache_factory.cache_blocks.utils import (
152
- remove_cached_stats,
157
+ CachedAdapter.maybe_release_hooks(pipe_or_adapter)
158
+ logger.warning(
159
+ f"Cache Acceleration is disabled for: "
160
+ f"{pipe_or_adapter.__class__.__name__}."
153
161
  )
154
162
 
155
- def _disable_blocks(blocks: torch.nn.ModuleList):
156
- if blocks is None:
157
- return
158
- if hasattr(blocks, "_forward_pattern"):
159
- del blocks._forward_pattern
160
- if hasattr(blocks, "_cache_context_kwargs"):
161
- del blocks._cache_context_kwargs
162
- remove_cached_stats(blocks)
163
-
164
- def _disable_transformer(transformer: torch.nn.Module):
165
- if transformer is None or not BlockAdapter.is_cached(transformer):
166
- return
167
- if original_forward := getattr(transformer, "_original_forward"):
168
- transformer.forward = original_forward.__get__(transformer)
169
- del transformer._original_forward
170
- if hasattr(transformer, "_is_cached"):
171
- del transformer._is_cached
172
- if hasattr(transformer, "_forward_pattern"):
173
- del transformer._forward_pattern
174
- if hasattr(transformer, "_has_separate_cfg"):
175
- del transformer._has_separate_cfg
176
- if hasattr(transformer, "_cache_context_kwargs"):
177
- del transformer._cache_context_kwargs
178
- remove_cached_stats(transformer)
179
- for blocks in BlockAdapter.find_blocks(transformer):
180
- _disable_blocks(blocks)
181
-
182
- def _disable_pipe(pipe: DiffusionPipeline):
183
- if pipe is None or not BlockAdapter.is_cached(pipe):
184
- return
185
- if original_call := getattr(pipe, "_original_call"):
186
- pipe.__class__.__call__ = original_call
187
- del pipe.__class__._original_call
188
- if cache_manager := getattr(pipe, "_cache_manager"):
189
- assert isinstance(cache_manager, CachedContextManager)
190
- cache_manager.clear_contexts()
191
- del pipe._cache_manager
192
- if hasattr(pipe, "_is_cached"):
193
- del pipe.__class__._is_cached
194
- if hasattr(pipe, "_cache_context_kwargs"):
195
- del pipe._cache_context_kwargs
196
- remove_cached_stats(pipe)
197
-
198
- if isinstance(pipe_or_adapter, DiffusionPipeline):
199
- pipe = pipe_or_adapter
200
- _disable_pipe(pipe)
201
- if hasattr(pipe, "transformer"):
202
- _disable_transformer(pipe.transformer)
203
- if hasattr(pipe, "transformer_2"): # Wan 2.2
204
- _disable_transformer(pipe.transformer_2)
205
- pipe_cls_name = pipe.__class__.__name__
206
- logger.warning(f"Cache Acceleration is disabled for: {pipe_cls_name}")
207
- elif isinstance(pipe_or_adapter, BlockAdapter):
208
- # BlockAdapter
209
- adapter = pipe_or_adapter
210
- BlockAdapter.assert_normalized(adapter)
211
- _disable_pipe(adapter.pipe)
212
- for transformer in BlockAdapter.flatten(adapter.transformer):
213
- _disable_transformer(transformer)
214
- for blocks in BlockAdapter.flatten(adapter.blocks):
215
- _disable_blocks(blocks)
216
- pipe_cls_name = adapter.pipe.__class__.__name__
217
- logger.warning(f"Cache Acceleration is disabled for: {pipe_cls_name}")
218
- else:
219
- pass # do nothing
220
-
221
163
 
222
164
  def supported_pipelines(
223
165
  **kwargs,
@@ -22,11 +22,11 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
22
22
  if isinstance(type_hint, CacheType):
23
23
  return type_hint
24
24
 
25
- elif type_hint.lower() in (
26
- "dual_block_cache",
27
- "db_cache",
28
- "dbcache",
29
- "db",
25
+ elif type_hint.upper() in (
26
+ "DUAL_BLOCK_CACHE",
27
+ "DB_CACHE",
28
+ "DBCACHE",
29
+ "DB",
30
30
  ):
31
31
  return CacheType.DBCache
32
32
  return CacheType.NONE
@@ -56,7 +56,8 @@ class ChromaPatchFunctor(PatchFunctor):
56
56
  transformer.forward = __patch_transformer_forward__.__get__(
57
57
  transformer
58
58
  )
59
- transformer._is_patched = True
59
+
60
+ transformer._is_patched = is_patched # True or False
60
61
 
61
62
  cls_name = transformer.__class__.__name__
62
63
  logger.info(
@@ -57,7 +57,8 @@ class FluxPatchFunctor(PatchFunctor):
57
57
  transformer.forward = __patch_transformer_forward__.__get__(
58
58
  transformer
59
59
  )
60
- transformer._is_patched = True
60
+
61
+ transformer._is_patched = is_patched # True or False
61
62
 
62
63
  cls_name = transformer.__class__.__name__
63
64
  logger.info(
cache_dit/utils.py CHANGED
@@ -52,6 +52,9 @@ def summary(
52
52
  if hasattr(adapter_or_others, "transformer_2"):
53
53
  transformer_2 = adapter_or_others.transformer_2
54
54
 
55
+ if not BlockAdapter.is_cached(transformer):
56
+ return [CacheStats()]
57
+
55
58
  blocks_stats: List[CacheStats] = []
56
59
  for blocks in BlockAdapter.find_blocks(transformer):
57
60
  blocks_stats.append(
@@ -212,7 +215,8 @@ def _summary(
212
215
  if logging:
213
216
  print(f"\n🤗Cache Options: {cls_name}\n\n{cache_options}")
214
217
  else:
215
- logger.warning(f"Can't find Cache Options for: {cls_name}")
218
+ if logging:
219
+ logger.warning(f"Can't find Cache Options for: {cls_name}")
216
220
 
217
221
  if hasattr(module, "_cached_steps"):
218
222
  cached_steps: list[int] = module._cached_steps
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.29
3
+ Version: 0.2.30
4
4
  Summary: 🤗 A Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -59,7 +59,7 @@ Dynamic: requires-python
59
59
  🔥<b><a href="#unified">Unified Cache APIs</a> | <a href="#dbcache">DBCache</a> | <a href="#taylorseer">Hybrid TaylorSeer</a> | <a href="#cfg">Hybrid Cache CFG</a></b>🔥
60
60
  </p>
61
61
  <p align="center">
62
- 🎉Now, <b>cache-dit</b> covers <b>100%</b> Diffusers' <b>DiT-based</b> Pipelines🎉<br>
62
+ 🎉Now, <b>cache-dit</b> covers <b>mainstream</b> Diffusers' <b>DiT-based</b> Pipelines🎉<br>
63
63
  🔥<b><a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1/2.2</a> | <a href="#supported"> ... </a> | <a href="#supported">CogVideoX</a></b>🔥
64
64
  </p>
65
65
  </div>
@@ -87,7 +87,6 @@ Dynamic: requires-python
87
87
  <summary> Previous News </summary>
88
88
 
89
89
  - [2025-09-01] 📚[**Hybird Forward Pattern**](#unified) is supported! Please check [FLUX.1-dev](./examples/run_flux_adapter.py) as an example.
90
- - [2025-08-29] 🔥</b>Covers <b>100%</b> Diffusers' <b>DiT-based</b> Pipelines: **[BlockAdapter](#unified) + [Pattern Matching](#unified).**
91
90
  - [2025-08-10] 🔥[**FLUX.1-Kontext-dev**](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) is supported! Please refer [run_flux_kontext.py](./examples/pipeline/run_flux_kontext.py) as an example.
92
91
  - [2025-07-18] 🎉First caching mechanism in [🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/huggingface/flux-fast/pull/13).
93
92
 
@@ -130,19 +129,8 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
130
129
 
131
130
  <div id="supported"></div>
132
131
 
133
- ```python
134
- >>> import cache_dit
135
- >>> cache_dit.supported_pipelines()
136
- (31, ['Flux*', 'Mochi*', 'CogVideoX*', 'Wan*', 'HunyuanVideo*', 'QwenImage*', 'LTXVideo*',
137
- 'Allegro*', 'CogView3Plus*', 'CogView4*', 'Cosmos*', 'EasyAnimate*', 'SkyReelsV2*', 'SD3*',
138
- 'ConsisID*', 'DiT*', 'Amused*', 'Bria*', 'HunyuanDiT*', 'HunyuanDiTPAG*', 'Lumina*', 'Lumina2*',
139
- 'OmniGen*', 'PixArt*', 'Sana*', 'ShapE*', 'StableAudio*', 'VisualCloze*', 'AuraFlow*',
140
- 'Chroma*', 'HiDream*'])
141
- ```
142
-
143
132
  Currently, **cache-dit** library supports almost **Any** Diffusion Transformers (with **Transformer Blocks** that match the specific Input and Output **patterns**). Please check [🎉Unified Cache APIs](#unified) for more details. Here are just some of the tested models listed:
144
133
 
145
-
146
134
  - [🚀Qwen-Image-Edit](https://github.com/vipshop/cache-dit/raw/main/examples)
147
135
  - [🚀Qwen-Image](https://github.com/vipshop/cache-dit/raw/main/examples)
148
136
  - [🚀FLUX.1-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
@@ -154,35 +142,7 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
154
142
  - [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
155
143
  - [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
156
144
  - [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
157
- - [🚀HunyuanDiT](https://github.com/vipshop/cache-dit/raw/main/examples)
158
145
 
159
- <details>
160
- <summary> More Pipelines </summary>
161
-
162
- - [🚀mochi-1-preview](https://github.com/vipshop/cache-dit/raw/main/examples)
163
- - [🚀LTXVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
164
- - [🚀Allegro](https://github.com/vipshop/cache-dit/raw/main/examples)
165
- - [🚀CogView3Plus](https://github.com/vipshop/cache-dit/raw/main/examples)
166
- - [🚀CogView4](https://github.com/vipshop/cache-dit/raw/main/examples)
167
- - [🚀Cosmos](https://github.com/vipshop/cache-dit/raw/main/examples)
168
- - [🚀EasyAnimate](https://github.com/vipshop/cache-dit/raw/main/examples)
169
- - [🚀SkyReelsV2](https://github.com/vipshop/cache-dit/raw/main/examples)
170
- - [🚀SD3](https://github.com/vipshop/cache-dit/raw/main/examples)
171
- - [🚀ConsisID](https://github.com/vipshop/cache-dit/raw/main/examples)
172
- - [🚀DiT](https://github.com/vipshop/cache-dit/raw/main/examples)
173
- - [🚀Amused](https://github.com/vipshop/cache-dit/raw/main/examples)
174
- - [🚀HunyuanDiTPAG](https://github.com/vipshop/cache-dit/raw/main/examples)
175
- - [🚀Lumina](https://github.com/vipshop/cache-dit/raw/main/examples)
176
- - [🚀Lumina2](https://github.com/vipshop/cache-dit/raw/main/examples)
177
- - [🚀OmniGen](https://github.com/vipshop/cache-dit/raw/main/examples)
178
- - [🚀PixArt](https://github.com/vipshop/cache-dit/raw/main/examples)
179
- - [🚀Sana](https://github.com/vipshop/cache-dit/raw/main/examples)
180
- - [🚀StableAudio](https://github.com/vipshop/cache-dit/raw/main/examples)
181
- - [🚀VisualCloze](https://github.com/vipshop/cache-dit/raw/main/examples)
182
- - [🚀AuraFlow](https://github.com/vipshop/cache-dit/raw/main/examples)
183
- - [🚀Chroma](https://github.com/vipshop/cache-dit/raw/main/examples)
184
- - [🚀HiDream](https://github.com/vipshop/cache-dit/raw/main/examples)
185
-
186
146
  </details>
187
147
 
188
148
  ## 🎉Unified Cache APIs
@@ -1,30 +1,30 @@
1
1
  cache_dit/__init__.py,sha256=kX9V-FegZG4c8LMwI4PTmMqH794MEW0pzDArdhC0cJw,1241
2
- cache_dit/_version.py,sha256=4_NDrwSRsA8gshfOOEHYB4RwOrbBlY3Re7Srt7YQl4M,706
2
+ cache_dit/_version.py,sha256=6uKAYeE03adIcUS0SDwp52AaQx0KO8z_-07D_lPHrz8,706
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
- cache_dit/utils.py,sha256=bMeZw377_mACEj3nV1tn5DTqypBsbUVvZWJYjNQQxPg,10399
4
+ cache_dit/utils.py,sha256=WK7eqgH6gCYNHXNLmWyxBDU0XSHTPg7CfOcyXlGXBqE,10510
5
5
  cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
6
6
  cache_dit/cache_factory/__init__.py,sha256=Iw6-iJLFbdzCsIDZXXOw371L-HPmoeZO_P9a3sDjP5s,1103
7
- cache_dit/cache_factory/cache_adapters.py,sha256=knNzV4BbCQmyiwsybFGl3LpTDEFtenykLb9-y_bbDpA,13905
8
- cache_dit/cache_factory/cache_interface.py,sha256=wy9QNNNNH4ncdyGepuvyJSJLbaTRCqqlvxtO6Os20yA,11317
9
- cache_dit/cache_factory/cache_types.py,sha256=FIFa6ZBfvvSMMHyBBhvarvgg2Y2wbRgITcG_uGylGe0,991
7
+ cache_dit/cache_factory/cache_adapters.py,sha256=TA_0mEHMdSQDrt4rYASeX4-BD8pJOznSJfMV3hkrGuk,17851
8
+ cache_dit/cache_factory/cache_interface.py,sha256=y1nY6R3MucRmAnG2UJRI_tIKrRk27FktGWLbfckf3zE,8543
9
+ cache_dit/cache_factory/cache_types.py,sha256=ooukxQRG55uTLmaZ0SKw6gIeY6SQHhMxkbv55uj2Sqk,991
10
10
  cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
11
11
  cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
12
- cache_dit/cache_factory/block_adapters/__init__.py,sha256=mtYPmsAYz4MGsMmanf6xZLaxZEkgE8gwB5mYhrM4nw4,15862
13
- cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=sRTAfhDxdj3VAm7ejyklGwM7HMimQSEgVroVkLx7CR8,20997
14
- cache_dit/cache_factory/block_adapters/block_registers.py,sha256=79HpWTX7PO2ynY8I-KnF6pa-ETV4Dlpbxn5wvp_iyvw,2387
12
+ cache_dit/cache_factory/block_adapters/__init__.py,sha256=EA-4mEVy-JJ5vRDo6C3nJIOXu0ZDNc6FQ-ZLAKHDtB0,17251
13
+ cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=OrKhuNdcGBCgSsPchdf4h32Ad-bQVUXNigMhPJ4cCvk,21069
14
+ cache_dit/cache_factory/block_adapters/block_registers.py,sha256=ZeN2wGPmuf2u3puSsBx8x-rl3wRo8-cWcuWNcrssVfA,2553
15
15
  cache_dit/cache_factory/cache_blocks/__init__.py,sha256=OWjnpJxA8EJVoRzuyb5miuiRphUFj831-bbtWsTDjnM,2750
16
16
  cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=ElMps6_7uI74tSF9GDR_dEI0bZEhdzcepM29xFWnYo8,428
17
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=CtBr6nvtAW8SAeEwPwiwWtPgrmwyb5ukb-j3IwFULJU,9953
17
+ cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=nf2f5wdxp6tfq9AhFyMyBeKiZfxh63WG1g8q-c2BBSg,10182
18
18
  cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=XSDy3hsaKbAZPGZY92YgGA0qLgjQyIX8irQkb2R5T2c,20331
19
19
  cache_dit/cache_factory/cache_blocks/utils.py,sha256=dGOC1tMMOvcbvEgx44eTESKn_jsv-0RZ3tRHPa3wmQ4,1315
20
20
  cache_dit/cache_factory/cache_contexts/__init__.py,sha256=rqnJ5__zqnpVHK5A1OqWILpNh5Ss-0ZDTGgtxZMKGGo,250
21
21
  cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=N88WLdd4KE9DuMWmpX8URcF55E2zWNwcKMxgVYkxMJY,13691
22
- cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=Tgk2-VFEhUp-oe-TFHzXay_YgbU8v90_Nx2G17ZlTds,34937
22
+ cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=_NUXcMYYEIVfDHpc4HJr9RUjU5RUEkZmAgFGE8bh5Wc,34883
23
23
  cache_dit/cache_factory/cache_contexts/taylorseer.py,sha256=etSUIZzDvqW3ScKCbccTPcFaSmxV1T-xAXdk-p3e3wk,3802
24
24
  cache_dit/cache_factory/patch_functors/__init__.py,sha256=yK05iONMGILsTZ83ynrUUJtiJKJ_FDjxmVIzRLy416s,252
25
25
  cache_dit/cache_factory/patch_functors/functor_base.py,sha256=Ahk0fTfrHgNdEl-9JSkACvfyyv9G-Ei5OSz7XBIlX5o,357
26
- cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=Z0kPAib0TkXGzJIP9FRK559UlBVuGQSZIVFir6tHzJM,10022
27
- cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=w_QaUwv7l7ypvFxWHzjHjAafLr1PxQcgv5N7VFjr6N8,9531
26
+ cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=N3UzszCM55g3GHeVdyXkid2_n72VJrfqBM2gdtD52gw,10042
27
+ cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=rJsbGEIxWPGnZyGI4ZwLLBdg8u6ZItsOeh0UoD_bVwk,9551
28
28
  cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
29
29
  cache_dit/compile/utils.py,sha256=nN2OIrSdwRR5zGxJinKDqb07pXpvTNTF3g_OgLkeeBU,3858
30
30
  cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -39,9 +39,9 @@ cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70
39
39
  cache_dit/quantize/quantize_ao.py,sha256=mGspqYgQtenl3QnKPtsSYsSD7LbVX93f1M940bhXKLU,6066
40
40
  cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
41
41
  cache_dit/quantize/quantize_svdq.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
- cache_dit-0.2.29.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
- cache_dit-0.2.29.dist-info/METADATA,sha256=wagdLaiAIX7fs5Rsw89DMsj5KQl1t3zGsTosKJe2AlQ,24540
44
- cache_dit-0.2.29.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
- cache_dit-0.2.29.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
- cache_dit-0.2.29.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
- cache_dit-0.2.29.dist-info/RECORD,,
42
+ cache_dit-0.2.30.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
+ cache_dit-0.2.30.dist-info/METADATA,sha256=8Ln_X5fw14U3greCM7cSukrei1SRiMDpksFalg5ZBAU,22130
44
+ cache_dit-0.2.30.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
+ cache_dit-0.2.30.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
+ cache_dit-0.2.30.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
+ cache_dit-0.2.30.dist-info/RECORD,,