cache-dit 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.30'
32
- __version_tuple__ = version_tuple = (0, 2, 30)
31
+ __version__ = version = '0.2.32'
32
+ __version_tuple__ = version_tuple = (0, 2, 32)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -254,7 +254,7 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
254
254
  )
255
255
 
256
256
 
257
- @BlockAdapterRegistry.register("SD3")
257
+ @BlockAdapterRegistry.register("StableDiffusion3")
258
258
  def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
259
259
  from diffusers import SD3Transformer2DModel
260
260
 
@@ -501,7 +501,7 @@ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
501
501
  )
502
502
 
503
503
 
504
- @BlockAdapterRegistry.register("HiDream", supported=True)
504
+ @BlockAdapterRegistry.register("HiDream")
505
505
  def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
506
506
  # NOTE: Need to patch Transformer forward to fully support
507
507
  # double_stream_blocks and single_stream_blocks, namely, need
@@ -509,29 +509,32 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
509
509
  # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
510
510
  # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
511
511
  from diffusers import HiDreamImageTransformer2DModel
512
+ from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
512
513
 
513
514
  assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
514
515
  return BlockAdapter(
515
516
  pipe=pipe,
516
517
  transformer=pipe.transformer,
517
518
  blocks=[
518
- # pipe.transformer.double_stream_blocks,
519
+ pipe.transformer.double_stream_blocks,
519
520
  pipe.transformer.single_stream_blocks,
520
521
  ],
521
522
  forward_pattern=[
522
- # ForwardPattern.Pattern_4,
523
+ ForwardPattern.Pattern_0,
523
524
  ForwardPattern.Pattern_3,
524
525
  ],
525
- # The type hint in diffusers is wrong
526
- check_num_outputs=False,
526
+ patch_functor=HiDreamPatchFunctor(),
527
+ # NOTE: The type hint in diffusers is wrong
528
+ check_forward_pattern=True,
529
+ check_num_outputs=True,
527
530
  **kwargs,
528
531
  )
529
532
 
530
533
 
531
- @BlockAdapterRegistry.register("HunyuanDiT", supported=False)
534
+ @BlockAdapterRegistry.register("HunyuanDiT")
532
535
  def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
533
- # TODO: Patch Transformer forward
534
536
  from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
537
+ from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
535
538
 
536
539
  assert isinstance(
537
540
  pipe.transformer,
@@ -542,14 +545,15 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
542
545
  transformer=pipe.transformer,
543
546
  blocks=pipe.transformer.blocks,
544
547
  forward_pattern=ForwardPattern.Pattern_3,
548
+ patch_functor=HunyuanDiTPatchFunctor(),
545
549
  **kwargs,
546
550
  )
547
551
 
548
552
 
549
- @BlockAdapterRegistry.register("HunyuanDiTPAG", supported=False)
553
+ @BlockAdapterRegistry.register("HunyuanDiTPAG")
550
554
  def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
551
- # TODO: Patch Transformer forward
552
555
  from diffusers import HunyuanDiT2DModel
556
+ from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
553
557
 
554
558
  assert isinstance(pipe.transformer, HunyuanDiT2DModel)
555
559
  return BlockAdapter(
@@ -557,5 +561,6 @@ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
557
561
  transformer=pipe.transformer,
558
562
  blocks=pipe.transformer.blocks,
559
563
  forward_pattern=ForwardPattern.Pattern_3,
564
+ patch_functor=HunyuanDiTPatchFunctor(),
560
565
  **kwargs,
561
566
  )
@@ -75,6 +75,7 @@ class BlockAdapter:
75
75
  List[List[ParamsModifier]],
76
76
  ] = None
77
77
 
78
+ check_forward_pattern: bool = True
78
79
  check_num_outputs: bool = False
79
80
 
80
81
  # Pipeline Level Flags
@@ -391,11 +392,20 @@ class BlockAdapter:
391
392
  forward_pattern: ForwardPattern,
392
393
  **kwargs,
393
394
  ) -> bool:
395
+
396
+ if not kwargs.get("check_forward_pattern", True):
397
+ return True
398
+
394
399
  assert (
395
400
  forward_pattern.Supported
396
401
  and forward_pattern in ForwardPattern.supported_patterns()
397
402
  ), f"Pattern {forward_pattern} is not support now!"
398
403
 
404
+ # NOTE: Special case for HiDreamBlock
405
+ if hasattr(block, "block"):
406
+ if isinstance(block.block, torch.nn.Module):
407
+ block = block.block
408
+
399
409
  forward_parameters = set(
400
410
  inspect.signature(block.forward).parameters.keys()
401
411
  )
@@ -425,6 +435,14 @@ class BlockAdapter:
425
435
  logging: bool = True,
426
436
  **kwargs,
427
437
  ) -> bool:
438
+
439
+ if not kwargs.get("check_forward_pattern", True):
440
+ if logging:
441
+ logger.warning(
442
+ f"Skipped Forward Pattern Check: {forward_pattern}"
443
+ )
444
+ return True
445
+
428
446
  assert (
429
447
  forward_pattern.Supported
430
448
  and forward_pattern in ForwardPattern.supported_patterns()
@@ -531,6 +549,7 @@ class BlockAdapter:
531
549
  blocks,
532
550
  forward_pattern=forward_pattern,
533
551
  check_num_outputs=adapter.check_num_outputs,
552
+ check_forward_pattern=adapter.check_forward_pattern,
534
553
  ), (
535
554
  "No block forward pattern matched, "
536
555
  f"supported lists: {ForwardPattern.supported_patterns()}"
@@ -114,7 +114,7 @@ class CachedAdapter:
114
114
  **cache_context_kwargs,
115
115
  ):
116
116
  # Check cache_context_kwargs
117
- if not cache_context_kwargs["enable_spearate_cfg"]:
117
+ if cache_context_kwargs["enable_spearate_cfg"] is None:
118
118
  # Check cfg for some specific case if users don't set it as True
119
119
  if BlockAdapterRegistry.has_separate_cfg(block_adapter):
120
120
  cache_context_kwargs["enable_spearate_cfg"] = True
@@ -131,6 +131,12 @@ class CachedAdapter:
131
131
  f"register: {cache_context_kwargs['enable_spearate_cfg']}, "
132
132
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
133
133
  )
134
+ else:
135
+ logger.info(
136
+ f"Use custom 'enable_spearate_cfg' from cache context "
137
+ f"kwargs: {cache_context_kwargs['enable_spearate_cfg']}. "
138
+ f"Pipeline: {block_adapter.pipe.__class__.__name__}."
139
+ )
134
140
 
