cache-dit 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (32) hide show
  1. cache_dit/__init__.py +9 -4
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +16 -3
  4. cache_dit/cache_factory/block_adapters/__init__.py +538 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +121 -563
  8. cache_dit/cache_factory/cache_blocks/__init__.py +18 -0
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +275 -0
  11. cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +100 -82
  12. cache_dit/cache_factory/cache_blocks/utils.py +23 -0
  13. cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
  14. cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +94 -56
  15. cache_dit/cache_factory/cache_interface.py +24 -16
  16. cache_dit/cache_factory/forward_pattern.py +45 -24
  17. cache_dit/cache_factory/patch_functors/__init__.py +5 -0
  18. cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
  19. cache_dit/cache_factory/patch_functors/functor_chroma.py +276 -0
  20. cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +49 -31
  21. cache_dit/quantize/quantize_ao.py +19 -4
  22. cache_dit/quantize/quantize_interface.py +2 -2
  23. cache_dit/utils.py +19 -15
  24. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/METADATA +76 -19
  25. cache_dit-0.2.27.dist-info/RECORD +47 -0
  26. cache_dit-0.2.25.dist-info/RECORD +0 -36
  27. /cache_dit/cache_factory/{patch/__init__.py → cache_contexts/cache_manager.py} +0 -0
  28. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  29. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
  30. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
  31. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
  32. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
@@ -2,14 +2,14 @@ import inspect
2
2
  import torch
3
3
  import torch.distributed as dist
4
4
 
5
- from cache_dit.cache_factory import cache_context
5
+ from cache_dit.cache_factory import CachedContext
6
6
  from cache_dit.cache_factory import ForwardPattern
7
7
  from cache_dit.logger import init_logger
8
8
 
9
9
  logger = init_logger(__name__)
10
10
 
11
11
 
12
- class DBCachedTransformerBlocks(torch.nn.Module):
12
+ class CachedBlocks_Pattern_Base(torch.nn.Module):
13
13
  _supported_patterns = [
14
14
  ForwardPattern.Pattern_0,
15
15
  ForwardPattern.Pattern_1,
@@ -19,28 +19,54 @@ class DBCachedTransformerBlocks(torch.nn.Module):
19
19
  def __init__(
20
20
  self,
21
21
  transformer_blocks: torch.nn.ModuleList,
22
+ # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
23
+ # 'layers', 'single_stream_blocks', 'double_stream_blocks'
24
+ blocks_name: str,
25
+ # Usually, blocks_name, etc.
26
+ cache_context: str,
22
27
  *,
23
28
  transformer: torch.nn.Module = None,
24
29
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
30
+ check_num_outputs: bool = True,
25
31
  ):
26
32
  super().__init__()
27
33
 
28
34
  self.transformer = transformer
29
35
  self.transformer_blocks = transformer_blocks
36
+ self.blocks_name = blocks_name
37
+ self.cache_context = cache_context
30
38
  self.forward_pattern = forward_pattern
39
+ self.check_num_outputs = check_num_outputs
31
40
  self._check_forward_pattern()
41
+ logger.info(
42
+ f"Match Cached Blocks: {self.__class__.__name__}, for "
43
+ f"{self.blocks_name}, context: {self.cache_context}"
44
+ )
32
45
 
33
46
  def _check_forward_pattern(self):
34
47
  assert (
35
48
  self.forward_pattern.Supported
36
49
  and self.forward_pattern in self._supported_patterns
37
- ), f"Pattern {self.forward_pattern} is not support for DBCache now!"
50
+ ), f"Pattern {self.forward_pattern} is not supported now!"
38
51
 
39
52
  if self.transformer_blocks is not None:
40
53
  for block in self.transformer_blocks:
41
54
  forward_parameters = set(
42
55
  inspect.signature(block.forward).parameters.keys()
43
56
  )
57
+
58
+ if self.check_num_outputs:
59
+ num_outputs = str(
60
+ inspect.signature(block.forward).return_annotation
61
+ ).count("torch.Tensor")
62
+
63
+ if num_outputs > 0:
64
+ assert len(self.forward_pattern.Out) == num_outputs, (
65
+ f"The number of block's outputs is {num_outputs} don't not "
66
+ f"match the number of the pattern: {self.forward_pattern}, "
67
+ f"Out: {len(self.forward_pattern.Out)}."
68
+ )
69
+
44
70
  for required_param in self.forward_pattern.In:
