cache-dit 0.2.25__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (27) hide show
  1. cache_dit/__init__.py +4 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +2 -0
  4. cache_dit/cache_factory/cache_adapters.py +375 -26
  5. cache_dit/cache_factory/cache_blocks/__init__.py +20 -0
  6. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
  7. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +270 -0
  8. cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +17 -18
  9. cache_dit/cache_factory/cache_blocks/utils.py +19 -0
  10. cache_dit/cache_factory/cache_context.py +5 -1
  11. cache_dit/cache_factory/cache_interface.py +7 -2
  12. cache_dit/cache_factory/forward_pattern.py +45 -24
  13. cache_dit/cache_factory/patch_functors/__init__.py +5 -0
  14. cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
  15. cache_dit/cache_factory/patch_functors/functor_chroma.py +273 -0
  16. cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +45 -31
  17. cache_dit/quantize/quantize_ao.py +18 -4
  18. cache_dit/quantize/quantize_interface.py +2 -2
  19. cache_dit/utils.py +3 -2
  20. {cache_dit-0.2.25.dist-info → cache_dit-0.2.26.dist-info}/METADATA +35 -8
  21. cache_dit-0.2.26.dist-info/RECORD +42 -0
  22. cache_dit/cache_factory/patch/__init__.py +0 -0
  23. cache_dit-0.2.25.dist-info/RECORD +0 -36
  24. {cache_dit-0.2.25.dist-info → cache_dit-0.2.26.dist-info}/WHEEL +0 -0
  25. {cache_dit-0.2.25.dist-info → cache_dit-0.2.26.dist-info}/entry_points.txt +0 -0
  26. {cache_dit-0.2.25.dist-info → cache_dit-0.2.26.dist-info}/licenses/LICENSE +0 -0
  27. {cache_dit-0.2.25.dist-info → cache_dit-0.2.26.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,270 @@
1
+ import torch
2
+
3
+ from cache_dit.cache_factory import cache_context
4
+ from cache_dit.cache_factory import ForwardPattern
5
+ from cache_dit.cache_factory.cache_blocks.utils import (
6
+ patch_cached_stats,
7
+ )
8
+ from cache_dit.cache_factory.cache_blocks.pattern_base import (
9
+ DBCachedBlocks_Pattern_Base,
10
+ )
11
+ from cache_dit.logger import init_logger
12
+
13
+ logger = init_logger(__name__)
14
+
15
+
16
+ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
17
+ _supported_patterns = [
18
+ ForwardPattern.Pattern_3,
19
+ ForwardPattern.Pattern_4,
20
+ ForwardPattern.Pattern_5,
21
+ ]
22
+
23
+ def forward(
24
+ self,
25
+ hidden_states: torch.Tensor,
26
+ *args,
27
+ **kwargs,
28
+ ):
29
+ original_hidden_states = hidden_states
30
+ # Call first `n` blocks to process the hidden states for
31
+ # more stable diff calculation.
32
+ # encoder_hidden_states: None Pattern 3, else 4, 5
33
+ hidden_states, encoder_hidden_states = self.call_Fn_blocks(
34
+ hidden_states,
35
+ *args,
36
+ **kwargs,
37
+ )
38
+
39
+ Fn_hidden_states_residual = hidden_states - original_hidden_states
40
+ del original_hidden_states
41
+
42
+ cache_context.mark_step_begin()
43
+ # Residual L1 diff or Hidden States L1 diff
44
+ can_use_cache = cache_context.get_can_use_cache(
45
+ (
46
+ Fn_hidden_states_residual
47
+ if not cache_context.is_l1_diff_enabled()
48
+ else hidden_states
49
+ ),
50
+ parallelized=self._is_parallelized(),
51
+ prefix=(
52
+ "Fn_residual"
53
+ if not cache_context.is_l1_diff_enabled()
54
+ else "Fn_hidden_states"
55
+ ),
56
+ )
57
+
58
+ torch._dynamo.graph_break()
59
+ if can_use_cache:
60
+ cache_context.add_cached_step()
61
+ del Fn_hidden_states_residual
62
+ hidden_states, encoder_hidden_states = (
63
+ cache_context.apply_hidden_states_residual(
64
+ hidden_states,
65
+ # None Pattern 3, else 4, 5
66
+ encoder_hidden_states,
67
+ prefix=(
68
+ "Bn_residual"
69
+ if cache_context.is_cache_residual()
70
+ else "Bn_hidden_states"
71
+ ),
72
+ encoder_prefix=(
73
+ "Bn_residual"
74
+ if cache_context.is_encoder_cache_residual()
75
+ else "Bn_hidden_states"
76
+ ),
77
+ )
78
+ )
79
+ torch._dynamo.graph_break()
80
+ # Call last `n` blocks to further process the hidden states
81
+ # for higher precision.
82
+ hidden_states, encoder_hidden_states = self.call_Bn_blocks(
83
+ hidden_states,
84
+ encoder_hidden_states,
85
+ *args,
86
+ **kwargs,
87
+ )
88
+ else:
89
+ cache_context.set_Fn_buffer(
90
+ Fn_hidden_states_residual, prefix="Fn_residual"
91
+ )
92
+ if cache_context.is_l1_diff_enabled():
93
+ # for hidden states L1 diff
94
+ cache_context.set_Fn_buffer(hidden_states, "Fn_hidden_states")
95
+ del Fn_hidden_states_residual
96
+ torch._dynamo.graph_break()
97
+ (
98
+ hidden_states,
99
+ encoder_hidden_states,
100
+ hidden_states_residual,
101
+ # None Pattern 3, else 4, 5
102
+ encoder_hidden_states_residual,
103
+ ) = self.call_Mn_blocks( # middle
104
+ hidden_states,
105
+ # None Pattern 3, else 4, 5
106
+ encoder_hidden_states,
107
+ *args,
108
+ **kwargs,
109
+ )
110
+ torch._dynamo.graph_break()
111
+ if cache_context.is_cache_residual():
112
+ cache_context.set_Bn_buffer(
113
+ hidden_states_residual,
114
+ prefix="Bn_residual",
115
+ )
116
+ else:
117
+ # TaylorSeer
118
+ cache_context.set_Bn_buffer(
119
+ hidden_states,
120
+ prefix="Bn_hidden_states",
121
+ )
122
+ if cache_context.is_encoder_cache_residual():
123
+ cache_context.set_Bn_encoder_buffer(
124
+ # None Pattern 3, else 4, 5
125
+ encoder_hidden_states_residual,
126
+ prefix="Bn_residual",
127
+ )
128
+ else:
129
+ # TaylorSeer
130
+ cache_context.set_Bn_encoder_buffer(
131
+ # None Pattern 3, else 4, 5
132
+ encoder_hidden_states,
133
+ prefix="Bn_hidden_states",
134
+ )
135
+ torch._dynamo.graph_break()
136
+ # Call last `n` blocks to further process the hidden states
137
+ # for higher precision.
138
+ hidden_states, encoder_hidden_states = self.call_Bn_blocks(
139
+ hidden_states,
140
+ # None Pattern 3, else 4, 5
141
+ encoder_hidden_states,
142
+ *args,
143
+ **kwargs,
144
+ )
145
+
146
+ patch_cached_stats(self.transformer)
147
+ torch._dynamo.graph_break()
148
+
149
+ return (
150
+ hidden_states
151
+ if self.forward_pattern.Return_H_Only
152
+ else (
153
+ (hidden_states, encoder_hidden_states)
154
+ if self.forward_pattern.Return_H_First
155
+ else (encoder_hidden_states, hidden_states)
156
+ )
157
+ )
158
+
159
+ def call_Fn_blocks(
160
+ self,
161
+ hidden_states: torch.Tensor,
162
+ *args,
163
+ **kwargs,
164
+ ):
165
+ assert cache_context.Fn_compute_blocks() <= len(
166
+ self.transformer_blocks
167
+ ), (
168
+ f"Fn_compute_blocks {cache_context.Fn_compute_blocks()} must be less than "
169
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
170
+ )
171
+ encoder_hidden_states = None # Pattern 3
172
+ for block in self._Fn_blocks():
173
+ hidden_states = block(
174
+ hidden_states,
175
+ *args,
176
+ **kwargs,
177
+ )
178
+ if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
179
+ hidden_states, encoder_hidden_states = hidden_states
180
+ if not self.forward_pattern.Return_H_First:
181
+ hidden_states, encoder_hidden_states = (
182
+ encoder_hidden_states,
183
+ hidden_states,
184
+ )
185
+
186
+ return hidden_states, encoder_hidden_states
187
+
188
+ def call_Mn_blocks(
189
+ self,
190
+ hidden_states: torch.Tensor,
191
+ # None Pattern 3, else 4, 5
192
+ encoder_hidden_states: torch.Tensor | None,
193
+ *args,
194
+ **kwargs,
195
+ ):
196
+ original_hidden_states = hidden_states
197
+ original_encoder_hidden_states = encoder_hidden_states
198
+ for block in self._Mn_blocks():
199
+ hidden_states = block(
200
+ hidden_states,
201
+ *args,
202
+ **kwargs,
203
+ )
204
+ if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
205
+ hidden_states, encoder_hidden_states = hidden_states
206
+ if not self.forward_pattern.Return_H_First:
207
+ hidden_states, encoder_hidden_states = (
208
+ encoder_hidden_states,
209
+ hidden_states,
210
+ )
211
+
212
+ # compute hidden_states residual
213
+ hidden_states = hidden_states.contiguous()
214
+ hidden_states_residual = hidden_states - original_hidden_states
215
+ if (
216
+ original_encoder_hidden_states is not None
217
+ and encoder_hidden_states is not None
218
+ ): # Pattern 4, 5
219
+ encoder_hidden_states_residual = (
220
+ encoder_hidden_states - original_encoder_hidden_states
221
+ )
222
+ else:
223
+ encoder_hidden_states_residual = None # Pattern 3
224
+
225
+ return (
226
+ hidden_states,
227
+ encoder_hidden_states,
228
+ hidden_states_residual,
229
+ encoder_hidden_states_residual,
230
+ )
231
+
232
+ def call_Bn_blocks(
233
+ self,
234
+ hidden_states: torch.Tensor,
235
+ # None Pattern 3, else 4, 5
236
+ encoder_hidden_states: torch.Tensor | None,
237
+ *args,
238
+ **kwargs,
239
+ ):
240
+ if cache_context.Bn_compute_blocks() == 0:
241
+ return hidden_states, encoder_hidden_states
242
+
243
+ assert cache_context.Bn_compute_blocks() <= len(
244
+ self.transformer_blocks
245
+ ), (
246
+ f"Bn_compute_blocks {cache_context.Bn_compute_blocks()} must be less than "
247
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
248
+ )
249
+ if len(cache_context.Bn_compute_blocks_ids()) > 0:
250
+ raise ValueError(
251
+ f"Bn_compute_blocks_ids is not support for "
252
+ f"patterns: {self._supported_patterns}."
253
+ )
254
+ else:
255
+ # Compute all Bn blocks if no specific Bn compute blocks ids are set.
256
+ for block in self._Bn_blocks():
257
+ hidden_states = block(
258
+ hidden_states,
259
+ *args,
260
+ **kwargs,
261
+ )
262
+ if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
263
+ hidden_states, encoder_hidden_states = hidden_states
264
+ if not self.forward_pattern.Return_H_First:
265
+ hidden_states, encoder_hidden_states = (
266
+ encoder_hidden_states,
267
+ hidden_states,
268
+ )
269
+
270
+ return hidden_states, encoder_hidden_states
@@ -4,12 +4,15 @@ import torch.distributed as dist
4
4
 
5
5
  from cache_dit.cache_factory import cache_context
6
6
  from cache_dit.cache_factory import ForwardPattern
7
+ from cache_dit.cache_factory.cache_blocks.utils import (
8
+ patch_cached_stats,
9
+ )
7
10
  from cache_dit.logger import init_logger
8
11
 
9
12
  logger = init_logger(__name__)
10
13
 
11
14
 
12
- class DBCachedTransformerBlocks(torch.nn.Module):
15
+ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
13
16
  _supported_patterns = [
14
17
  ForwardPattern.Pattern_0,
15
18
  ForwardPattern.Pattern_1,
@@ -29,18 +32,30 @@ class DBCachedTransformerBlocks(torch.nn.Module):
29
32
  self.transformer_blocks = transformer_blocks
30
33
  self.forward_pattern = forward_pattern
31
34
  self._check_forward_pattern()
35
+ logger.info(f"Match Cached Blocks: {self.__class__.__name__}")
32
36
 
33
37
  def _check_forward_pattern(self):
34
38
  assert (
35
39
  self.forward_pattern.Supported
36
40
  and self.forward_pattern in self._supported_patterns
37
- ), f"Pattern {self.forward_pattern} is not support for DBCache now!"
41
+ ), f"Pattern {self.forward_pattern} is not supported now!"
38
42
 
39
43
  if self.transformer_blocks is not None:
40
44
  for block in self.transformer_blocks:
41
45
  forward_parameters = set(
42
46
  inspect.signature(block.forward).parameters.keys()
43
47
  )
48
+ num_outputs = str(
49
+ inspect.signature(block.forward).return_annotation
50
+ ).count("torch.Tensor")
51
+
52
+ if num_outputs > 0:
53
+ assert len(self.forward_pattern.Out) == num_outputs, (
54
+ f"The number of block's outputs is {num_outputs} don't not "
55
+ f"match the number of the pattern: {self.forward_pattern}, "
56
+ f"Out: {len(self.forward_pattern.Out)}."
57
+ )
58
+
44
59
  for required_param in self.forward_pattern.In:
45
60
  assert (
46
61
  required_param in forward_parameters
@@ -479,19 +494,3 @@ class DBCachedTransformerBlocks(torch.nn.Module):
479
494
  )
480
495
 
481
496
  return hidden_states, encoder_hidden_states
482
-
483
-
484
- @torch.compiler.disable
485
- def patch_cached_stats(
486
- transformer,
487
- ):
488
- # Patch the cached stats to the transformer, the cached stats
489
- # will be reset for each calling of pipe.__call__(**kwargs).
490
- if transformer is None:
491
- return
492
-
493
- # TODO: Patch more cached stats to the transformer
494
- transformer._cached_steps = cache_context.get_cached_steps()
495
- transformer._residual_diffs = cache_context.get_residual_diffs()
496
- transformer._cfg_cached_steps = cache_context.get_cfg_cached_steps()
497
- transformer._cfg_residual_diffs = cache_context.get_cfg_residual_diffs()
@@ -0,0 +1,19 @@
1
+ import torch
2
+
3
+ from cache_dit.cache_factory import cache_context
4
+
5
+
6
+ @torch.compiler.disable
7
+ def patch_cached_stats(
8
+ transformer,
9
+ ):
10
+ # Patch the cached stats to the transformer, the cached stats
11
+ # will be reset for each calling of pipe.__call__(**kwargs).
12
+ if transformer is None:
13
+ return
14
+
15
+ # TODO: Patch more cached stats to the transformer
16
+ transformer._cached_steps = cache_context.get_cached_steps()
17
+ transformer._residual_diffs = cache_context.get_residual_diffs()
18
+ transformer._cfg_cached_steps = cache_context.get_cfg_cached_steps()
19
+ transformer._cfg_residual_diffs = cache_context.get_cfg_residual_diffs()
@@ -941,7 +941,11 @@ def get_Bn_buffer(prefix: str = "Bn"):
941
941
 
942
942
 
943
943
  @torch.compiler.disable
944
- def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
944
+ def set_Bn_encoder_buffer(buffer: torch.Tensor | None, prefix: str = "Bn"):
945
+ # DON'T set None Buffer
946
+ if buffer is None:
947
+ return
948
+
945
949
  # This buffer is use for encoder hidden states approximation.
946
950
  if is_encoder_taylorseer_enabled():
947
951
  # taylorseer, encoder_taylorseer
@@ -1,3 +1,4 @@
1
+ from typing import Any, Tuple, List
1
2
  from diffusers import DiffusionPipeline
2
3
  from cache_dit.cache_factory.forward_pattern import ForwardPattern
3
4
  from cache_dit.cache_factory.cache_types import CacheType
@@ -9,9 +10,13 @@ from cache_dit.logger import init_logger
9
10
  logger = init_logger(__name__)
10
11
 
11
12
 
13
+ def supported_pipelines() -> Tuple[int, List[str]]:
14
+ return UnifiedCacheAdapter.supported_pipelines()
15
+
16
+
12
17
  def enable_cache(
13
18
  # BlockAdapter & forward pattern
14
- pipe_or_adapter: DiffusionPipeline | BlockAdapter,
19
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter | Any,
15
20
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
16
21
  # Cache context kwargs
17
22
  Fn_compute_blocks: int = 8,
@@ -30,7 +35,7 @@ def enable_cache(
30
35
  taylorseer_cache_type: str = "residual",
31
36
  taylorseer_order: int = 2,
32
37
  **other_cache_kwargs,
33
- ) -> DiffusionPipeline:
38
+ ) -> DiffusionPipeline | Any:
34
39
  r"""
35
40
  Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
36
41
  that match the specific Input and Output patterns).
@@ -19,39 +19,57 @@ class ForwardPattern(Enum):
19
19
  self.Supported = Supported
20
20
 
21
21
  Pattern_0 = (
22
- True,
23
- False,
24
- False,
25
- ("hidden_states", "encoder_hidden_states"),
26
- ("hidden_states", "encoder_hidden_states"),
27
- True,
22
+ True, # Return_H_First
23
+ False, # Return_H_Only
24
+ False, # Forward_H_only
25
+ ("hidden_states", "encoder_hidden_states"), # In
26
+ ("hidden_states", "encoder_hidden_states"), # Out
27
+ True, # Supported
28
28
  )
29
29
 
30
30
  Pattern_1 = (
31
- False,
32
- False,
33
- False,
34
- ("hidden_states", "encoder_hidden_states"),
35
- ("encoder_hidden_states", "hidden_states"),
36
- True,
31
+ False, # Return_H_First
32
+ False, # Return_H_Only
33
+ False, # Forward_H_only
34
+ ("hidden_states", "encoder_hidden_states"), # In
35
+ ("encoder_hidden_states", "hidden_states"), # Out
36
+ True, # Supported
37
37
  )
38
38
 
39
39
  Pattern_2 = (
40
- False,
41
- True,
42
- False,
43
- ("hidden_states", "encoder_hidden_states"),
44
- ("hidden_states",),
45
- True,
40
+ False, # Return_H_First
41
+ True, # Return_H_Only
42
+ False, # Forward_H_only
43
+ ("hidden_states", "encoder_hidden_states"), # In
44
+ ("hidden_states",), # Out
45
+ True, # Supported
46
46
  )
47
47
 
48
48
  Pattern_3 = (
49
- False,
50
- True,
51
- False,
52
- ("hidden_states",),
53
- ("hidden_states",),
54
- False,
49
+ False, # Return_H_First
50
+ True, # Return_H_Only
51
+ True, # Forward_H_only
52
+ ("hidden_states",), # In
53
+ ("hidden_states",), # Out
54
+ True, # Supported
55
+ )
56
+
57
+ Pattern_4 = (
58
+ True, # Return_H_First
59
+ False, # Return_H_Only
60
+ True, # Forward_H_only
61
+ ("hidden_states",), # In
62
+ ("hidden_states", "encoder_hidden_states"), # Out
63
+ True, # Supported
64
+ )
65
+
66
+ Pattern_5 = (
67
+ False, # Return_H_First
68
+ False, # Return_H_Only
69
+ True, # Forward_H_only
70
+ ("hidden_states",), # In
71
+ ("encoder_hidden_states", "hidden_states"), # Out
72
+ True, # Supported
55
73
  )
56
74
 
57
75
  @staticmethod
@@ -60,4 +78,7 @@ class ForwardPattern(Enum):
60
78
  ForwardPattern.Pattern_0,
61
79
  ForwardPattern.Pattern_1,
62
80
  ForwardPattern.Pattern_2,
81
+ ForwardPattern.Pattern_3,
82
+ ForwardPattern.Pattern_4,
83
+ ForwardPattern.Pattern_5,
63
84
  ]
@@ -0,0 +1,5 @@
1
+ from cache_dit.cache_factory.patch_functors.functor_base import PatchFunctor
2
+ from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
3
+ from cache_dit.cache_factory.patch_functors.functor_chroma import (
4
+ ChromaPatchFunctor,
5
+ )
@@ -0,0 +1,18 @@
1
+ import torch
2
+ from abc import abstractmethod
3
+
4
+ from cache_dit.logger import init_logger
5
+
6
+ logger = init_logger(__name__)
7
+
8
+
9
+ class PatchFunctor:
10
+
11
+ @abstractmethod
12
+ def apply(
13
+ self,
14
+ transformer: torch.nn.Module,
15
+ *args,
16
+ **kwargs,
17
+ ) -> torch.nn.Module:
18
+ raise NotImplementedError("apply method is not implemented.")