135
141
  if (
136
142
  cache_type := cache_context_kwargs.pop("cache_type", None)
@@ -345,6 +351,7 @@ class CachedAdapter:
345
351
  block_adapter.blocks[i][j],
346
352
  transformer=block_adapter.transformer[i],
347
353
  forward_pattern=block_adapter.forward_pattern[i][j],
354
+ check_forward_pattern=block_adapter.check_forward_pattern,
348
355
  check_num_outputs=block_adapter.check_num_outputs,
349
356
  # 1. Cache context configuration
350
357
  cache_prefix=block_adapter.blocks_name[i][j],
@@ -25,6 +25,7 @@ class CachedBlocks:
25
25
  transformer_blocks: torch.nn.ModuleList,
26
26
  transformer: torch.nn.Module = None,
27
27
  forward_pattern: ForwardPattern = None,
28
+ check_forward_pattern: bool = True,
28
29
  check_num_outputs: bool = True,
29
30
  # 1. Cache context configuration
30
31
  # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
@@ -45,6 +46,7 @@ class CachedBlocks:
45
46
  transformer_blocks,
46
47
  transformer=transformer,
47
48
  forward_pattern=forward_pattern,
49
+ check_forward_pattern=check_forward_pattern,
48
50
  check_num_outputs=check_num_outputs,
49
51
  # 1. Cache context configuration
50
52
  cache_prefix=cache_prefix,
@@ -58,6 +60,7 @@ class CachedBlocks:
58
60
  transformer_blocks,
59
61
  transformer=transformer,
60
62
  forward_pattern=forward_pattern,
63
+ check_forward_pattern=check_forward_pattern,
61
64
  check_num_outputs=check_num_outputs,
62
65
  # 1. Cache context configuration
63
66
  cache_prefix=cache_prefix,
@@ -25,6 +25,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
25
25
  transformer_blocks: torch.nn.ModuleList,
26
26
  transformer: torch.nn.Module = None,
27
27
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
28
+ check_forward_pattern: bool = True,
28
29
  check_num_outputs: bool = True,
29
30
  # 1. Cache context configuration
30
31
  cache_prefix: str = None, # maybe un-need.
@@ -38,6 +39,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
38
39
  self.transformer = transformer
39
40
  self.transformer_blocks = transformer_blocks
40
41
  self.forward_pattern = forward_pattern
42
+ self.check_forward_pattern = check_forward_pattern
41
43
  self.check_num_outputs = check_num_outputs
42
44
  # 1. Cache context configuration
43
45
  self.cache_prefix = cache_prefix
@@ -52,6 +54,12 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
52
54
  )
53
55
 
54
56
  def _check_forward_pattern(self):
