cache-dit 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (32) hide show
  1. cache_dit/__init__.py +9 -4
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +16 -3
  4. cache_dit/cache_factory/block_adapters/__init__.py +538 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +121 -563
  8. cache_dit/cache_factory/cache_blocks/__init__.py +18 -0
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +275 -0
  11. cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +100 -82
  12. cache_dit/cache_factory/cache_blocks/utils.py +23 -0
  13. cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
  14. cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +94 -56
  15. cache_dit/cache_factory/cache_interface.py +24 -16
  16. cache_dit/cache_factory/forward_pattern.py +45 -24
  17. cache_dit/cache_factory/patch_functors/__init__.py +5 -0
  18. cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
  19. cache_dit/cache_factory/patch_functors/functor_chroma.py +276 -0
  20. cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +49 -31
  21. cache_dit/quantize/quantize_ao.py +19 -4
  22. cache_dit/quantize/quantize_interface.py +2 -2
  23. cache_dit/utils.py +19 -15
  24. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/METADATA +76 -19
  25. cache_dit-0.2.27.dist-info/RECORD +47 -0
  26. cache_dit-0.2.25.dist-info/RECORD +0 -36
  27. /cache_dit/cache_factory/{patch/__init__.py → cache_contexts/cache_manager.py} +0 -0
  28. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  29. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
  30. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
  31. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
  32. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,18 @@