45
71
  assert (
46
72
  required_param in forward_parameters
@@ -53,6 +79,10 @@ class DBCachedTransformerBlocks(torch.nn.Module):
53
79
  *args,
54
80
  **kwargs,
55
81
  ):
82
+ CachedContext.set_cache_context(
83
+ self.cache_context,
84
+ )
85
+
56
86
  original_hidden_states = hidden_states
57
87
  # Call first `n` blocks to process the hidden states for
58
88
  # more stable diff calculation.
@@ -66,39 +96,39 @@ class DBCachedTransformerBlocks(torch.nn.Module):
66
96
  Fn_hidden_states_residual = hidden_states - original_hidden_states
67
97
  del original_hidden_states
68
98
 
69
- cache_context.mark_step_begin()
99
+ CachedContext.mark_step_begin()
70
100
  # Residual L1 diff or Hidden States L1 diff
71
- can_use_cache = cache_context.get_can_use_cache(
101
+ can_use_cache = CachedContext.get_can_use_cache(
72
102
  (
73
103
  Fn_hidden_states_residual
74
- if not cache_context.is_l1_diff_enabled()
104
+ if not CachedContext.is_l1_diff_enabled()
75
105
  else hidden_states
76
106
  ),
77
107
  parallelized=self._is_parallelized(),
78
108
  prefix=(
79
- "Fn_residual"
80
- if not cache_context.is_l1_diff_enabled()
81
- else "Fn_hidden_states"
109
+ f"{self.blocks_name}_Fn_residual"
110
+ if not CachedContext.is_l1_diff_enabled()
111
+ else f"{self.blocks_name}_Fn_hidden_states"
82
112
  ),
83
113
  )
84
114
 
85
115
  torch._dynamo.graph_break()
86
116
  if can_use_cache:
87
- cache_context.add_cached_step()
117
+ CachedContext.add_cached_step()
88
118
  del Fn_hidden_states_residual
89
119
  hidden_states, encoder_hidden_states = (
90
- cache_context.apply_hidden_states_residual(
120
+ CachedContext.apply_hidden_states_residual(
91
121
  hidden_states,
92
122
  encoder_hidden_states,
93
123
  prefix=(
94
- "Bn_residual"
95
- if cache_context.is_cache_residual()
96
- else "Bn_hidden_states"
124
+ f"{self.blocks_name}_Bn_residual"
125
+ if CachedContext.is_cache_residual()
126
+ else f"{self.blocks_name}_Bn_hidden_states"
97
127
  ),
98
128
  encoder_prefix=(
99
- "Bn_residual"
100
- if cache_context.is_encoder_cache_residual()
101
- else "Bn_hidden_states"
129
+ f"{self.blocks_name}_Bn_residual"
130
+ if CachedContext.is_encoder_cache_residual()
131
+ else f"{self.blocks_name}_Bn_hidden_states"
102
132
  ),
103
133
  )
104
134
  )
@@ -112,12 +142,16 @@ class DBCachedTransformerBlocks(torch.nn.Module):
112
142
  **kwargs,
113
143
  )
114
144
  else:
115
- cache_context.set_Fn_buffer(
116
- Fn_hidden_states_residual, prefix="Fn_residual"
145
+ CachedContext.set_Fn_buffer(
146
+ Fn_hidden_states_residual,
147
+ prefix=f"{self.blocks_name}_Fn_residual",
117
148
  )
118
- if cache_context.is_l1_diff_enabled():
149
+ if CachedContext.is_l1_diff_enabled():
119
150
  # for hidden states L1 diff
120
- cache_context.set_Fn_buffer(hidden_states, "Fn_hidden_states")
151
+ CachedContext.set_Fn_buffer(
152
+ hidden_states,
153
+ f"{self.blocks_name}_Fn_hidden_states",
154
+ )
121
155
  del Fn_hidden_states_residual
122
156
  torch._dynamo.graph_break()
123
157
  (
@@ -132,27 +166,27 @@ class DBCachedTransformerBlocks(torch.nn.Module):
132
166
  **kwargs,
133
167
  )
134
168
  torch._dynamo.graph_break()
135
- if cache_context.is_cache_residual():
136
- cache_context.set_Bn_buffer(
169
+ if CachedContext.is_cache_residual():
170
+ CachedContext.set_Bn_buffer(
137
171
  hidden_states_residual,
138
- prefix="Bn_residual",
172
+ prefix=f"{self.blocks_name}_Bn_residual",
139
173
  )
140
174
  else:
141
175
  # TaylorSeer
142
- cache_context.set_Bn_buffer(
176
+ CachedContext.set_Bn_buffer(
143
177
  hidden_states,
144
- prefix="Bn_hidden_states",
178
+ prefix=f"{self.blocks_name}_Bn_hidden_states",
145
179
  )
146
- if cache_context.is_encoder_cache_residual():
147
- cache_context.set_Bn_encoder_buffer(
180
+ if CachedContext.is_encoder_cache_residual():
181
+ CachedContext.set_Bn_encoder_buffer(
148
182
  encoder_hidden_states_residual,
149
- prefix="Bn_residual",
183
+ prefix=f"{self.blocks_name}_Bn_residual",
150
184
  )
151
185
  else:
152
186
  # TaylorSeer
153
- cache_context.set_Bn_encoder_buffer(
187
+ CachedContext.set_Bn_encoder_buffer(
154
188
  encoder_hidden_states,
155
- prefix="Bn_hidden_states",
189
+ prefix=f"{self.blocks_name}_Bn_hidden_states",
156
190
  )
157
191
  torch._dynamo.graph_break()
158
192
  # Call last `n` blocks to further process the hidden states
@@ -164,7 +198,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
164
198
  **kwargs,
165
199
  )
166
200
 
167
- patch_cached_stats(self.transformer)
201
+ # patch cached stats for blocks or remove it.
168
202
  torch._dynamo.graph_break()
169
203
 
170
204
  return (
@@ -198,10 +232,10 @@ class DBCachedTransformerBlocks(torch.nn.Module):
198
232
  # If so, we can skip some Bn blocks and directly
199
233
  # use the cached values.
200
234
  return (
201
- cache_context.get_current_step() in cache_context.get_cached_steps()
235
+ CachedContext.get_current_step() in CachedContext.get_cached_steps()
202
236
  ) or (
203
- cache_context.get_current_step()
204
- in cache_context.get_cfg_cached_steps()
237
+ CachedContext.get_current_step()
238
+ in CachedContext.get_cfg_cached_steps()
205
239
  )
206
240
 
207
241
  @torch.compiler.disable
@@ -210,20 +244,20 @@ class DBCachedTransformerBlocks(torch.nn.Module):
210
244
  # more stable diff calculation.
211
245
  # Fn: [0,...,n-1]
212
246
  selected_Fn_blocks = self.transformer_blocks[
213
- : cache_context.Fn_compute_blocks()
247
+ : CachedContext.Fn_compute_blocks()
214
248
  ]
215
249
  return selected_Fn_blocks
216
250
 
217
251
  @torch.compiler.disable
218
252
  def _Mn_blocks(self): # middle blocks
219
253
  # M(N-2n): only transformer_blocks [n,...,N-n], middle
220
- if cache_context.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
254
+ if CachedContext.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
221
255
  selected_Mn_blocks = self.transformer_blocks[
222
- cache_context.Fn_compute_blocks() :
256
+ CachedContext.Fn_compute_blocks() :
223
257
  ]
224
258
  else:
225
259
  selected_Mn_blocks = self.transformer_blocks[
226
- cache_context.Fn_compute_blocks() : -cache_context.Bn_compute_blocks()
260
+ CachedContext.Fn_compute_blocks() : -CachedContext.Bn_compute_blocks()
227
261
  ]
228
262
  return selected_Mn_blocks
229
263
 
@@ -231,7 +265,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
231
265
  def _Bn_blocks(self):
232
266
  # Bn: transformer_blocks [N-n+1,...,N-1]
233
267
  selected_Bn_blocks = self.transformer_blocks[
234
- -cache_context.Bn_compute_blocks() :
268
+ -CachedContext.Bn_compute_blocks() :
235
269
  ]
236
270
  return selected_Bn_blocks
237
271
 
@@ -242,10 +276,10 @@ class DBCachedTransformerBlocks(torch.nn.Module):
242
276
  *args,
243
277
  **kwargs,
244
278
  ):
245
- assert cache_context.Fn_compute_blocks() <= len(
279
+ assert CachedContext.Fn_compute_blocks() <= len(
246
280
  self.transformer_blocks
247
281
  ), (
248
- f"Fn_compute_blocks {cache_context.Fn_compute_blocks()} must be less than "
282
+ f"Fn_compute_blocks {CachedContext.Fn_compute_blocks()} must be less than "
249
283
  f"the number of transformer blocks {len(self.transformer_blocks)}"
250
284
  )
251
285
  for block in self._Fn_blocks():
@@ -342,7 +376,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
342
376
  )
343
377
  # Cache residuals for the non-compute Bn blocks for
344
378
  # subsequent cache steps.
345
- if block_id not in cache_context.Bn_compute_blocks_ids():
379
+ if block_id not in CachedContext.Bn_compute_blocks_ids():
346
380
  Bn_i_hidden_states_residual = (
347
381
  hidden_states - Bn_i_original_hidden_states
348
382
  )
@@ -351,22 +385,22 @@ class DBCachedTransformerBlocks(torch.nn.Module):
351
385
  )
352
386
 
353
387
  # Save original_hidden_states for diff calculation.
354
- cache_context.set_Bn_buffer(
388
+ CachedContext.set_Bn_buffer(
355
389
  Bn_i_original_hidden_states,
356
- prefix=f"Bn_{block_id}_original",
390
+ prefix=f"{self.blocks_name}_Bn_{block_id}_original",
357
391
  )
358
- cache_context.set_Bn_encoder_buffer(
392
+ CachedContext.set_Bn_encoder_buffer(
359
393
  Bn_i_original_encoder_hidden_states,
360
- prefix=f"Bn_{block_id}_original",
394
+ prefix=f"{self.blocks_name}_Bn_{block_id}_original",
361
395
  )
362
396
 
363
- cache_context.set_Bn_buffer(
397
+ CachedContext.set_Bn_buffer(
364
398
  Bn_i_hidden_states_residual,
365
- prefix=f"Bn_{block_id}_residual",
399
+ prefix=f"{self.blocks_name}_Bn_{block_id}_residual",
366
400
  )
367
- cache_context.set_Bn_encoder_buffer(
401
+ CachedContext.set_Bn_encoder_buffer(
368
402
  Bn_i_encoder_hidden_states_residual,
369
- prefix=f"Bn_{block_id}_residual",
403
+ prefix=f"{self.blocks_name}_Bn_{block_id}_residual",
370
404
  )
371
405
  del Bn_i_hidden_states_residual
372
406
  del Bn_i_encoder_hidden_states_residual
@@ -377,7 +411,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
377
411
  else:
378
412
  # Cache steps: Reuse the cached residuals.
379
413
  # Check if the block is in the Bn_compute_blocks_ids.
380
- if block_id in cache_context.Bn_compute_blocks_ids():
414
+ if block_id in CachedContext.Bn_compute_blocks_ids():
381
415
  hidden_states = block(
382
416
  hidden_states,
383
417
  encoder_hidden_states,
@@ -395,25 +429,25 @@ class DBCachedTransformerBlocks(torch.nn.Module):
395
429
  # Skip the block if it is not in the Bn_compute_blocks_ids.
396
430
  # Use the cached residuals instead.
397
431
  # Check if can use the cached residuals.
398
- if cache_context.get_can_use_cache(
432
+ if CachedContext.get_can_use_cache(
399
433
  hidden_states, # curr step
400
434
  parallelized=self._is_parallelized(),
401
- threshold=cache_context.non_compute_blocks_diff_threshold(),
402
- prefix=f"Bn_{block_id}_original", # prev step
435
+ threshold=CachedContext.non_compute_blocks_diff_threshold(),
436
+ prefix=f"{self.blocks_name}_Bn_{block_id}_original", # prev step
403
437
  ):
404
438
  hidden_states, encoder_hidden_states = (
405
- cache_context.apply_hidden_states_residual(
439
+ CachedContext.apply_hidden_states_residual(
406
440
  hidden_states,
407
441
  encoder_hidden_states,
408
442
  prefix=(
409
- f"Bn_{block_id}_residual"
410
- if cache_context.is_cache_residual()
411
- else f"Bn_{block_id}_original"
443
+ f"{self.blocks_name}_Bn_{block_id}_residual"
444
+ if CachedContext.is_cache_residual()
445
+ else f"{self.blocks_name}_Bn_{block_id}_original"
412
446
  ),
413
447
  encoder_prefix=(
414
- f"Bn_{block_id}_residual"
415
- if cache_context.is_encoder_cache_residual()
416
- else f"Bn_{block_id}_original"
448
+ f"{self.blocks_name}_Bn_{block_id}_residual"
449
+ if CachedContext.is_encoder_cache_residual()
450
+ else f"{self.blocks_name}_Bn_{block_id}_original"
417
451
  ),
418
452
  )
419
453
  )
@@ -440,16 +474,16 @@ class DBCachedTransformerBlocks(torch.nn.Module):
440
474
  *args,
441
475
  **kwargs,
442
476
  ):
443
- if cache_context.Bn_compute_blocks() == 0:
477
+ if CachedContext.Bn_compute_blocks() == 0:
444
478
  return hidden_states, encoder_hidden_states
445
479
 
446
- assert cache_context.Bn_compute_blocks() <= len(
480
+ assert CachedContext.Bn_compute_blocks() <= len(
447
481
  self.transformer_blocks
448
482
  ), (
449
- f"Bn_compute_blocks {cache_context.Bn_compute_blocks()} must be less than "
483
+ f"Bn_compute_blocks {CachedContext.Bn_compute_blocks()} must be less than "
450
484
  f"the number of transformer blocks {len(self.transformer_blocks)}"
451
485
  )
452
- if len(cache_context.Bn_compute_blocks_ids()) > 0:
486
+ if len(CachedContext.Bn_compute_blocks_ids()) > 0:
453
487
  for i, block in enumerate(self._Bn_blocks()):
454
488
  hidden_states, encoder_hidden_states = (
455
489
  self._compute_or_cache_block(
@@ -479,19 +513,3 @@ class DBCachedTransformerBlocks(torch.nn.Module):
479
513
  )
480
514
 
481
515
  return hidden_states, encoder_hidden_states
482
-
483
-
484
- @torch.compiler.disable
485
- def patch_cached_stats(
486
- transformer,
487
- ):
488
- # Patch the cached stats to the transformer, the cached stats
489
- # will be reset for each calling of pipe.__call__(**kwargs).
490
- if transformer is None:
491
- return
492
-
493
- # TODO: Patch more cached stats to the transformer
494
- transformer._cached_steps = cache_context.get_cached_steps()
495
- transformer._residual_diffs = cache_context.get_residual_diffs()
496
- transformer._cfg_cached_steps = cache_context.get_cfg_cached_steps()
497
- transformer._cfg_residual_diffs = cache_context.get_cfg_residual_diffs()
@@ -0,0 +1,23 @@
1
+ import torch
2
+
3
+ from typing import Any
4
+ from cache_dit.cache_factory import CachedContext
5
+
6
+
7
+ @torch.compiler.disable
8
+ def patch_cached_stats(
9
+ module: torch.nn.Module | Any, cache_context: str = None
10
+ ):
11
+ # Patch the cached stats to the module, the cached stats
12
+ # will be reset for each calling of pipe.__call__(**kwargs).
13
+ if module is None:
14
+ return
15
+
16
+ if cache_context is not None:
17
+ CachedContext.set_cache_context(cache_context)
18
+
19
+ # TODO: Patch more cached stats to the module
20
+ module._cached_steps = CachedContext.get_cached_steps()
21
+ module._residual_diffs = CachedContext.get_residual_diffs()
22
+ module._cfg_cached_steps = CachedContext.get_cfg_cached_steps()
23
+ module._cfg_residual_diffs = CachedContext.get_cfg_residual_diffs()
@@ -0,0 +1,2 @@
1
+ # namespace alias: for _CachedContext and many others' cache context funcs.
2
+ import cache_dit.cache_factory.cache_contexts.cache_context as CachedContext