57
+ if not self.check_forward_pattern:
58
+ logger.warning(
59
+ f"Skipped Forward Pattern Check: {self.forward_pattern}"
60
+ )
61
+ return
62
+
55
63
  assert (
56
64
  self.forward_pattern.Supported
57
65
  and self.forward_pattern in self._supported_patterns
@@ -59,6 +67,11 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
59
67
 
60
68
  if self.transformer_blocks is not None:
61
69
  for block in self.transformer_blocks:
70
+ # Special case for HiDreamBlock
71
+ if hasattr(block, "block"):
72
+ if isinstance(block.block, torch.nn.Module):
73
+ block = block.block
74
+
62
75
  forward_parameters = set(
63
76
  inspect.signature(block.forward).parameters.keys()
64
77
  )
@@ -332,12 +345,19 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
332
345
 
333
346
  # compute hidden_states residual
334
347
  hidden_states = hidden_states.contiguous()
335
- encoder_hidden_states = encoder_hidden_states.contiguous()
336
348
 
337
349
  hidden_states_residual = hidden_states - original_hidden_states
338
- encoder_hidden_states_residual = (
339
- encoder_hidden_states - original_encoder_hidden_states
340
- )
350
+
351
+ if (
352
+ encoder_hidden_states is not None
353
+ and original_encoder_hidden_states is not None
354
+ ):
355
+ encoder_hidden_states = encoder_hidden_states.contiguous()
356
+ encoder_hidden_states_residual = (
357
+ encoder_hidden_states - original_encoder_hidden_states
358
+ )
359
+ else:
360
+ encoder_hidden_states_residual = None
341
361
 
342
362
  return (
343
363
  hidden_states,
@@ -387,9 +407,16 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
387
407
  Bn_i_hidden_states_residual = (
388
408
  hidden_states - Bn_i_original_hidden_states
389
409
  )
390
- Bn_i_encoder_hidden_states_residual = (
391
- encoder_hidden_states - Bn_i_original_encoder_hidden_states
392
- )
410
+ if (
411
+ encoder_hidden_states is not None
412
+ and Bn_i_original_encoder_hidden_states is not None
413
+ ):
414
+ Bn_i_encoder_hidden_states_residual = (
415
+ encoder_hidden_states
416
+ - Bn_i_original_encoder_hidden_states
417
+ )
418
+ else:
419
+ Bn_i_encoder_hidden_states_residual = None
393
420
 
394
421
  # Save original_hidden_states for diff calculation.
395
422
  self.cache_manager.set_Bn_buffer(
@@ -24,7 +24,7 @@ def enable_cache(
24
24
  max_continuous_cached_steps: int = -1,
25
25
  residual_diff_threshold: float = 0.08,
26
26
  # Cache CFG or not
27
- enable_spearate_cfg: bool = False,
27
+ enable_spearate_cfg: bool | None = None,
28
28
  cfg_compute_first: bool = False,
29
29
  cfg_diff_compute_separate: bool = True,
30
30
  # Hybird TaylorSeer
@@ -70,7 +70,7 @@ def enable_cache(
70
70
  residual_diff_threshold (`float`, *required*, defaults to 0.08):
71
71
  he value of residual diff threshold, a higher value leads to faster performance at the
72
72
  cost of lower precision.
73
- enable_spearate_cfg (`bool`, *required*, defaults to False):
73
+ enable_spearate_cfg (`bool`, *required*, defaults to None):
74
74
  Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
75
75
  and non-CFG into single forward step, should set enable_spearate_cfg as False, for example:
76
76
  CogVideoX, HunyuanVideo, Mochi, etc.
@@ -3,3 +3,9 @@ from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
3
3
  from cache_dit.cache_factory.patch_functors.functor_chroma import (
4
4
  ChromaPatchFunctor,
5
5
  )
6
+ from cache_dit.cache_factory.patch_functors.functor_hidream import (
7
+ HiDreamPatchFunctor,
8
+ )
9
+ from cache_dit.cache_factory.patch_functors.functor_hunyuan_dit import (
10
+ HunyuanDiTPatchFunctor,
11
+ )
@@ -46,8 +46,10 @@ class ChromaPatchFunctor(PatchFunctor):
46
46
  block.forward = __patch_single_forward__.__get__(block)
47
47
  is_patched = True
48
48
 
49
+ cls_name = transformer.__class__.__name__
50
+
49
51
  if is_patched:
50
- logger.warning("Patched Chroma for cache-dit.")
52
+ logger.warning(f"Patched {cls_name} for cache-dit.")
51
53
  assert not getattr(transformer, "_is_parallelized", False), (
52
54
  "Please call `cache_dit.enable_cache` before Parallelize, "
53
55
  "the __patch_transformer_forward__ will overwrite the "
@@ -59,7 +61,6 @@ class ChromaPatchFunctor(PatchFunctor):
59
61
 
60
62
  transformer._is_patched = is_patched # True or False
61
63
 
62
- cls_name = transformer.__class__.__name__
63
64
  logger.info(
64
65
  f"Applied {self.__class__.__name__} for {cls_name}, "
65
66
  f"Patch: {is_patched}."
@@ -47,8 +47,10 @@ class FluxPatchFunctor(PatchFunctor):
47
47
  block.forward = __patch_single_forward__.__get__(block)
48
48
  is_patched = True
49
49
 
50
+ cls_name = transformer.__class__.__name__
51
+
50
52
  if is_patched:
51
- logger.warning("Patched Flux for cache-dit.")
53
+ logger.warning(f"Patched {cls_name} for cache-dit.")
52
54
  assert not getattr(transformer, "_is_parallelized", False), (
53
55
  "Please call `cache_dit.enable_cache` before Parallelize, "
54
56
  "the __patch_transformer_forward__ will overwrite the "
@@ -60,7 +62,6 @@ class FluxPatchFunctor(PatchFunctor):
60
62
 
61
63
  transformer._is_patched = is_patched # True or False
62
64
 
63
- cls_name = transformer.__class__.__name__
64
65
  logger.info(
65
66
  f"Applied {self.__class__.__name__} for {cls_name}, "
66
67
  f"Patch: {is_patched}."
@@ -0,0 +1,412 @@
1
+ import torch
2
+ from typing import Tuple, Optional, Dict, Any, Union, List
3
+ from diffusers import HiDreamImageTransformer2DModel
4
+ from diffusers.models.transformers.transformer_hidream_image import (
5
+ HiDreamBlock,
6
+ HiDreamImageTransformerBlock,
7
+ HiDreamImageSingleTransformerBlock,
8
+ Transformer2DModelOutput,
9
+ )
10
+ from diffusers.utils import (
11
+ deprecate,
12
+ USE_PEFT_BACKEND,
13
+ scale_lora_layers,
14
+ unscale_lora_layers,
15
+ )
16
+ from cache_dit.cache_factory.patch_functors.functor_base import (
17
+ PatchFunctor,
18
+ )
19
+ from cache_dit.logger import init_logger
20
+
21
+ logger = init_logger(__name__)
22
+
23
+
24
+ class HiDreamPatchFunctor(PatchFunctor):
25
+
26
+ def apply(
27
+ self,
28
+ transformer: HiDreamImageTransformer2DModel,
29
+ **kwargs,
30
+ ) -> HiDreamImageTransformer2DModel:
31
+ if hasattr(transformer, "_is_patched"):
32
+ return transformer
33
+
34
+ is_patched = False
35
+
36
+ _block_id = 0
37
+ for block in transformer.double_stream_blocks:
38
+ assert isinstance(block, HiDreamBlock)
39
+ block.forward = __patch_block_forward__.__get__(block)
40
+ # NOTE: Patch Inner block and block_id
41
+ _block = block.block
42
+ assert isinstance(_block, HiDreamImageTransformerBlock)
43
+ _block._block_id = _block_id
44
+ _block.forward = __patch_double_forward__.__get__(_block)
45
+ _block_id += 1
46
+
47
+ for block in transformer.single_stream_blocks:
48
+ assert isinstance(block, HiDreamBlock)
49
+ block.forward = __patch_block_forward__.__get__(block)
50
+ # NOTE: Patch Inner block and block_id
51
+ _block = block.block
52
+ assert isinstance(_block, HiDreamImageSingleTransformerBlock)
53
+ _block._block_id = _block_id
54
+ _block.forward = __patch_single_forward__.__get__(_block)
55
+ _block_id += 1
56
+
57
+ is_patched = True
58
+ cls_name = transformer.__class__.__name__
59
+
60
+ if is_patched:
61
+ logger.warning(f"Patched {cls_name} for cache-dit.")
62
+ assert not getattr(transformer, "_is_parallelized", False), (
63
+ "Please call `cache_dit.enable_cache` before Parallelize, "
64
+ "the __patch_transformer_forward__ will overwrite the "
65
+ "parallized forward and cause a downgrade of performance."
66
+ )
67
+ transformer.forward = __patch_transformer_forward__.__get__(
68
+ transformer
69
+ )
70
+
71
+ transformer._is_patched = is_patched # True or False
72
+
73
+ logger.info(
74
+ f"Applied {self.__class__.__name__} for {cls_name}, "
75
+ f"Patch: {is_patched}."
76
+ )
77
+
78
+ return transformer
79
+
80
+
81
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
82
+ def __patch_double_forward__(
83
+ self: HiDreamImageTransformerBlock,
84
+ hidden_states: torch.Tensor,
85
+ encoder_hidden_states: torch.Tensor, # initial_encoder_hidden_states
86
+ hidden_states_masks: Optional[torch.Tensor] = None,
87
+ temb: Optional[torch.Tensor] = None,
88
+ image_rotary_emb: torch.Tensor = None,
89
+ llama31_encoder_hidden_states: List[torch.Tensor] = None,
90
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
91
+ # Assume block_id was patched in transformer forward:
92
+ # for i, block in enumerate(blocks): block._block_id = i;
93
+ block_id = self._block_id
94
+ initial_encoder_hidden_states_seq_len = encoder_hidden_states.shape[1]
95
+ cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[block_id]
96
+ cur_encoder_hidden_states = torch.cat(
97
+ [encoder_hidden_states, cur_llama31_encoder_hidden_states],
98
+ dim=1,
99
+ )
100
+ encoder_hidden_states = cur_encoder_hidden_states
101
+
102
+ wtype = hidden_states.dtype
103
+ (
104
+ shift_msa_i,
105
+ scale_msa_i,
106
+ gate_msa_i,
107
+ shift_mlp_i,
108
+ scale_mlp_i,
109
+ gate_mlp_i,
110
+ shift_msa_t,
111
+ scale_msa_t,
112
+ gate_msa_t,
113
+ shift_mlp_t,
114
+ scale_mlp_t,
115
+ gate_mlp_t,
116
+ ) = self.adaLN_modulation(temb)[:, None].chunk(12, dim=-1)
117
+
118
+ # 1. MM-Attention
119
+ norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
120
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
121
+ norm_encoder_hidden_states = self.norm1_t(encoder_hidden_states).to(
122
+ dtype=wtype
123
+ )
124
+ norm_encoder_hidden_states = (
125
+ norm_encoder_hidden_states * (1 + scale_msa_t) + shift_msa_t
126
+ )
127
+
128
+ attn_output_i, attn_output_t = self.attn1(
129
+ norm_hidden_states,
130
+ hidden_states_masks,
131
+ norm_encoder_hidden_states,
132
+ image_rotary_emb=image_rotary_emb,
133
+ )
134
+
135
+ hidden_states = gate_msa_i * attn_output_i + hidden_states
136
+ encoder_hidden_states = gate_msa_t * attn_output_t + encoder_hidden_states
137
+
138
+ # 2. Feed-forward
139
+ norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
140
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
141
+ norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(
142
+ dtype=wtype
143
+ )
144
+ norm_encoder_hidden_states = (
145
+ norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t
146
+ )
147
+
148
+ ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
149
+ ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
150
+ hidden_states = ff_output_i + hidden_states
151
+ encoder_hidden_states = ff_output_t + encoder_hidden_states
152
+
153
+ initial_encoder_hidden_states = encoder_hidden_states[
154
+ :, :initial_encoder_hidden_states_seq_len
155
+ ]
156
+
157
+ return hidden_states, initial_encoder_hidden_states
158
+
159
+
160
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
161
+ def __patch_single_forward__(
162
+ self: HiDreamImageSingleTransformerBlock,
163
+ hidden_states: torch.Tensor,
164
+ hidden_states_masks: Optional[torch.Tensor] = None,
165
+ temb: Optional[torch.Tensor] = None,
166
+ image_rotary_emb: torch.Tensor = None,
167
+ llama31_encoder_hidden_states: List[torch.Tensor] = None,
168
+ ) -> torch.Tensor:
169
+ # Assume block_id was patched in transformer forward:
170
+ # for i, block in enumerate(blocks): block._block_id = i;
171
+ block_id = self._block_id
172
+ hidden_states_seq_len = hidden_states.shape[1]
173
+ cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[block_id]
174
+ hidden_states = torch.cat(
175
+ [hidden_states, cur_llama31_encoder_hidden_states], dim=1
176
+ )
177
+
178
+ wtype = hidden_states.dtype
179
+ (
180
+ shift_msa_i,
181
+ scale_msa_i,
182
+ gate_msa_i,
183
+ shift_mlp_i,
184
+ scale_mlp_i,
185
+ gate_mlp_i,
186
+ ) = self.adaLN_modulation(temb)[:, None].chunk(6, dim=-1)
187
+
188
+ # 1. MM-Attention
189
+ norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
190
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
191
+ attn_output_i = self.attn1(
192
+ norm_hidden_states,
193
+ hidden_states_masks,
194
+ image_rotary_emb=image_rotary_emb,
195
+ )
196
+ hidden_states = gate_msa_i * attn_output_i + hidden_states
197
+
198
+ # 2. Feed-forward
199
+ norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
200
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
201
+ ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
202
+ hidden_states = ff_output_i + hidden_states
203
+
204
+ hidden_states = hidden_states[:, :hidden_states_seq_len]
205
+
206
+ return hidden_states
207
+
208
+
209
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
210
+ def __patch_block_forward__(
211
+ self: HiDreamBlock,
212
+ hidden_states: torch.Tensor,
213
+ *args,
214
+ **kwargs,
215
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
216
+ return self.block(hidden_states, *args, **kwargs)
217
+
218
+
219
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
220
+ def __patch_transformer_forward__(
221
+ self: HiDreamImageTransformer2DModel,
222
+ hidden_states: torch.Tensor,
223
+ timesteps: torch.LongTensor = None,
224
+ encoder_hidden_states_t5: torch.Tensor = None,
225
+ encoder_hidden_states_llama3: torch.Tensor = None,
226
+ pooled_embeds: torch.Tensor = None,
227
+ img_ids: Optional[torch.Tensor] = None,
228
+ img_sizes: Optional[List[Tuple[int, int]]] = None,
229
+ hidden_states_masks: Optional[torch.Tensor] = None,
230
+ attention_kwargs: Optional[Dict[str, Any]] = None,
231
+ return_dict: bool = True,
232
+ **kwargs,
233
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
234
+ encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
235
+
236
+ if encoder_hidden_states is not None:
237
+ deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
238
+ deprecate("encoder_hidden_states", "0.35.0", deprecation_message)
239
+ encoder_hidden_states_t5 = encoder_hidden_states[0]
240
+ encoder_hidden_states_llama3 = encoder_hidden_states[1]
241
+
242
+ if (
243
+ img_ids is not None
244
+ and img_sizes is not None
245
+ and hidden_states_masks is None
246
+ ):
247
+ deprecation_message = "Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
248
+ deprecate("img_ids", "0.35.0", deprecation_message)
249
+
250
+ if hidden_states_masks is not None and (
251
+ img_ids is None or img_sizes is None
252
+ ):
253
+ raise ValueError(
254
+ "if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed."
255
+ )
256
+ elif hidden_states_masks is not None and hidden_states.ndim != 3:
257
+ raise ValueError(
258
+ "if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
259
+ )
260
+
261
+ if attention_kwargs is not None:
262
+ attention_kwargs = attention_kwargs.copy()
263
+ lora_scale = attention_kwargs.pop("scale", 1.0)
264
+ else:
265
+ lora_scale = 1.0
266
+
267
+ if USE_PEFT_BACKEND:
268
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
269
+ scale_lora_layers(self, lora_scale)
270
+ else:
271
+ if (
272
+ attention_kwargs is not None
273
+ and attention_kwargs.get("scale", None) is not None
274
+ ):
275
+ logger.warning(
276
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
277
+ )
278
+
279
+ # spatial forward
280
+ batch_size = hidden_states.shape[0]
281
+ hidden_states_type = hidden_states.dtype
282
+
283
+ # Patchify the input
284
+ if hidden_states_masks is None:
285
+ hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(
286
+ hidden_states
287
+ )
288
+
289
+ # Embed the hidden states
290
+ hidden_states = self.x_embedder(hidden_states)
291
+
292
+ # 0. time
293
+ timesteps = self.t_embedder(timesteps, hidden_states_type)
294
+ p_embedder = self.p_embedder(pooled_embeds)
295
+ temb = timesteps + p_embedder
296
+
297
+ encoder_hidden_states = [
298
+ encoder_hidden_states_llama3[k] for k in self.config.llama_layers
299
+ ]
300
+
301
+ if self.caption_projection is not None:
302
+ new_encoder_hidden_states = []
303
+ for i, enc_hidden_state in enumerate(encoder_hidden_states):
304
+ enc_hidden_state = self.caption_projection[i](enc_hidden_state)
305
+ enc_hidden_state = enc_hidden_state.view(
306
+ batch_size, -1, hidden_states.shape[-1]
307
+ )
308
+ new_encoder_hidden_states.append(enc_hidden_state)
309
+ encoder_hidden_states = new_encoder_hidden_states
310
+ encoder_hidden_states_t5 = self.caption_projection[-1](
311
+ encoder_hidden_states_t5
312
+ )
313
+ encoder_hidden_states_t5 = encoder_hidden_states_t5.view(
314
+ batch_size, -1, hidden_states.shape[-1]
315
+ )
316
+ encoder_hidden_states.append(encoder_hidden_states_t5)
317
+
318
+ txt_ids = torch.zeros(
319
+ batch_size,
320
+ encoder_hidden_states[-1].shape[1]
321
+ + encoder_hidden_states[-2].shape[1]
322
+ + encoder_hidden_states[0].shape[1],
323
+ 3,
324
+ device=img_ids.device,
325
+ dtype=img_ids.dtype,
326
+ )
327
+ ids = torch.cat((img_ids, txt_ids), dim=1)
328
+ image_rotary_emb = self.pe_embedder(ids)
329
+
330
+ # 2. Blocks
331
+ # NOTE: block_id is no-need anymore.
332
+ initial_encoder_hidden_states = torch.cat(
333
+ [encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1
334
+ )
335
+ llama31_encoder_hidden_states = encoder_hidden_states
336
+ for bid, block in enumerate(self.double_stream_blocks):
337
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
338
+ hidden_states, initial_encoder_hidden_states = (
339
+ self._gradient_checkpointing_func(
340
+ block,
341
+ hidden_states,
342
+ initial_encoder_hidden_states,
343
+ hidden_states_masks,
344
+ temb,
345
+ image_rotary_emb,
346
+ llama31_encoder_hidden_states,
347
+ )
348
+ )
349
+ else:
350
+ hidden_states, initial_encoder_hidden_states = block(
351
+ hidden_states,
352
+ initial_encoder_hidden_states, # encoder_hidden_states
353
+ hidden_states_masks=hidden_states_masks,
354
+ temb=temb,
355
+ image_rotary_emb=image_rotary_emb,
356
+ llama31_encoder_hidden_states=llama31_encoder_hidden_states,
357
+ )
358
+
359
+ image_tokens_seq_len = hidden_states.shape[1]
360
+ hidden_states = torch.cat(
361
+ [hidden_states, initial_encoder_hidden_states], dim=1
362
+ )
363
+ if hidden_states_masks is not None:
364
+ # NOTE: Patched
365
+ cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[
366
+ self.double_stream_blocks[-1].block._block_id
367
+ ]
368
+ encoder_attention_mask_ones = torch.ones(
369
+ (
370
+ batch_size,
371
+ initial_encoder_hidden_states.shape[1]
372
+ + cur_llama31_encoder_hidden_states.shape[1],
373
+ ),
374
+ device=hidden_states_masks.device,
375
+ dtype=hidden_states_masks.dtype,
376
+ )
377
+ hidden_states_masks = torch.cat(
378
+ [hidden_states_masks, encoder_attention_mask_ones], dim=1
379
+ )
380
+
381
+ for bid, block in enumerate(self.single_stream_blocks):
382
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
383
+ hidden_states = self._gradient_checkpointing_func(
384
+ block,
385
+ hidden_states,
386
+ hidden_states_masks,
387
+ temb,
388
+ image_rotary_emb,
389
+ llama31_encoder_hidden_states,
390
+ )
391
+ else:
392
+ hidden_states = block(
393
+ hidden_states,
394
+ hidden_states_masks=hidden_states_masks,
395
+ temb=temb,
396
+ image_rotary_emb=image_rotary_emb,
397
+ llama31_encoder_hidden_states=llama31_encoder_hidden_states,
398
+ )
399
+
400
+ hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
401
+ output = self.final_layer(hidden_states, temb)
402
+ output = self.unpatchify(output, img_sizes, self.training)
403
+ if hidden_states_masks is not None:
404
+ hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
405
+
406
+ if USE_PEFT_BACKEND:
407
+ # remove `lora_scale` from each PEFT layer
408
+ unscale_lora_layers(self, lora_scale)
409
+
410
+ if not return_dict:
411
+ return (output,)
412
+ return Transformer2DModelOutput(sample=output)
@@ -0,0 +1,213 @@
1
+ import torch
2
+ from typing import Optional, Union, List
3
+ from diffusers import HunyuanDiT2DModel
4
+ from diffusers.models.transformers.hunyuan_transformer_2d import (
5
+ HunyuanDiTBlock,
6
+ Transformer2DModelOutput,
7
+ )
8
+ from cache_dit.cache_factory.patch_functors.functor_base import (
9
+ PatchFunctor,
10
+ )
11
+ from cache_dit.logger import init_logger
12
+
13
+ logger = init_logger(__name__)
14
+
15
+
16
+ class HunyuanDiTPatchFunctor(PatchFunctor):
17
+
18
+ def apply(
19
+ self,
20
+ transformer: HunyuanDiT2DModel,
21
+ **kwargs,
22
+ ) -> HunyuanDiT2DModel:
23
+ if hasattr(transformer, "_is_patched"):
24
+ return transformer
25
+
26
+ is_patched = False
27
+
28
+ num_layers = transformer.config.num_layers
29
+ layer_id = 0
30
+ for block in transformer.blocks:
31
+ assert isinstance(block, HunyuanDiTBlock)
32
+ block._num_layers = num_layers
33
+ block._layer_id = layer_id
34
+ block.forward = __patch_block_forward__.__get__(block)
35
+ layer_id += 1
36
+
37
+ is_patched = True
38
+
39
+ cls_name = transformer.__class__.__name__
40
+
41
+ if is_patched:
42
+ logger.warning(f"Patched {cls_name} for cache-dit.")
43
+ assert not getattr(transformer, "_is_parallelized", False), (
44
+ "Please call `cache_dit.enable_cache` before Parallelize, "
45
+ "the __patch_transformer_forward__ will overwrite the "
46
+ "parallized forward and cause a downgrade of performance."
47
+ )
48
+ transformer.forward = __patch_transformer_forward__.__get__(
49
+ transformer
50
+ )
51
+
52
+ transformer._is_patched = is_patched # True or False
53
+
54
+ logger.info(
55
+ f"Applied {self.__class__.__name__} for {cls_name}, "
56
+ f"Patch: {is_patched}."
57
+ )
58
+
59
+ return transformer
60
+
61
+
62
+ def __patch_block_forward__(
63
+ self: HunyuanDiTBlock,
64
+ hidden_states: torch.Tensor,
65
+ encoder_hidden_states: Optional[torch.Tensor] = None,
66
+ temb: Optional[torch.Tensor] = None,
67
+ image_rotary_emb: torch.Tensor = None,
68
+ controlnet_block_samples: torch.Tensor = None,
69
+ skips: List[torch.Tensor] = [],
70
+ ) -> torch.Tensor:
71
+ # Notice that normalization is always applied before the real computation in the following blocks.
72
+ # 0. Long Skip Connection
73
+ num_layers = self._num_layers
74
+ layer_id = self._layer_id
75
+
76
+ if layer_id > num_layers // 2:
77
+ if controlnet_block_samples is not None:
78
+ skip = skips.pop() + controlnet_block_samples.pop()
79
+ else:
80
+ skip = skips.pop()
81
+ else:
82
+ skip = None
83
+
84
+ if self.skip_linear is not None:
85
+ cat = torch.cat([hidden_states, skip], dim=-1)
86
+ cat = self.skip_norm(cat)
87
+ hidden_states = self.skip_linear(cat)
88
+
89
+ # 1. Self-Attention
90
+ norm_hidden_states = self.norm1(
91
+ hidden_states, temb
92
+ ) # checked: self.norm1 is correct
93
+ attn_output = self.attn1(
94
+ norm_hidden_states,
95
+ image_rotary_emb=image_rotary_emb,
96
+ )
97
+ hidden_states = hidden_states + attn_output
98
+
99
+ # 2. Cross-Attention
100
+ hidden_states = hidden_states + self.attn2(
101
+ self.norm2(hidden_states),
102
+ encoder_hidden_states=encoder_hidden_states,
103
+ image_rotary_emb=image_rotary_emb,
104
+ )
105
+
106
+ # FFN Layer
107
+ mlp_inputs = self.norm3(hidden_states)
108
+ hidden_states = hidden_states + self.ff(mlp_inputs)
109
+
110
+ if layer_id < (num_layers // 2 - 1):
111
+ skips.append(hidden_states)
112
+
113
+ return hidden_states
114
+
115
+
116
+ def __patch_transformer_forward__(
117
+ self: HunyuanDiT2DModel,
118
+ hidden_states,
119
+ timestep,
120
+ encoder_hidden_states=None,
121
+ text_embedding_mask=None,
122
+ encoder_hidden_states_t5=None,
123
+ text_embedding_mask_t5=None,
124
+ image_meta_size=None,
125
+ style=None,
126
+ image_rotary_emb=None,
127
+ controlnet_block_samples=None,
128
+ return_dict=True,
129
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
130
+ height, width = hidden_states.shape[-2:]
131
+
132
+ hidden_states = self.pos_embed(hidden_states)
133
+
134
+ temb = self.time_extra_emb(
135
+ timestep,
136
+ encoder_hidden_states_t5,
137
+ image_meta_size,
138
+ style,
139
+ hidden_dtype=timestep.dtype,
140
+ ) # [B, D]
141
+
142
+ # text projection
143
+ batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
144
+ encoder_hidden_states_t5 = self.text_embedder(
145
+ encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
146
+ )
147
+ encoder_hidden_states_t5 = encoder_hidden_states_t5.view(
148
+ batch_size, sequence_length, -1
149
+ )
150
+
151
+ encoder_hidden_states = torch.cat(
152
+ [encoder_hidden_states, encoder_hidden_states_t5], dim=1
153
+ )
154
+ text_embedding_mask = torch.cat(
155
+ [text_embedding_mask, text_embedding_mask_t5], dim=-1
156
+ )
157
+ text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
158
+
159
+ encoder_hidden_states = torch.where(
160
+ text_embedding_mask, encoder_hidden_states, self.text_embedding_padding
161
+ )
162
+
163
+ skips = []
164
+ for layer, block in enumerate(self.blocks):
165
+ hidden_states = block(
166
+ hidden_states,
167
+ temb=temb,
168
+ encoder_hidden_states=encoder_hidden_states,
169
+ image_rotary_emb=image_rotary_emb,
170
+ controlnet_block_samples=controlnet_block_samples,
171
+ skips=skips,
172
+ ) # (N, L, D)
173
+
174
+ if (
175
+ controlnet_block_samples is not None
176
+ and len(controlnet_block_samples) != 0
177
+ ):
178
+ raise ValueError(
179
+ "The number of controls is not equal to the number of skip connections."
180
+ )
181
+
182
+ # final layer
183
+ hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
184
+ hidden_states = self.proj_out(hidden_states)
185
+ # (N, L, patch_size ** 2 * out_channels)
186
+
187
+ # unpatchify: (N, out_channels, H, W)
188
+ patch_size = self.pos_embed.patch_size
189
+ height = height // patch_size
190
+ width = width // patch_size
191
+
192
+ hidden_states = hidden_states.reshape(
193
+ shape=(
194
+ hidden_states.shape[0],
195
+ height,
196
+ width,
197
+ patch_size,
198
+ patch_size,
199
+ self.out_channels,
200
+ )
201
+ )
202
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
203
+ output = hidden_states.reshape(
204
+ shape=(
205
+ hidden_states.shape[0],
206
+ self.out_channels,
207
+ height * patch_size,
208
+ width * patch_size,
209
+ )
210
+ )
211
+ if not return_dict:
212
+ return (output,)
213
+ return Transformer2DModelOutput(sample=output)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.30
3
+ Version: 0.2.32
4
4
  Summary: 🤗 A Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -59,29 +59,37 @@ Dynamic: requires-python
59
59
  🔥<b><a href="#unified">Unified Cache APIs</a> | <a href="#dbcache">DBCache</a> | <a href="#taylorseer">Hybrid TaylorSeer</a> | <a href="#cfg">Hybrid Cache CFG</a></b>🔥
60
60
  </p>
61
61
  <p align="center">
62
- 🎉Now, <b>cache-dit</b> covers <b>mainstream</b> Diffusers' <b>DiT-based</b> Pipelines🎉<br>
63
- 🔥<b><a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1/2.2</a> | <a href="#supported"> ... </a> | <a href="#supported">CogVideoX</a></b>🔥
62
+ 🎉Now, <b>cache-dit</b> covers <b>most</b> mainstream Diffusers' <b>DiT</b> Pipelines🎉<br>
63
+ 🔥<a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Qwen-Image-Lightning</a> | <a href="#supported"> Wan 2.1/2.2 </a>🔥<br>
64
+ 🔥<a href="#supported">HunyuanVideo</a> | <a href="#supported">HunyuanDiT</a> | <a href="#supported">HiDream</a> | <a href="#supported">Mochi</a> | <a href="#supported">CogVideoX 1/1.5</a>🔥<br>
65
+ 🔥<a href="#supported">CogView3Plus</a> | <a href="#supported">CogView4</a> | <a href="#supported">Chroma</a> | <a href="#supported"> LTXVideo </a> | <a href="#supported">PixArt</a>🔥<br>
66
+ 🔥<a href="#supported">Cosmos</a> | <a href="#supported">SkyReelsV2</a> | <a href="#supported">VisualCloze</a> | <a href="#supported"> OmniGen </a> | <a href="#supported">Lumina 1/2</a>🔥<br>
67
+ 🔥<a href="#supported">Allegro</a> | <a href="#supported">EasyAnimate</a> | <a href="#supported">SD 3/3.5</a> | <a href="#supported"> ... </a> | <a href="#supported">DiT-XL</a>🔥
64
68
  </p>
65
69
  </div>
66
70
  <div align='center'>
67
- <img src=./assets/gifs/wan2.2.C0_Q0_NONE.gif width=160px>
68
- <img src=./assets/gifs/wan2.2.C1_Q0_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
69
- <img src=./assets/gifs/wan2.2.C1_Q1_fp8_w8a8_dq_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
70
- <p><b>🔥Wan2.2 MoE</b> Baseline | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:~2.0x↑🎉</b> | +FP8 DQ:<b>~2.4x↑🎉</b></p>
71
- <img src=./assets/qwen-image.C0_Q0_NONE.png width=160px>
72
- <img src=./assets/qwen-image.C1_Q0_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S23.png width=160px>
73
- <img src=./assets/qwen-image.C1_Q1_fp8_w8a8_dq_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S18.png width=160px>
74
- <p><b>🔥Qwen-Image</b> Baseline | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:~1.8x↑🎉</b> | +FP8 DQ:<b>~2.2x↑🎉</b><br>♥️ Please consider to leave a <b>⭐️ Star</b> to support us ~ ♥️</p>
75
- </p>
71
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C0_Q0_NONE.gif width=160px>
72
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q0_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
73
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q1_fp8_w8a8_dq_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
74
+ <p><b>🔥Wan2.2 MoE</b> | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:~2.0x↑🎉</b> | +FP8 DQ:<b>~2.4x↑🎉</b></p>
75
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C0_Q0_NONE.png width=160px>
76
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C1_Q0_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S23.png width=160px>
77
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C1_Q1_fp8_w8a8_dq_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S18.png width=160px>
78
+ <p><b>🔥Qwen-Image</b> | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:~1.8x↑🎉</b> | +FP8 DQ:<b>~2.2x↑🎉</b></p>
79
+ <img src=./assets/qwen-image-lightning.4steps.C0_L1_Q0_NONE.png width=200px>
80
+ <img src=./assets/qwen-image-lightning.4steps.C0_L1_Q0_DBCACHE_F16B16_W2M1MC1_T0O2_R0.9_S1.png width=200px>
81
+ <p><b>🔥Qwen-Image-Lightning</b> 4 steps | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a></b> 3.5 steps:<b>~1.14x↑🎉</b>
82
+ <br>♥️ Please consider to leave a <b>⭐️ Star</b> to support us ~ ♥️</p>
76
83
  </div>
77
84
 
78
85
  ## 🔥News
79
86
 
87
+ - [2025-09-08] 🔥[**Qwen-Image-Lightning**](./examples/pipeline/run_qwen_image_lightning.py) **7.1/3.5 steps🎉** inference with **[DBCache: F16B16](https://github.com/vipshop/cache-dit)**.
80
88
  - [2025-09-03] 🎉[**Wan2.2-MoE**](https://github.com/Wan-Video) **2.4x↑🎉** speedup! Please refer to [run_wan_2.2.py](./examples/pipeline/run_wan_2.2.py) as an example.
81
89
  - [2025-08-19] 🔥[**Qwen-Image-Edit**](https://github.com/QwenLM/Qwen-Image) **2x↑🎉** speedup! Check the example: [run_qwen_image_edit.py](./examples/pipeline/run_qwen_image_edit.py).
82
90
  - [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
83
91
  - [2025-08-11] 🔥[**Qwen-Image**](https://github.com/QwenLM/Qwen-Image) **1.8x↑🎉** speedup! Please refer to [run_qwen_image.py](./examples/pipeline/run_qwen_image.py) as an example.
84
- - [2025-07-13] 🎉[**FLUX.1-Dev**](https://github.com/xlite-dev/flux-faster) **3.3x↑🎉** speedup! NVIDIA L20 with **[cache-dit](https://github.com/vipshop/cache-dit)** + `compile + FP8 DQ`.
92
+ - [2025-07-13] 🎉[**FLUX.1-dev**](https://github.com/xlite-dev/flux-faster) **3.3x↑🎉** speedup! NVIDIA L20 with **[cache-dit](https://github.com/vipshop/cache-dit)** + **compile + FP8 DQ**.
85
93
 
86
94
  <details>
87
95
  <summary> Previous News </summary>
@@ -131,6 +139,7 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
131
139
 
132
140
  Currently, **cache-dit** library supports almost **Any** Diffusion Transformers (with **Transformer Blocks** that match the specific Input and Output **patterns**). Please check [🎉Unified Cache APIs](#unified) for more details. Here are just some of the tested models listed:
133
141
 
142
+ - [🚀Qwen-Image-Lightning](https://github.com/vipshop/cache-dit/raw/main/examples)
134
143
  - [🚀Qwen-Image-Edit](https://github.com/vipshop/cache-dit/raw/main/examples)
135
144
  - [🚀Qwen-Image](https://github.com/vipshop/cache-dit/raw/main/examples)
136
145
  - [🚀FLUX.1-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
@@ -141,7 +150,13 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
141
150
  - [🚀Wan2.2-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
142
151
  - [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
143
152
  - [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
153
+ - [🚀mochi-1-preview](https://github.com/vipshop/cache-dit/raw/main/examples)
144
154
  - [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
155
+ - [🚀HunyuanDiT](https://github.com/vipshop/cache-dit/raw/main/examples)
156
+ - [🚀HiDream-I1-Full](https://github.com/vipshop/cache-dit/raw/main/examples)
157
+ - [🚀PixArt-Alpha](https://github.com/vipshop/cache-dit/raw/main/examples)
158
+ - [🚀PixArt-Sigma](https://github.com/vipshop/cache-dit/raw/main/examples)
159
+ - [🚀SD-3/3.5](https://github.com/vipshop/cache-dit/raw/main/examples)
145
160
 
146
161
  </details>
147
162
 
@@ -1,30 +1,32 @@
1
1
  cache_dit/__init__.py,sha256=kX9V-FegZG4c8LMwI4PTmMqH794MEW0pzDArdhC0cJw,1241
2
- cache_dit/_version.py,sha256=6uKAYeE03adIcUS0SDwp52AaQx0KO8z_-07D_lPHrz8,706
2
+ cache_dit/_version.py,sha256=J0YTFDgdG9rY1Xk5pUbWWGgbT2rbSasvUHcntxayVtA,706
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/utils.py,sha256=WK7eqgH6gCYNHXNLmWyxBDU0XSHTPg7CfOcyXlGXBqE,10510
5
5
  cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
6
6
  cache_dit/cache_factory/__init__.py,sha256=Iw6-iJLFbdzCsIDZXXOw371L-HPmoeZO_P9a3sDjP5s,1103
7
- cache_dit/cache_factory/cache_adapters.py,sha256=TA_0mEHMdSQDrt4rYASeX4-BD8pJOznSJfMV3hkrGuk,17851
8
- cache_dit/cache_factory/cache_interface.py,sha256=y1nY6R3MucRmAnG2UJRI_tIKrRk27FktGWLbfckf3zE,8543
7
+ cache_dit/cache_factory/cache_adapters.py,sha256=dmNX68nBD52HtQvHnNAuSn1zjDWrQdycD0qXy-w-mwc,18212
8
+ cache_dit/cache_factory/cache_interface.py,sha256=LpyCy-tQ_GcTRAYLpMMf9hFVIktABHI6CObn5Ll8bMw,8548
9
9
  cache_dit/cache_factory/cache_types.py,sha256=ooukxQRG55uTLmaZ0SKw6gIeY6SQHhMxkbv55uj2Sqk,991
10
10
  cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
11
11
  cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
12
- cache_dit/cache_factory/block_adapters/__init__.py,sha256=EA-4mEVy-JJ5vRDo6C3nJIOXu0ZDNc6FQ-ZLAKHDtB0,17251
13
- cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=OrKhuNdcGBCgSsPchdf4h32Ad-bQVUXNigMhPJ4cCvk,21069
12
+ cache_dit/cache_factory/block_adapters/__init__.py,sha256=OZM5vJwmQIkoIwVmMxKXiHqKvs31NyAva1Z91C_ko3w,17547
13
+ cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=EQBiJYyoInKU1ND69wTm7M0n5Ja4I8QW01SgRpBjSn8,21671
14
14
  cache_dit/cache_factory/block_adapters/block_registers.py,sha256=ZeN2wGPmuf2u3puSsBx8x-rl3wRo8-cWcuWNcrssVfA,2553
15
- cache_dit/cache_factory/cache_blocks/__init__.py,sha256=OWjnpJxA8EJVoRzuyb5miuiRphUFj831-bbtWsTDjnM,2750
15
+ cache_dit/cache_factory/cache_blocks/__init__.py,sha256=08Ox7kD05lkRKCOsVTdEZeKAWBheqpxfrAT1Nz7eclI,2916
16
16
  cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=ElMps6_7uI74tSF9GDR_dEI0bZEhdzcepM29xFWnYo8,428
17
17
  cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=nf2f5wdxp6tfq9AhFyMyBeKiZfxh63WG1g8q-c2BBSg,10182
18
- cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=XSDy3hsaKbAZPGZY92YgGA0qLgjQyIX8irQkb2R5T2c,20331
18
+ cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=f1ojREQcDoBtDG3dzl8t1g_Vru8140LVDRPWlY-kAXw,21311
19
19
  cache_dit/cache_factory/cache_blocks/utils.py,sha256=dGOC1tMMOvcbvEgx44eTESKn_jsv-0RZ3tRHPa3wmQ4,1315
20
20
  cache_dit/cache_factory/cache_contexts/__init__.py,sha256=rqnJ5__zqnpVHK5A1OqWILpNh5Ss-0ZDTGgtxZMKGGo,250
21
21
  cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=N88WLdd4KE9DuMWmpX8URcF55E2zWNwcKMxgVYkxMJY,13691
22
22
  cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=_NUXcMYYEIVfDHpc4HJr9RUjU5RUEkZmAgFGE8bh5Wc,34883
23
23
  cache_dit/cache_factory/cache_contexts/taylorseer.py,sha256=etSUIZzDvqW3ScKCbccTPcFaSmxV1T-xAXdk-p3e3wk,3802
24
- cache_dit/cache_factory/patch_functors/__init__.py,sha256=yK05iONMGILsTZ83ynrUUJtiJKJ_FDjxmVIzRLy416s,252
24
+ cache_dit/cache_factory/patch_functors/__init__.py,sha256=06zdddrjvSCgBzJ0a8niRHd3ucF2qsbzlbL00d4aCvk,451
25
25
  cache_dit/cache_factory/patch_functors/functor_base.py,sha256=Ahk0fTfrHgNdEl-9JSkACvfyyv9G-Ei5OSz7XBIlX5o,357
26
- cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=N3UzszCM55g3GHeVdyXkid2_n72VJrfqBM2gdtD52gw,10042
27
- cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=rJsbGEIxWPGnZyGI4ZwLLBdg8u6ZItsOeh0UoD_bVwk,9551
26
+ cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=2iLxlsc-1dDHRveqCXaC07E9CeMNOuBNkvpJ1atpK7E,10048
27
+ cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=UMkyuEYjO7UO_zmXi9Djd-nD-XMgCUgE-qkYA3plWSM,9559
28
+ cache_dit/cache_factory/patch_functors/functor_hidream.py,sha256=pi_vvpDy1lsgQHxu3eK3v93rdJL7oNwkt3WakRP8pbw,15375
29
+ cache_dit/cache_factory/patch_functors/functor_hunyuan_dit.py,sha256=iSo5dD5uKnjQQeysDUIkKt0wdnK5bzXTc_F_lfHG70w,6401
28
30
  cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
29
31
  cache_dit/compile/utils.py,sha256=nN2OIrSdwRR5zGxJinKDqb07pXpvTNTF3g_OgLkeeBU,3858
30
32
  cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -39,9 +41,9 @@ cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70
39
41
  cache_dit/quantize/quantize_ao.py,sha256=mGspqYgQtenl3QnKPtsSYsSD7LbVX93f1M940bhXKLU,6066
40
42
  cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
41
43
  cache_dit/quantize/quantize_svdq.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
- cache_dit-0.2.30.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
- cache_dit-0.2.30.dist-info/METADATA,sha256=8Ln_X5fw14U3greCM7cSukrei1SRiMDpksFalg5ZBAU,22130
44
- cache_dit-0.2.30.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
- cache_dit-0.2.30.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
- cache_dit-0.2.30.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
- cache_dit-0.2.30.dist-info/RECORD,,
44
+ cache_dit-0.2.32.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
45
+ cache_dit-0.2.32.dist-info/METADATA,sha256=WQ9GP-Om05j3NBvtifkmbz5t20XBU_-KJQptrK7jQBs,24222
46
+ cache_dit-0.2.32.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
47
+ cache_dit-0.2.32.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
48
+ cache_dit-0.2.32.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
49
+ cache_dit-0.2.32.dist-info/RECORD,,