cache-dit 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. cache_dit/__init__.py +8 -6
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +17 -4
  4. cache_dit/cache_factory/block_adapters/__init__.py +555 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +538 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +262 -938
  8. cache_dit/cache_factory/cache_blocks/__init__.py +60 -11
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +45 -41
  11. cache_dit/cache_factory/cache_blocks/pattern_base.py +106 -80
  12. cache_dit/cache_factory/cache_blocks/utils.py +16 -10
  13. cache_dit/cache_factory/cache_contexts/__init__.py +5 -0
  14. cache_dit/cache_factory/cache_contexts/cache_context.py +327 -0
  15. cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
  16. cache_dit/cache_factory/cache_interface.py +31 -31
  17. cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
  18. cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
  19. cache_dit/quantize/quantize_ao.py +1 -0
  20. cache_dit/utils.py +26 -26
  21. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/METADATA +59 -23
  22. cache_dit-0.2.28.dist-info/RECORD +47 -0
  23. cache_dit/cache_factory/cache_context.py +0 -1155
  24. cache_dit-0.2.26.dist-info/RECORD +0 -42
  25. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  26. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
  27. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
@@ -2,17 +2,17 @@ import inspect
2
2
  import torch
3
3
  import torch.distributed as dist
4
4
 
5
- from cache_dit.cache_factory import cache_context
6
- from cache_dit.cache_factory import ForwardPattern
7
- from cache_dit.cache_factory.cache_blocks.utils import (
8
- patch_cached_stats,
5
+ from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
6
+ from cache_dit.cache_factory.cache_contexts.cache_manager import (
7
+ CachedContextManager,
9
8
  )
9
+ from cache_dit.cache_factory import ForwardPattern
10
10
  from cache_dit.logger import init_logger
11
11
 
12
12
  logger = init_logger(__name__)
13
13
 
14
14
 
15
- class DBCachedBlocks_Pattern_Base(torch.nn.Module):
15
+ class CachedBlocks_Pattern_Base(torch.nn.Module):
16
16
  _supported_patterns = [
17
17
  ForwardPattern.Pattern_0,
18
18
  ForwardPattern.Pattern_1,
@@ -21,18 +21,35 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
21
21
 
22
22
  def __init__(
23
23
  self,
24
+ # 0. Transformer blocks configuration
24
25
  transformer_blocks: torch.nn.ModuleList,
25
- *,
26
26
  transformer: torch.nn.Module = None,
27
27
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
28
+ check_num_outputs: bool = True,
29
+ # 1. Cache context configuration
30
+ cache_prefix: str = None, # maybe un-need.
31
+ cache_context: CachedContext | str = None,
32
+ cache_manager: CachedContextManager = None,
33
+ **kwargs,
28
34
  ):
29
35
  super().__init__()
30
36
 
37
+ # 0. Transformer blocks configuration
31
38
  self.transformer = transformer
32
39
  self.transformer_blocks = transformer_blocks
33
40
  self.forward_pattern = forward_pattern
41
+ self.check_num_outputs = check_num_outputs
42
+ # 1. Cache context configuration
43
+ self.cache_prefix = cache_prefix
44
+ self.cache_context = cache_context
45
+ self.cache_manager = cache_manager
46
+
34
47
  self._check_forward_pattern()
35
- logger.info(f"Match Cached Blocks: {self.__class__.__name__}")
48
+ logger.info(
49
+ f"Match Cached Blocks: {self.__class__.__name__}, for "
50
+ f"{self.cache_prefix}, cache_context: {self.cache_context}, "
51
+ f"cache_manager: {self.cache_manager.name}."
52
+ )
36
53
 
37
54
  def _check_forward_pattern(self):
38
55
  assert (
@@ -45,16 +62,18 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
45
62
  forward_parameters = set(
46
63
  inspect.signature(block.forward).parameters.keys()
47
64
  )
48
- num_outputs = str(
49
- inspect.signature(block.forward).return_annotation
50
- ).count("torch.Tensor")
51
-
52
- if num_outputs > 0:
53
- assert len(self.forward_pattern.Out) == num_outputs, (
54
- f"The number of block's outputs is {num_outputs} don't not "
55
- f"match the number of the pattern: {self.forward_pattern}, "
56
- f"Out: {len(self.forward_pattern.Out)}."
57
- )
65
+
66
+ if self.check_num_outputs:
67
+ num_outputs = str(
68
+ inspect.signature(block.forward).return_annotation
69
+ ).count("torch.Tensor")
70
+
71
+ if num_outputs > 0:
72
+ assert len(self.forward_pattern.Out) == num_outputs, (
73
+ f"The number of block's outputs is {num_outputs} don't not "
74
+ f"match the number of the pattern: {self.forward_pattern}, "
75
+ f"Out: {len(self.forward_pattern.Out)}."
76
+ )
58
77
 
59
78
  for required_param in self.forward_pattern.In:
60
79
  assert (
@@ -68,6 +87,8 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
68
87
  *args,
69
88
  **kwargs,
70
89
  ):
90
+ self.cache_manager.set_context(self.cache_context)
91
+
71
92
  original_hidden_states = hidden_states
72
93
  # Call first `n` blocks to process the hidden states for
73
94
  # more stable diff calculation.
@@ -81,39 +102,39 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
81
102
  Fn_hidden_states_residual = hidden_states - original_hidden_states
82
103
  del original_hidden_states
83
104
 
84
- cache_context.mark_step_begin()
105
+ self.cache_manager.mark_step_begin()
85
106
  # Residual L1 diff or Hidden States L1 diff
86
- can_use_cache = cache_context.get_can_use_cache(
107
+ can_use_cache = self.cache_manager.can_cache(
87
108
  (
88
109
  Fn_hidden_states_residual
89
- if not cache_context.is_l1_diff_enabled()
110
+ if not self.cache_manager.is_l1_diff_enabled()
90
111
  else hidden_states
91
112
  ),
92
113
  parallelized=self._is_parallelized(),
93
114
  prefix=(
94
- "Fn_residual"
95
- if not cache_context.is_l1_diff_enabled()
96
- else "Fn_hidden_states"
115
+ f"{self.cache_prefix}_Fn_residual"
116
+ if not self.cache_manager.is_l1_diff_enabled()
117
+ else f"{self.cache_prefix}_Fn_hidden_states"
97
118
  ),
98
119
  )
99
120
 
100
121
  torch._dynamo.graph_break()
101
122
  if can_use_cache:
102
- cache_context.add_cached_step()
123
+ self.cache_manager.add_cached_step()
103
124
  del Fn_hidden_states_residual
104
125
  hidden_states, encoder_hidden_states = (
105
- cache_context.apply_hidden_states_residual(
126
+ self.cache_manager.apply_cache(
106
127
  hidden_states,
107
128
  encoder_hidden_states,
108
129
  prefix=(
109
- "Bn_residual"
110
- if cache_context.is_cache_residual()
111
- else "Bn_hidden_states"
130
+ f"{self.cache_prefix}_Bn_residual"
131
+ if self.cache_manager.is_cache_residual()
132
+ else f"{self.cache_prefix}_Bn_hidden_states"
112
133
  ),
113
134
  encoder_prefix=(
114
- "Bn_residual"
115
- if cache_context.is_encoder_cache_residual()
116
- else "Bn_hidden_states"
135
+ f"{self.cache_prefix}_Bn_residual"
136
+ if self.cache_manager.is_encoder_cache_residual()
137
+ else f"{self.cache_prefix}_Bn_hidden_states"
117
138
  ),
118
139
  )
119
140
  )
@@ -127,12 +148,16 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
127
148
  **kwargs,
128
149
  )
129
150
  else:
130
- cache_context.set_Fn_buffer(
131
- Fn_hidden_states_residual, prefix="Fn_residual"
151
+ self.cache_manager.set_Fn_buffer(
152
+ Fn_hidden_states_residual,
153
+ prefix=f"{self.cache_prefix}_Fn_residual",
132
154
  )
133
- if cache_context.is_l1_diff_enabled():
155
+ if self.cache_manager.is_l1_diff_enabled():
134
156
  # for hidden states L1 diff
135
- cache_context.set_Fn_buffer(hidden_states, "Fn_hidden_states")
157
+ self.cache_manager.set_Fn_buffer(
158
+ hidden_states,
159
+ f"{self.cache_prefix}_Fn_hidden_states",
160
+ )
136
161
  del Fn_hidden_states_residual
137
162
  torch._dynamo.graph_break()
138
163
  (
@@ -147,27 +172,27 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
147
172
  **kwargs,
148
173
  )
149
174
  torch._dynamo.graph_break()
150
- if cache_context.is_cache_residual():
151
- cache_context.set_Bn_buffer(
175
+ if self.cache_manager.is_cache_residual():
176
+ self.cache_manager.set_Bn_buffer(
152
177
  hidden_states_residual,
153
- prefix="Bn_residual",
178
+ prefix=f"{self.cache_prefix}_Bn_residual",
154
179
  )
155
180
  else:
156
181
  # TaylorSeer
157
- cache_context.set_Bn_buffer(
182
+ self.cache_manager.set_Bn_buffer(
158
183
  hidden_states,
159
- prefix="Bn_hidden_states",
184
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
160
185
  )
161
- if cache_context.is_encoder_cache_residual():
162
- cache_context.set_Bn_encoder_buffer(
186
+ if self.cache_manager.is_encoder_cache_residual():
187
+ self.cache_manager.set_Bn_encoder_buffer(
163
188
  encoder_hidden_states_residual,
164
- prefix="Bn_residual",
189
+ prefix=f"{self.cache_prefix}_Bn_residual",
165
190
  )
166
191
  else:
167
192
  # TaylorSeer
168
- cache_context.set_Bn_encoder_buffer(
193
+ self.cache_manager.set_Bn_encoder_buffer(
169
194
  encoder_hidden_states,
170
- prefix="Bn_hidden_states",
195
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
171
196
  )
172
197
  torch._dynamo.graph_break()
173
198
  # Call last `n` blocks to further process the hidden states
@@ -179,7 +204,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
179
204
  **kwargs,
180
205
  )
181
206
 
182
- patch_cached_stats(self.transformer)
207
+ # patch cached stats for blocks or remove it.
183
208
  torch._dynamo.graph_break()
184
209
 
185
210
  return (
@@ -213,10 +238,11 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
213
238
  # If so, we can skip some Bn blocks and directly
214
239
  # use the cached values.
215
240
  return (
216
- cache_context.get_current_step() in cache_context.get_cached_steps()
241
+ self.cache_manager.get_current_step()
242
+ in self.cache_manager.get_cached_steps()
217
243
  ) or (
218
- cache_context.get_current_step()
219
- in cache_context.get_cfg_cached_steps()
244
+ self.cache_manager.get_current_step()
245
+ in self.cache_manager.get_cfg_cached_steps()
220
246
  )
221
247
 
222
248
  @torch.compiler.disable
@@ -225,20 +251,20 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
225
251
  # more stable diff calculation.
226
252
  # Fn: [0,...,n-1]
227
253
  selected_Fn_blocks = self.transformer_blocks[
228
- : cache_context.Fn_compute_blocks()
254
+ : self.cache_manager.Fn_compute_blocks()
229
255
  ]
230
256
  return selected_Fn_blocks
231
257
 
232
258
  @torch.compiler.disable
233
259
  def _Mn_blocks(self): # middle blocks
234
260
  # M(N-2n): only transformer_blocks [n,...,N-n], middle
235
- if cache_context.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
261
+ if self.cache_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
236
262
  selected_Mn_blocks = self.transformer_blocks[
237
- cache_context.Fn_compute_blocks() :
263
+ self.cache_manager.Fn_compute_blocks() :
238
264
  ]
239
265
  else:
240
266
  selected_Mn_blocks = self.transformer_blocks[
241
- cache_context.Fn_compute_blocks() : -cache_context.Bn_compute_blocks()
267
+ self.cache_manager.Fn_compute_blocks() : -self.cache_manager.Bn_compute_blocks()
242
268
  ]
243
269
  return selected_Mn_blocks
244
270
 
@@ -246,7 +272,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
246
272
  def _Bn_blocks(self):
247
273
  # Bn: transformer_blocks [N-n+1,...,N-1]
248
274
  selected_Bn_blocks = self.transformer_blocks[
249
- -cache_context.Bn_compute_blocks() :
275
+ -self.cache_manager.Bn_compute_blocks() :
250
276
  ]
251
277
  return selected_Bn_blocks
252
278
 
@@ -257,10 +283,10 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
257
283
  *args,
258
284
  **kwargs,
259
285
  ):
260
- assert cache_context.Fn_compute_blocks() <= len(
286
+ assert self.cache_manager.Fn_compute_blocks() <= len(
261
287
  self.transformer_blocks
262
288
  ), (
263
- f"Fn_compute_blocks {cache_context.Fn_compute_blocks()} must be less than "
289
+ f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
264
290
  f"the number of transformer blocks {len(self.transformer_blocks)}"
265
291
  )
266
292
  for block in self._Fn_blocks():
@@ -357,7 +383,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
357
383
  )
358
384
  # Cache residuals for the non-compute Bn blocks for
359
385
  # subsequent cache steps.
360
- if block_id not in cache_context.Bn_compute_blocks_ids():
386
+ if block_id not in self.cache_manager.Bn_compute_blocks_ids():
361
387
  Bn_i_hidden_states_residual = (
362
388
  hidden_states - Bn_i_original_hidden_states
363
389
  )
@@ -366,22 +392,22 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
366
392
  )
367
393
 
368
394
  # Save original_hidden_states for diff calculation.
369
- cache_context.set_Bn_buffer(
395
+ self.cache_manager.set_Bn_buffer(
370
396
  Bn_i_original_hidden_states,
371
- prefix=f"Bn_{block_id}_original",
397
+ prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
372
398
  )
373
- cache_context.set_Bn_encoder_buffer(
399
+ self.cache_manager.set_Bn_encoder_buffer(
374
400
  Bn_i_original_encoder_hidden_states,
375
- prefix=f"Bn_{block_id}_original",
401
+ prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
376
402
  )
377
403
 
378
- cache_context.set_Bn_buffer(
404
+ self.cache_manager.set_Bn_buffer(
379
405
  Bn_i_hidden_states_residual,
380
- prefix=f"Bn_{block_id}_residual",
406
+ prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
381
407
  )
382
- cache_context.set_Bn_encoder_buffer(
408
+ self.cache_manager.set_Bn_encoder_buffer(
383
409
  Bn_i_encoder_hidden_states_residual,
384
- prefix=f"Bn_{block_id}_residual",
410
+ prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
385
411
  )
386
412
  del Bn_i_hidden_states_residual
387
413
  del Bn_i_encoder_hidden_states_residual
@@ -392,7 +418,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
392
418
  else:
393
419
  # Cache steps: Reuse the cached residuals.
394
420
  # Check if the block is in the Bn_compute_blocks_ids.
395
- if block_id in cache_context.Bn_compute_blocks_ids():
421
+ if block_id in self.cache_manager.Bn_compute_blocks_ids():
396
422
  hidden_states = block(
397
423
  hidden_states,
398
424
  encoder_hidden_states,
@@ -410,25 +436,25 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
410
436
  # Skip the block if it is not in the Bn_compute_blocks_ids.
411
437
  # Use the cached residuals instead.
412
438
  # Check if can use the cached residuals.
413
- if cache_context.get_can_use_cache(
439
+ if self.cache_manager.can_cache(
414
440
  hidden_states, # curr step
415
441
  parallelized=self._is_parallelized(),
416
- threshold=cache_context.non_compute_blocks_diff_threshold(),
417
- prefix=f"Bn_{block_id}_original", # prev step
442
+ threshold=self.cache_manager.non_compute_blocks_diff_threshold(),
443
+ prefix=f"{self.cache_prefix}_Bn_{block_id}_original", # prev step
418
444
  ):
419
445
  hidden_states, encoder_hidden_states = (
420
- cache_context.apply_hidden_states_residual(
446
+ self.cache_manager.apply_cache(
421
447
  hidden_states,
422
448
  encoder_hidden_states,
423
449
  prefix=(
424
- f"Bn_{block_id}_residual"
425
- if cache_context.is_cache_residual()
426
- else f"Bn_{block_id}_original"
450
+ f"{self.cache_prefix}_Bn_{block_id}_residual"
451
+ if self.cache_manager.is_cache_residual()
452
+ else f"{self.cache_prefix}_Bn_{block_id}_original"
427
453
  ),
428
454
  encoder_prefix=(
429
- f"Bn_{block_id}_residual"
430
- if cache_context.is_encoder_cache_residual()
431
- else f"Bn_{block_id}_original"
455
+ f"{self.cache_prefix}_Bn_{block_id}_residual"
456
+ if self.cache_manager.is_encoder_cache_residual()
457
+ else f"{self.cache_prefix}_Bn_{block_id}_original"
432
458
  ),
433
459
  )
434
460
  )
@@ -455,16 +481,16 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
455
481
  *args,
456
482
  **kwargs,
457
483
  ):
458
- if cache_context.Bn_compute_blocks() == 0:
484
+ if self.cache_manager.Bn_compute_blocks() == 0:
459
485
  return hidden_states, encoder_hidden_states
460
486
 
461
- assert cache_context.Bn_compute_blocks() <= len(
487
+ assert self.cache_manager.Bn_compute_blocks() <= len(
462
488
  self.transformer_blocks
463
489
  ), (
464
- f"Bn_compute_blocks {cache_context.Bn_compute_blocks()} must be less than "
490
+ f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
465
491
  f"the number of transformer blocks {len(self.transformer_blocks)}"
466
492
  )
467
- if len(cache_context.Bn_compute_blocks_ids()) > 0:
493
+ if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
468
494
  for i, block in enumerate(self._Bn_blocks()):
469
495
  hidden_states, encoder_hidden_states = (
470
496
  self._compute_or_cache_block(
@@ -1,19 +1,25 @@
1
1
  import torch
2
2
 
3
- from cache_dit.cache_factory import cache_context
3
+ from typing import Any
4
+ from cache_dit.cache_factory import CachedContext
5
+ from cache_dit.cache_factory import CachedContextManager
4
6
 
5
7
 
6
- @torch.compiler.disable
7
8
  def patch_cached_stats(
8
- transformer,
9
+ module: torch.nn.Module | Any,
10
+ cache_context: CachedContext | str = None,
11
+ cache_manager: CachedContextManager = None,
9
12
  ):
10
- # Patch the cached stats to the transformer, the cached stats
13
+ # Patch the cached stats to the module, the cached stats
11
14
  # will be reset for each calling of pipe.__call__(**kwargs).
12
- if transformer is None:
15
+ if module is None or cache_manager is None:
13
16
  return
14
17
 
15
- # TODO: Patch more cached stats to the transformer
16
- transformer._cached_steps = cache_context.get_cached_steps()
17
- transformer._residual_diffs = cache_context.get_residual_diffs()
18
- transformer._cfg_cached_steps = cache_context.get_cfg_cached_steps()
19
- transformer._cfg_residual_diffs = cache_context.get_cfg_residual_diffs()
18
+ if cache_context is not None:
19
+ cache_manager.set_context(cache_context)
20
+
21
+ # TODO: Patch more cached stats to the module
22
+ module._cached_steps = cache_manager.get_cached_steps()
23
+ module._residual_diffs = cache_manager.get_residual_diffs()
24
+ module._cfg_cached_steps = cache_manager.get_cfg_cached_steps()
25
+ module._cfg_residual_diffs = cache_manager.get_cfg_residual_diffs()
@@ -0,0 +1,5 @@
1
+ # namespace alias: for _CachedContext and many others' cache context funcs.
2
+ from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
3
+ from cache_dit.cache_factory.cache_contexts.cache_manager import (
4
+ CachedContextManager,
5
+ )