1
+ from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
2
+ CachedBlocks_Pattern_0_1_2,
3
+ )
4
+ from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
5
+ CachedBlocks_Pattern_3_4_5,
6
+ )
7
+
8
+
9
+ class CachedBlocks:
10
+ def __new__(cls, *args, **kwargs):
11
+ forward_pattern = kwargs.get("forward_pattern", None)
12
+ assert forward_pattern is not None, "forward_pattern can't be None."
13
+ if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
14
+ return CachedBlocks_Pattern_0_1_2(*args, **kwargs)
15
+ elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
16
+ return CachedBlocks_Pattern_3_4_5(*args, **kwargs)
17
+ else:
18
+ raise ValueError(f"Pattern {forward_pattern} is not supported now!")
@@ -0,0 +1,16 @@
1
+ from cache_dit.cache_factory import ForwardPattern
2
+ from cache_dit.cache_factory.cache_blocks.pattern_base import (
3
+ CachedBlocks_Pattern_Base,
4
+ )
5
+ from cache_dit.logger import init_logger
6
+
7
+ logger = init_logger(__name__)
8
+
9
+
10
+ class CachedBlocks_Pattern_0_1_2(CachedBlocks_Pattern_Base):
11
+ _supported_patterns = [
12
+ ForwardPattern.Pattern_0,
13
+ ForwardPattern.Pattern_1,
14
+ ForwardPattern.Pattern_2,
15
+ ]
16
+ ...
@@ -0,0 +1,275 @@
1
+ import torch
2
+
3
+ from cache_dit.cache_factory import CachedContext
4
+ from cache_dit.cache_factory import ForwardPattern
5
+ from cache_dit.cache_factory.cache_blocks.pattern_base import (
6
+ CachedBlocks_Pattern_Base,
7
+ )
8
+ from cache_dit.logger import init_logger
9
+
10
+ logger = init_logger(__name__)
11
+
12
+
13
+ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
14
+ _supported_patterns = [
15
+ ForwardPattern.Pattern_3,
16
+ ForwardPattern.Pattern_4,
17
+ ForwardPattern.Pattern_5,
18
+ ]
19
+
20
+ def forward(
21
+ self,
22
+ hidden_states: torch.Tensor,
23
+ *args,
24
+ **kwargs,
25
+ ):
26
+ # Use it's own cache context.
27
+ CachedContext.set_cache_context(
28
+ self.cache_context,
29
+ )
30
+
31
+ original_hidden_states = hidden_states
32
+ # Call first `n` blocks to process the hidden states for
33
+ # more stable diff calculation.
34
+ # encoder_hidden_states: None Pattern 3, else 4, 5
35
+ hidden_states, encoder_hidden_states = self.call_Fn_blocks(
36
+ hidden_states,
37
+ *args,
38
+ **kwargs,
39
+ )
40
+
41
+ Fn_hidden_states_residual = hidden_states - original_hidden_states
42
+ del original_hidden_states
43
+
44
+ CachedContext.mark_step_begin()
45
+ # Residual L1 diff or Hidden States L1 diff
46
+ can_use_cache = CachedContext.get_can_use_cache(
47
+ (
48
+ Fn_hidden_states_residual
49
+ if not CachedContext.is_l1_diff_enabled()
50
+ else hidden_states
51
+ ),
52
+ parallelized=self._is_parallelized(),
53
+ prefix=(
54
+ f"{self.blocks_name}_Fn_residual"
55
+ if not CachedContext.is_l1_diff_enabled()
56
+ else f"{self.blocks_name}_Fn_hidden_states"
57
+ ),
58
+ )
59
+
60
+ torch._dynamo.graph_break()
61
+ if can_use_cache:
62
+ CachedContext.add_cached_step()
63
+ del Fn_hidden_states_residual
64
+ hidden_states, encoder_hidden_states = (
65
+ CachedContext.apply_hidden_states_residual(
66
+ hidden_states,
67
+ # None Pattern 3, else 4, 5
68
+ encoder_hidden_states,
69
+ prefix=(
70
+ f"{self.blocks_name}_Bn_residual"
71
+ if CachedContext.is_cache_residual()
72
+ else f"{self.blocks_name}_Bn_hidden_states"
73
+ ),
74
+ encoder_prefix=(
75
+ f"{self.blocks_name}_Bn_residual"
76
+ if CachedContext.is_encoder_cache_residual()
77
+ else f"{self.blocks_name}_Bn_hidden_states"
78
+ ),
79
+ )
80
+ )
81
+ torch._dynamo.graph_break()
82
+ # Call last `n` blocks to further process the hidden states
83
+ # for higher precision.
84
+ hidden_states, encoder_hidden_states = self.call_Bn_blocks(
85
+ hidden_states,
86
+ encoder_hidden_states,
87
+ *args,
88
+ **kwargs,
89
+ )
90
+ else:
91
+ CachedContext.set_Fn_buffer(
92
+ Fn_hidden_states_residual,
93
+ prefix=f"{self.blocks_name}_Fn_residual",
94
+ )
95
+ if CachedContext.is_l1_diff_enabled():
96
+ # for hidden states L1 diff
97
+ CachedContext.set_Fn_buffer(
98
+ hidden_states,
99
+ f"{self.blocks_name}_Fn_hidden_states",
100
+ )
101
+ del Fn_hidden_states_residual
102
+ torch._dynamo.graph_break()
103
+ (
104
+ hidden_states,
105
+ encoder_hidden_states,
106
+ hidden_states_residual,
107
+ # None Pattern 3, else 4, 5
108
+ encoder_hidden_states_residual,
109
+ ) = self.call_Mn_blocks( # middle
110
+ hidden_states,
111
+ # None Pattern 3, else 4, 5
112
+ encoder_hidden_states,
113
+ *args,
114
+ **kwargs,
115
+ )
116
+ torch._dynamo.graph_break()
117
+ if CachedContext.is_cache_residual():
118
+ CachedContext.set_Bn_buffer(
119
+ hidden_states_residual,
120
+ prefix=f"{self.blocks_name}_Bn_residual",
121
+ )
122
+ else:
123
+ # TaylorSeer
124
+ CachedContext.set_Bn_buffer(
125
+ hidden_states,
126
+ prefix=f"{self.blocks_name}_Bn_hidden_states",
127
+ )
128
+ if CachedContext.is_encoder_cache_residual():
129
+ CachedContext.set_Bn_encoder_buffer(
130
+ # None Pattern 3, else 4, 5
131
+ encoder_hidden_states_residual,
132
+ prefix=f"{self.blocks_name}_Bn_residual",
133
+ )
134
+ else:
135
+ # TaylorSeer
136
+ CachedContext.set_Bn_encoder_buffer(
137
+ # None Pattern 3, else 4, 5
138
+ encoder_hidden_states,
139
+ prefix=f"{self.blocks_name}_Bn_hidden_states",
140
+ )
141
+ torch._dynamo.graph_break()
142
+ # Call last `n` blocks to further process the hidden states
143
+ # for higher precision.
144
+ hidden_states, encoder_hidden_states = self.call_Bn_blocks(
145
+ hidden_states,
146
+ # None Pattern 3, else 4, 5
147
+ encoder_hidden_states,
148
+ *args,
149
+ **kwargs,
150
+ )
151
+
152
+ torch._dynamo.graph_break()
153
+
154
+ return (
155
+ hidden_states
156
+ if self.forward_pattern.Return_H_Only
157
+ else (
158
+ (hidden_states, encoder_hidden_states)
159
+ if self.forward_pattern.Return_H_First
160
+ else (encoder_hidden_states, hidden_states)
161
+ )
162
+ )
163
+
164
+ def call_Fn_blocks(
165
+ self,
166
+ hidden_states: torch.Tensor,
167
+ *args,
168
+ **kwargs,
169
+ ):
170
+ assert CachedContext.Fn_compute_blocks() <= len(
171
+ self.transformer_blocks
172
+ ), (
173
+ f"Fn_compute_blocks {CachedContext.Fn_compute_blocks()} must be less than "
174
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
175
+ )
176
+ encoder_hidden_states = None # Pattern 3
177
+ for block in self._Fn_blocks():
178
+ hidden_states = block(
179
+ hidden_states,
180
+ *args,
181
+ **kwargs,
182
+ )
183
+ if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
184
+ hidden_states, encoder_hidden_states = hidden_states
185
+ if not self.forward_pattern.Return_H_First:
186
+ hidden_states, encoder_hidden_states = (
187
+ encoder_hidden_states,
188
+ hidden_states,
189
+ )
190
+
191
+ return hidden_states, encoder_hidden_states
192
+
193
+ def call_Mn_blocks(
194
+ self,
195
+ hidden_states: torch.Tensor,
196
+ # None Pattern 3, else 4, 5
197
+ encoder_hidden_states: torch.Tensor | None,
198
+ *args,
199
+ **kwargs,
200
+ ):
201
+ original_hidden_states = hidden_states
202
+ original_encoder_hidden_states = encoder_hidden_states
203
+ for block in self._Mn_blocks():
204
+ hidden_states = block(
205
+ hidden_states,
206
+ *args,
207
+ **kwargs,
208
+ )
209
+ if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
210
+ hidden_states, encoder_hidden_states = hidden_states
211
+ if not self.forward_pattern.Return_H_First:
212
+ hidden_states, encoder_hidden_states = (
213
+ encoder_hidden_states,
214
+ hidden_states,
215
+ )
216
+
217
+ # compute hidden_states residual
218
+ hidden_states = hidden_states.contiguous()
219
+ hidden_states_residual = hidden_states - original_hidden_states
220
+ if (
221
+ original_encoder_hidden_states is not None
222
+ and encoder_hidden_states is not None
223
+ ): # Pattern 4, 5
224
+ encoder_hidden_states_residual = (
225
+ encoder_hidden_states - original_encoder_hidden_states
226
+ )
227
+ else:
228
+ encoder_hidden_states_residual = None # Pattern 3
229
+
230
+ return (
231
+ hidden_states,
232
+ encoder_hidden_states,
233
+ hidden_states_residual,
234
+ encoder_hidden_states_residual,
235
+ )
236
+
237
+ def call_Bn_blocks(
238
+ self,
239
+ hidden_states: torch.Tensor,
240
+ # None Pattern 3, else 4, 5
241
+ encoder_hidden_states: torch.Tensor | None,
242
+ *args,
243
+ **kwargs,
244
+ ):
245
+ if CachedContext.Bn_compute_blocks() == 0:
246
+ return hidden_states, encoder_hidden_states
247
+
248
+ assert CachedContext.Bn_compute_blocks() <= len(
249
+ self.transformer_blocks
250
+ ), (
251
+ f"Bn_compute_blocks {CachedContext.Bn_compute_blocks()} must be less than "
252
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
253
+ )
254
+ if len(CachedContext.Bn_compute_blocks_ids()) > 0:
255
+ raise ValueError(
256
+ f"Bn_compute_blocks_ids is not support for "
257
+ f"patterns: {self._supported_patterns}."
258
+ )
259
+ else:
260
+ # Compute all Bn blocks if no specific Bn compute blocks ids are set.
261
+ for block in self._Bn_blocks():
262
+ hidden_states = block(
263
+ hidden_states,
264
+ *args,
265
+ **kwargs,
266
+ )
267
+ if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
268
+ hidden_states, encoder_hidden_states = hidden_states
269
+ if not self.forward_pattern.Return_H_First:
270
+ hidden_states, encoder_hidden_states = (
271
+ encoder_hidden_states,
272
+ hidden_states,
273
+ )
274
+
275
+ return hidden_states, encoder_hidden_states