cache-dit 0.2.32__py3-none-any.whl → 0.2.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.32'
32
- __version_tuple__ = version_tuple = (0, 2, 32)
31
+ __version__ = version = '0.2.33'
32
+ __version_tuple__ = version_tuple = (0, 2, 33)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -579,7 +579,7 @@ class BlockAdapter:
579
579
  assert isinstance(adapter[0], torch.nn.Module)
580
580
  return getattr(adapter[0], "_is_cached", False)
581
581
  else:
582
- raise TypeError(f"Can't check this type: {type(adapter)}!")
582
+ return getattr(adapter, "_is_cached", False)
583
583
 
584
584
  @classmethod
585
585
  def nested_depth(cls, obj: Any):
@@ -1,6 +1,5 @@
1
1
  import torch
2
2
 
3
- from typing import Dict, Any
4
3
  from cache_dit.cache_factory import ForwardPattern
5
4
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
6
5
  CachedBlocks_Pattern_Base,
@@ -24,14 +23,12 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
24
23
  **kwargs,
25
24
  ):
26
25
  # Use it's own cache context.
27
- self.cache_manager.set_context(
28
- self.cache_context,
29
- )
26
+ self.cache_manager.set_context(self.cache_context)
27
+ self._check_cache_params()
30
28
 
31
29
  original_hidden_states = hidden_states
32
30
  # Call first `n` blocks to process the hidden states for
33
31
  # more stable diff calculation.
34
- # encoder_hidden_states: None Pattern 3, else 4, 5
35
32
  hidden_states, new_encoder_hidden_states = self.call_Fn_blocks(
36
33
  hidden_states,
37
34
  *args,
@@ -109,10 +106,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
109
106
  *args,
110
107
  **kwargs,
111
108
  )
112
- if new_encoder_hidden_states is not None:
113
- new_encoder_hidden_states_residual = (
114
- new_encoder_hidden_states - old_encoder_hidden_states
115
- )
109
+
116
110
  torch._dynamo.graph_break()
117
111
  if self.cache_manager.is_cache_residual():
118
112
  self.cache_manager.set_Bn_buffer(
@@ -125,6 +119,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
125
119
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
126
120
  )
127
121
 
122
+ if new_encoder_hidden_states is not None:
123
+ new_encoder_hidden_states_residual = (
124
+ new_encoder_hidden_states - old_encoder_hidden_states
125
+ )
128
126
  if self.cache_manager.is_encoder_cache_residual():
129
127
  if new_encoder_hidden_states is not None:
130
128
  self.cache_manager.set_Bn_encoder_buffer(
@@ -159,27 +157,12 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
159
157
  )
160
158
  )
161
159
 
162
- @torch.compiler.disable
163
- def maybe_update_kwargs(
164
- self, encoder_hidden_states, kwargs: Dict[str, Any]
165
- ) -> Dict[str, Any]:
166
- # if "encoder_hidden_states" in kwargs:
167
- # kwargs["encoder_hidden_states"] = encoder_hidden_states
168
- # return kwargs
169
- return kwargs
170
-
171
160
  def call_Fn_blocks(
172
161
  self,
173
162
  hidden_states: torch.Tensor,
174
163
  *args,
175
164
  **kwargs,
176
165
  ):
177
- assert self.cache_manager.Fn_compute_blocks() <= len(
178
- self.transformer_blocks
179
- ), (
180
- f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
181
- f"the number of transformer blocks {len(self.transformer_blocks)}"
182
- )
183
166
  new_encoder_hidden_states = None
184
167
  for block in self._Fn_blocks():
185
168
  hidden_states = block(
@@ -194,10 +177,6 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
194
177
  new_encoder_hidden_states,
195
178
  hidden_states,
196
179
  )
197
- kwargs = self.maybe_update_kwargs(
198
- new_encoder_hidden_states,
199
- kwargs,
200
- )
201
180
 
202
181
  return hidden_states, new_encoder_hidden_states
203
182
 
@@ -222,11 +201,6 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
222
201
  new_encoder_hidden_states,
223
202
  hidden_states,
224
203
  )
225
- kwargs = self.maybe_update_kwargs(
226
- new_encoder_hidden_states,
227
- kwargs,
228
- )
229
-
230
204
  # compute hidden_states residual
231
205
  hidden_states = hidden_states.contiguous()
232
206
  hidden_states_residual = hidden_states - original_hidden_states
@@ -243,35 +217,22 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
243
217
  *args,
244
218
  **kwargs,
245
219
  ):
246
- assert self.cache_manager.Bn_compute_blocks() <= len(
247
- self.transformer_blocks
248
- ), (
249
- f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
250
- f"the number of transformer blocks {len(self.transformer_blocks)}"
251
- )
252
- if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
253
- raise ValueError(
254
- f"Bn_compute_blocks_ids is not support for "
255
- f"patterns: {self._supported_patterns}."
220
+ new_encoder_hidden_states = None
221
+ if self.cache_manager.Bn_compute_blocks() == 0:
222
+ return hidden_states, new_encoder_hidden_states
223
+
224
+ for block in self._Bn_blocks():
225
+ hidden_states = block(
226
+ hidden_states,
227
+ *args,
228
+ **kwargs,
256
229
  )
257
- else:
258
- # Compute all Bn blocks if no specific Bn compute blocks ids are set.
259
- for block in self._Bn_blocks():
260
- hidden_states = block(
261
- hidden_states,
262
- *args,
263
- **kwargs,
264
- )
265
- if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
266
- hidden_states, new_encoder_hidden_states = hidden_states
267
- if not self.forward_pattern.Return_H_First:
268
- hidden_states, new_encoder_hidden_states = (
269
- new_encoder_hidden_states,
270
- hidden_states,
271
- )
272
- kwargs = self.maybe_update_kwargs(
273
- new_encoder_hidden_states,
274
- kwargs,
275
- )
230
+ if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
231
+ hidden_states, new_encoder_hidden_states = hidden_states
232
+ if not self.forward_pattern.Return_H_First:
233
+ hidden_states, new_encoder_hidden_states = (
234
+ new_encoder_hidden_states,
235
+ hidden_states,
236
+ )
276
237
 
277
238
  return hidden_states, new_encoder_hidden_states
@@ -93,6 +93,21 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
93
93
  required_param in forward_parameters
94
94
  ), f"The input parameters must contains: {required_param}."
95
95
 
96
+ @torch.compiler.disable
97
+ def _check_cache_params(self):
98
+ assert self.cache_manager.Fn_compute_blocks() <= len(
99
+ self.transformer_blocks
100
+ ), (
101
+ f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
102
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
103
+ )
104
+ assert self.cache_manager.Bn_compute_blocks() <= len(
105
+ self.transformer_blocks
106
+ ), (
107
+ f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
108
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
109
+ )
110
+
96
111
  def forward(
97
112
  self,
98
113
  hidden_states: torch.Tensor,
@@ -100,7 +115,9 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
100
115
  *args,
101
116
  **kwargs,
102
117
  ):
118
+ # Use it's own cache context.
103
119
  self.cache_manager.set_context(self.cache_context)
120
+ self._check_cache_params()
104
121
 
105
122
  original_hidden_states = hidden_states
106
123
  # Call first `n` blocks to process the hidden states for
@@ -191,18 +208,17 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
191
208
  prefix=f"{self.cache_prefix}_Bn_residual",
192
209
  )
193
210
  else:
194
- # TaylorSeer
195
211
  self.cache_manager.set_Bn_buffer(
196
212
  hidden_states,
197
213
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
198
214
  )
215
+
199
216
  if self.cache_manager.is_encoder_cache_residual():
200
217
  self.cache_manager.set_Bn_encoder_buffer(
201
218
  encoder_hidden_states_residual,
202
219
  prefix=f"{self.cache_prefix}_Bn_residual",
203
220
  )
204
221
  else:
205
- # TaylorSeer
206
222
  self.cache_manager.set_Bn_encoder_buffer(
207
223
  encoder_hidden_states,
208
224
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
@@ -296,12 +312,6 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
296
312
  *args,
297
313
  **kwargs,
298
314
  ):
299
- assert self.cache_manager.Fn_compute_blocks() <= len(
300
- self.transformer_blocks
301
- ), (
302
- f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
303
- f"the number of transformer blocks {len(self.transformer_blocks)}"
304
- )
305
315
  for block in self._Fn_blocks():
306
316
  hidden_states = block(
307
317
  hidden_states,
@@ -366,28 +376,17 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
366
376
  encoder_hidden_states_residual,
367
377
  )
368
378
 
369
- def _compute_or_cache_block(
379
+ def call_Bn_blocks(
370
380
  self,
371
- # Block index in the transformer blocks
372
- # Bn: 8, block_id should be in [0, 8)
373
- block_id: int,
374
- # Below are the inputs to the block
375
- block, # The transformer block to be executed
376
381
  hidden_states: torch.Tensor,
377
382
  encoder_hidden_states: torch.Tensor,
378
383
  *args,
379
384
  **kwargs,
380
385
  ):
381
- # Helper function for `call_Bn_blocks`
382
- # Skip the blocks by reuse residual cache if they are not
383
- # in the Bn_compute_blocks_ids. NOTE: We should only skip
384
- # the specific Bn blocks in cache steps. Compute the block
385
- # and cache the residuals in non-cache steps.
386
-
387
- # Normal steps: Compute the block and cache the residuals.
388
- if not self._is_in_cache_step():
389
- Bn_i_original_hidden_states = hidden_states
390
- Bn_i_original_encoder_hidden_states = encoder_hidden_states
386
+ if self.cache_manager.Bn_compute_blocks() == 0:
387
+ return hidden_states, encoder_hidden_states
388
+
389
+ for block in self._Bn_blocks():
391
390
  hidden_states = block(
392
391
  hidden_states,
393
392
  encoder_hidden_states,
@@ -401,149 +400,5 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
401
400
  encoder_hidden_states,
402
401
  hidden_states,
403
402
  )
404
- # Cache residuals for the non-compute Bn blocks for
405
- # subsequent cache steps.
406
- if block_id not in self.cache_manager.Bn_compute_blocks_ids():
407
- Bn_i_hidden_states_residual = (
408
- hidden_states - Bn_i_original_hidden_states
409
- )
410
- if (
411
- encoder_hidden_states is not None
412
- and Bn_i_original_encoder_hidden_states is not None
413
- ):
414
- Bn_i_encoder_hidden_states_residual = (
415
- encoder_hidden_states
416
- - Bn_i_original_encoder_hidden_states
417
- )
418
- else:
419
- Bn_i_encoder_hidden_states_residual = None
420
-
421
- # Save original_hidden_states for diff calculation.
422
- self.cache_manager.set_Bn_buffer(
423
- Bn_i_original_hidden_states,
424
- prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
425
- )
426
- self.cache_manager.set_Bn_encoder_buffer(
427
- Bn_i_original_encoder_hidden_states,
428
- prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
429
- )
430
-
431
- self.cache_manager.set_Bn_buffer(
432
- Bn_i_hidden_states_residual,
433
- prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
434
- )
435
- self.cache_manager.set_Bn_encoder_buffer(
436
- Bn_i_encoder_hidden_states_residual,
437
- prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
438
- )
439
- del Bn_i_hidden_states_residual
440
- del Bn_i_encoder_hidden_states_residual
441
-
442
- del Bn_i_original_hidden_states
443
- del Bn_i_original_encoder_hidden_states
444
-
445
- else:
446
- # Cache steps: Reuse the cached residuals.
447
- # Check if the block is in the Bn_compute_blocks_ids.
448
- if block_id in self.cache_manager.Bn_compute_blocks_ids():
449
- hidden_states = block(
450
- hidden_states,
451
- encoder_hidden_states,
452
- *args,
453
- **kwargs,
454
- )
455
- if not isinstance(hidden_states, torch.Tensor):
456
- hidden_states, encoder_hidden_states = hidden_states
457
- if not self.forward_pattern.Return_H_First:
458
- hidden_states, encoder_hidden_states = (
459
- encoder_hidden_states,
460
- hidden_states,
461
- )
462
- else:
463
- # Skip the block if it is not in the Bn_compute_blocks_ids.
464
- # Use the cached residuals instead.
465
- # Check if can use the cached residuals.
466
- if self.cache_manager.can_cache(
467
- hidden_states, # curr step
468
- parallelized=self._is_parallelized(),
469
- threshold=self.cache_manager.non_compute_blocks_diff_threshold(),
470
- prefix=f"{self.cache_prefix}_Bn_{block_id}_original", # prev step
471
- ):
472
- hidden_states, encoder_hidden_states = (
473
- self.cache_manager.apply_cache(
474
- hidden_states,
475
- encoder_hidden_states,
476
- prefix=(
477
- f"{self.cache_prefix}_Bn_{block_id}_residual"
478
- if self.cache_manager.is_cache_residual()
479
- else f"{self.cache_prefix}_Bn_{block_id}_original"
480
- ),
481
- encoder_prefix=(
482
- f"{self.cache_prefix}_Bn_{block_id}_residual"
483
- if self.cache_manager.is_encoder_cache_residual()
484
- else f"{self.cache_prefix}_Bn_{block_id}_original"
485
- ),
486
- )
487
- )
488
- else:
489
- hidden_states = block(
490
- hidden_states,
491
- encoder_hidden_states,
492
- *args,
493
- **kwargs,
494
- )
495
- if not isinstance(hidden_states, torch.Tensor):
496
- hidden_states, encoder_hidden_states = hidden_states
497
- if not self.forward_pattern.Return_H_First:
498
- hidden_states, encoder_hidden_states = (
499
- encoder_hidden_states,
500
- hidden_states,
501
- )
502
- return hidden_states, encoder_hidden_states
503
-
504
- def call_Bn_blocks(
505
- self,
506
- hidden_states: torch.Tensor,
507
- encoder_hidden_states: torch.Tensor,
508
- *args,
509
- **kwargs,
510
- ):
511
- if self.cache_manager.Bn_compute_blocks() == 0:
512
- return hidden_states, encoder_hidden_states
513
-
514
- assert self.cache_manager.Bn_compute_blocks() <= len(
515
- self.transformer_blocks
516
- ), (
517
- f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
518
- f"the number of transformer blocks {len(self.transformer_blocks)}"
519
- )
520
- if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
521
- for i, block in enumerate(self._Bn_blocks()):
522
- hidden_states, encoder_hidden_states = (
523
- self._compute_or_cache_block(
524
- i,
525
- block,
526
- hidden_states,
527
- encoder_hidden_states,
528
- *args,
529
- **kwargs,
530
- )
531
- )
532
- else:
533
- # Compute all Bn blocks if no specific Bn compute blocks ids are set.
534
- for block in self._Bn_blocks():
535
- hidden_states = block(
536
- hidden_states,
537
- encoder_hidden_states,
538
- *args,
539
- **kwargs,
540
- )
541
- if not isinstance(hidden_states, torch.Tensor):
542
- hidden_states, encoder_hidden_states = hidden_states
543
- if not self.forward_pattern.Return_H_First:
544
- hidden_states, encoder_hidden_states = (
545
- encoder_hidden_states,
546
- hidden_states,
547
- )
548
403
 
549
404
  return hidden_states, encoder_hidden_states
@@ -14,13 +14,9 @@ logger = init_logger(__name__)
14
14
  @dataclasses.dataclass
15
15
  class CachedContext: # Internal CachedContext Impl class
16
16
  name: str = "default"
17
- # Dual Block Cache
18
- # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
17
+ # Dual Block Cache with flexible FnBn configuration.
19
18
  Fn_compute_blocks: int = 1
20
19
  Bn_compute_blocks: int = 0
21
- # We have added residual cache pattern for selected compute blocks
22
- Fn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
23
- Bn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
24
20
  # non compute blocks diff threshold, we don't skip the non
25
21
  # compute blocks if the diff >= threshold
26
22
  non_compute_blocks_diff_threshold: float = 0.08
@@ -31,13 +27,6 @@ class CachedContext: # Internal CachedContext Impl class
31
27
  l1_hidden_states_diff_threshold: float = None
32
28
  important_condition_threshold: float = 0.0
33
29
 
34
- # Alter Cache Settings
35
- # Pattern: 0 F 1 T 2 F 3 T 4 F 5 T ...
36
- enable_alter_cache: bool = False
37
- is_alter_cache: bool = True
38
- # 1.0 means we always cache the residuals if alter_cache is enabled.
39
- alter_residual_diff_threshold: Optional[Union[torch.Tensor, float]] = 1.0
40
-
41
30
  # Buffer for storing the residuals and other tensors
42
31
  buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
43
32
  incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
@@ -63,7 +52,6 @@ class CachedContext: # Internal CachedContext Impl class
63
52
  # Url: https://arxiv.org/pdf/2503.06923
64
53
  enable_taylorseer: bool = False
65
54
  enable_encoder_taylorseer: bool = False
66
- # NOTE: use residual cache for taylorseer may incur precision loss
67
55
  taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
68
56
  taylorseer_order: int = 2 # The order for TaylorSeer
69
57
  taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
@@ -97,16 +85,11 @@ class CachedContext: # Internal CachedContext Impl class
97
85
  )
98
86
  cfg_continuous_cached_steps: int = 0
99
87
 
100
- @torch.compiler.disable
101
88
  def __post_init__(self):
102
89
  if logger.isEnabledFor(logging.DEBUG):
103
90
  logger.info(f"Created _CacheContext: {self.name}")
104
91
  # Some checks for settings
105
92
  if self.enable_spearate_cfg:
106
- assert self.enable_alter_cache is False, (
107
- "enable_alter_cache must set as False if "
108
- "enable_spearate_cfg is enabled."
109
- )
110
93
  if self.cfg_diff_compute_separate:
111
94
  assert self.cfg_compute_first is False, (
112
95
  "cfg_compute_first must set as False if "
@@ -135,47 +118,32 @@ class CachedContext: # Internal CachedContext Impl class
135
118
  **self.taylorseer_kwargs
136
119
  )
137
120
 
138
- @torch.compiler.disable
139
121
  def get_residual_diff_threshold(self):
140
- if self.enable_alter_cache:
141
- residual_diff_threshold = self.alter_residual_diff_threshold
142
- else:
143
- residual_diff_threshold = self.residual_diff_threshold
144
- if self.l1_hidden_states_diff_threshold is not None:
145
- # Use the L1 hidden states diff threshold if set
146
- residual_diff_threshold = self.l1_hidden_states_diff_threshold
122
+ residual_diff_threshold = self.residual_diff_threshold
123
+ if self.l1_hidden_states_diff_threshold is not None:
124
+ # Use the L1 hidden states diff threshold if set
125
+ residual_diff_threshold = self.l1_hidden_states_diff_threshold
147
126
  if isinstance(residual_diff_threshold, torch.Tensor):
148
127
  residual_diff_threshold = residual_diff_threshold.item()
149
128
  return residual_diff_threshold
150
129
 
151
- @torch.compiler.disable
152
130
  def get_buffer(self, name):
153
- if self.enable_alter_cache and self.is_alter_cache:
154
- name = f"{name}_alter"
155
131
  return self.buffers.get(name)
156
132
 
157
- @torch.compiler.disable
158
133
  def set_buffer(self, name, buffer):
159
- if self.enable_alter_cache and self.is_alter_cache:
160
- name = f"{name}_alter"
161
134
  self.buffers[name] = buffer
162
135
 
163
- @torch.compiler.disable
164
136
  def remove_buffer(self, name):
165
- if self.enable_alter_cache and self.is_alter_cache:
166
- name = f"{name}_alter"
167
137
  if name in self.buffers:
168
138
  del self.buffers[name]
169
139
 
170
- @torch.compiler.disable
171
140
  def clear_buffers(self):
172
141
  self.buffers.clear()
173
142
 
174
- @torch.compiler.disable
175
143
  def mark_step_begin(self):
176
144
  # Always increase transformer executed steps
177
- # incr step: prev 0 -> 1; prev 1 -> 2
178
- # current step: incr step - 1
145
+ # incr step: prev 0 -> 1; prev 1 -> 2
146
+ # current step: incr step - 1
179
147
  self.transformer_executed_steps += 1
180
148
  if not self.enable_spearate_cfg:
181
149
  self.executed_steps += 1
@@ -190,10 +158,6 @@ class CachedContext: # Internal CachedContext Impl class
190
158
  # transformer step: 0,2,4,...
191
159
  self.executed_steps += 1
192
160
 
193
- if not self.enable_alter_cache:
194
- # 0 F 1 T 2 F 3 T 4 F 5 T ...
195
- self.is_alter_cache = not self.is_alter_cache
196
-
197
161
  # Reset the cached steps and residual diffs at the beginning
198
162
  # of each inference.
199
163
  if self.get_current_transformer_step() == 0:
@@ -248,7 +212,6 @@ class CachedContext: # Internal CachedContext Impl class
248
212
  def get_cfg_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
249
213
  return self.cfg_taylorseer, self.cfg_encoder_taylorseer
250
214
 
251
- @torch.compiler.disable
252
215
  def add_residual_diff(self, diff):
253
216
  # step: executed_steps - 1, not transformer_steps - 1
254
217
  step = str(self.get_current_step())
@@ -260,15 +223,12 @@ class CachedContext: # Internal CachedContext Impl class
260
223
  if step not in self.cfg_residual_diffs:
261
224
  self.cfg_residual_diffs[step] = diff
262
225
 
263
- @torch.compiler.disable
264
226
  def get_residual_diffs(self):
265
227
  return self.residual_diffs.copy()
266
228
 
267
- @torch.compiler.disable
268
229
  def get_cfg_residual_diffs(self):
269
230
  return self.cfg_residual_diffs.copy()
270
231
 
271
- @torch.compiler.disable
272
232
  def add_cached_step(self):
273
233
  curr_cached_step = self.get_current_step()
274
234
  if not self.is_separate_cfg_step():
@@ -296,23 +256,18 @@ class CachedContext: # Internal CachedContext Impl class
296
256
 
297
257
  self.cfg_cached_steps.append(curr_cached_step)
298
258
 
299
- @torch.compiler.disable
300
259
  def get_cached_steps(self):
301
260
  return self.cached_steps.copy()
302
261
 
303
- @torch.compiler.disable
304
262
  def get_cfg_cached_steps(self):
305
263
  return self.cfg_cached_steps.copy()
306
264
 
307
- @torch.compiler.disable
308
265
  def get_current_step(self):
309
266
  return self.executed_steps - 1
310
267
 
311
- @torch.compiler.disable
312
268
  def get_current_transformer_step(self):
313
269
  return self.transformer_executed_steps - 1
314
270
 
315
- @torch.compiler.disable
316
271
  def is_separate_cfg_step(self):
317
272
  if not self.enable_spearate_cfg:
318
273
  return False
@@ -322,6 +277,5 @@ class CachedContext: # Internal CachedContext Impl class
322
277
  # CFG steps: 1, 3, 5, 7, ...
323
278
  return self.get_current_transformer_step() % 2 != 0
324
279
 
325
- @torch.compiler.disable
326
280
  def is_in_warmup(self):
327
281
  return self.get_current_step() < self.max_warmup_steps
@@ -122,10 +122,7 @@ class CachedContextManager:
122
122
  default_value,
123
123
  )
124
124
 
125
- # Manually set sequence fields, namely, Fn_compute_blocks_ids
126
- # and Bn_compute_blocks_ids, which are lists or sets.
127
- _safe_set_sequence_field("Fn_compute_blocks_ids", [])
128
- _safe_set_sequence_field("Bn_compute_blocks_ids", [])
125
+ # Manually set sequence fields
129
126
  _safe_set_sequence_field("taylorseer_kwargs", {})
130
127
 
131
128
  for attr in cache_attrs:
@@ -301,18 +298,6 @@ class CachedContextManager:
301
298
  return self.is_taylorseer_cache_residual()
302
299
  return True
303
300
 
304
- @torch.compiler.disable
305
- def is_alter_cache_enabled(self) -> bool:
306
- cached_context = self.get_context()
307
- assert cached_context is not None, "cached_context must be set before"
308
- return cached_context.enable_alter_cache
309
-
310
- @torch.compiler.disable
311
- def is_alter_cache(self) -> bool:
312
- cached_context = self.get_context()
313
- assert cached_context is not None, "cached_context must be set before"
314
- return cached_context.is_alter_cache
315
-
316
301
  @torch.compiler.disable
317
302
  def is_in_warmup(self) -> bool:
318
303
  cached_context = self.get_context()
@@ -359,20 +344,6 @@ class CachedContextManager:
359
344
  )
360
345
  return cached_context.Fn_compute_blocks
361
346
 
362
- @torch.compiler.disable
363
- def Fn_compute_blocks_ids(self) -> List[int]:
364
- cached_context = self.get_context()
365
- assert cached_context is not None, "cached_context must be set before"
366
- assert (
367
- len(cached_context.Fn_compute_blocks_ids)
368
- <= cached_context.Fn_compute_blocks
369
- ), (
370
- "The num of Fn_compute_blocks_ids must be <= Fn_compute_blocks "
371
- f"{cached_context.Fn_compute_blocks}, but got "
372
- f"{len(cached_context.Fn_compute_blocks_ids)}"
373
- )
374
- return cached_context.Fn_compute_blocks_ids
375
-
376
347
  @torch.compiler.disable
377
348
  def Bn_compute_blocks(self) -> int:
378
349
  cached_context = self.get_context()
@@ -392,20 +363,6 @@ class CachedContextManager:
392
363
  )
393
364
  return cached_context.Bn_compute_blocks
394
365
 
395
- @torch.compiler.disable
396
- def Bn_compute_blocks_ids(self) -> List[int]:
397
- cached_context = self.get_context()
398
- assert cached_context is not None, "cached_context must be set before"
399
- assert (
400
- len(cached_context.Bn_compute_blocks_ids)
401
- <= cached_context.Bn_compute_blocks
402
- ), (
403
- "The num of Bn_compute_blocks_ids must be <= Bn_compute_blocks "
404
- f"{cached_context.Bn_compute_blocks}, but got "
405
- f"{len(cached_context.Bn_compute_blocks_ids)}"
406
- )
407
- return cached_context.Bn_compute_blocks_ids
408
-
409
366
  @torch.compiler.disable
410
367
  def enable_spearate_cfg(self) -> bool:
411
368
  cached_context = self.get_context()
@@ -525,6 +482,9 @@ class CachedContextManager:
525
482
  # Fn buffers
526
483
  @torch.compiler.disable
527
484
  def set_Fn_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
485
+ # DON'T set None Buffer
486
+ if buffer is None:
487
+ return
528
488
  # Set hidden_states or residual for Fn blocks.
529
489
  # This buffer is only use for L1 diff calculation.
530
490
  downsample_factor = self.get_downsample_factor()
@@ -548,6 +508,9 @@ class CachedContextManager:
548
508
 
549
509
  @torch.compiler.disable
550
510
  def set_Fn_encoder_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
511
+ # DON'T set None Buffer
512
+ if buffer is None:
513
+ return
551
514
  if self.is_separate_cfg_step():
552
515
  self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
553
516
  self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
@@ -566,6 +529,9 @@ class CachedContextManager:
566
529
  # Bn buffers
567
530
  @torch.compiler.disable
568
531
  def set_Bn_buffer(self, buffer: torch.Tensor, prefix: str = "Bn"):
532
+ # DON'T set None Buffer
533
+ if buffer is None:
534
+ return
569
535
  # Set hidden_states or residual for Bn blocks.
570
536
  # This buffer is use for hidden states approximation.
571
537
  if self.is_taylorseer_enabled():
@@ -820,26 +786,12 @@ class CachedContextManager:
820
786
  else:
821
787
  prev_states_tensor = self.get_Fn_buffer(prefix)
822
788
 
823
- if not self.is_alter_cache_enabled():
824
- # Dynamic cache according to the residual diff
825
- can_cache = prev_states_tensor is not None and self.similarity(
826
- prev_states_tensor,
827
- states_tensor,
828
- threshold=threshold,
829
- parallelized=parallelized,
830
- prefix=prefix,
831
- )
832
- else:
833
- # Only cache in the alter cache steps
834
- can_cache = (
835
- prev_states_tensor is not None
836
- and self.similarity(
837
- prev_states_tensor,
838
- states_tensor,
839
- threshold=threshold,
840
- parallelized=parallelized,
841
- prefix=prefix,
842
- )
843
- and self.is_alter_cache()
844
- )
789
+ # Dynamic cache according to the residual diff
790
+ can_cache = prev_states_tensor is not None and self.similarity(
791
+ prev_states_tensor,
792
+ states_tensor,
793
+ threshold=threshold,
794
+ parallelized=parallelized,
795
+ prefix=prefix,
796
+ )
845
797
  return can_cache
@@ -1,5 +1,4 @@
1
1
  import math
2
- import torch
3
2
 
4
3
 
5
4
  class TaylorSeer:
@@ -17,7 +16,6 @@ class TaylorSeer:
17
16
  self.compute_step_map = compute_step_map
18
17
  self.reset_cache()
19
18
 
20
- @torch.compiler.disable
21
19
  def reset_cache(self):
22
20
  self.state = {
23
21
  "dY_prev": [None] * self.ORDER,
@@ -26,7 +24,6 @@ class TaylorSeer:
26
24
  self.current_step = -1
27
25
  self.last_non_approximated_step = -1
28
26
 
29
- @torch.compiler.disable
30
27
  def should_compute_full(self, step=None):
31
28
  step = self.current_step if step is None else step
32
29
  if self.compute_step_map is not None:
@@ -39,7 +36,6 @@ class TaylorSeer:
39
36
  return True
40
37
  return False
41
38
 
42
- @torch.compiler.disable
43
39
  def approximate_derivative(self, Y):
44
40
  # n-th order Taylor expansion:
45
41
  # Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
@@ -58,7 +54,6 @@ class TaylorSeer:
58
54
  break
59
55
  return dY_current
60
56
 
61
- @torch.compiler.disable
62
57
  def approximate_value(self):
63
58
  # TODO: Custom Triton/CUDA kernel for better performance,
64
59
  # especially for large n_derivatives.
@@ -71,11 +66,9 @@ class TaylorSeer:
71
66
  break
72
67
  return output
73
68
 
74
- @torch.compiler.disable
75
69
  def mark_step_begin(self):
76
70
  self.current_step += 1
77
71
 
78
- @torch.compiler.disable
79
72
  def update(self, Y):
80
73
  # Directly call this method will ingnore the warmup
81
74
  # policy and force full computation.
@@ -94,7 +87,6 @@ class TaylorSeer:
94
87
  self.state["dY_current"] = self.approximate_derivative(Y)
95
88
  self.last_non_approximated_step = self.current_step
96
89
 
97
- @torch.compiler.disable
98
90
  def step(self, Y):
99
91
  self.mark_step_begin()
100
92
  if self.should_compute_full():
@@ -89,6 +89,9 @@ def quantize_ao(
89
89
  PerRow,
90
90
  )
91
91
 
92
+ if per_row: # Ensure bfloat16
93
+ module.to(torch.bfloat16)
94
+
92
95
  quantization_fn = float8_dynamic_activation_float8_weight(
93
96
  weight_dtype=kwargs.get(
94
97
  "weight_dtype",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.32
3
+ Version: 0.2.33
4
4
  Summary: 🤗 A Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -61,39 +61,57 @@ Dynamic: requires-python
61
61
  <p align="center">
62
62
  🎉Now, <b>cache-dit</b> covers <b>most</b> mainstream Diffusers' <b>DiT</b> Pipelines🎉<br>
63
63
  🔥<a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Qwen-Image-Lightning</a> | <a href="#supported"> Wan 2.1/2.2 </a>🔥<br>
64
- 🔥<a href="#supported">HunyuanVideo</a> | <a href="#supported">HunyuanDiT</a> | <a href="#supported">HiDream</a> | <a href="#supported">Mochi</a> | <a href="#supported">CogVideoX 1/1.5</a>🔥<br>
65
- 🔥<a href="#supported">CogView3Plus</a> | <a href="#supported">CogView4</a> | <a href="#supported">Chroma</a> | <a href="#supported"> LTXVideo </a> | <a href="#supported">PixArt</a>🔥<br>
64
+ 🔥<a href="#supported">HunyuanImage-2.1</a> | <a href="#supported">HunyuanVideo</a> | <a href="#supported">HunyuanDiT</a> | <a href="#supported">HiDream</a> | <a href="#supported">Mochi</a>🔥<br>
65
+ 🔥<a href="#supported">CogView3Plus</a> | <a href="#supported">CogView4</a> | <a href="#supported">Chroma</a> | <a href="#supported"> LTXVideo </a> | <a href="#supported">CogVideoX 1/1.5</a>🔥<br>
66
66
  🔥<a href="#supported">Cosmos</a> | <a href="#supported">SkyReelsV2</a> | <a href="#supported">VisualCloze</a> | <a href="#supported"> OmniGen </a> | <a href="#supported">Lumina 1/2</a>🔥<br>
67
- 🔥<a href="#supported">Allegro</a> | <a href="#supported">EasyAnimate</a> | <a href="#supported">SD 3/3.5</a> | <a href="#supported"> ... </a> | <a href="#supported">DiT-XL</a>🔥
67
+ 🔥<a href="#supported">Allegro</a> | <a href="#supported">EasyAnimate</a> | <a href="#supported">SD 3/3.5</a> | <a href="#supported"> ... </a> | <a href="#supported">PixArt</a>🔥
68
68
  </p>
69
69
  </div>
70
70
  <div align='center'>
71
- <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C0_Q0_NONE.gif width=160px>
72
- <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q0_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
73
- <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q1_fp8_w8a8_dq_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
74
- <p><b>🔥Wan2.2 MoE</b> | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:~2.0x↑🎉</b> | +FP8 DQ:<b>~2.4x↑🎉</b></p>
71
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C0_Q0_NONE.gif width=124px>
72
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q0_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=124px>
73
+ <img src=./assets/gifs/hunyuan_video.C0_L0_Q0_NONE.gif width=126px>
74
+ <img src=./assets/gifs/hunyuan_video.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.12_S27.gif width=126px>
75
+ <p><b>🔥Wan2.2 MoE</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.0x↑🎉 | <b>HunyuanVideo</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.1x↑🎉</p>
75
76
  <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C0_Q0_NONE.png width=160px>
76
77
  <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C1_Q0_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S23.png width=160px>
77
- <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C1_Q1_fp8_w8a8_dq_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S18.png width=160px>
78
- <p><b>🔥Qwen-Image</b> | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:~1.8x↑🎉</b> | +FP8 DQ:<b>~2.2x↑🎉</b></p>
79
- <img src=./assets/qwen-image-lightning.4steps.C0_L1_Q0_NONE.png width=200px>
80
- <img src=./assets/qwen-image-lightning.4steps.C0_L1_Q0_DBCACHE_F16B16_W2M1MC1_T0O2_R0.9_S1.png width=200px>
81
- <p><b>🔥Qwen-Image-Lightning</b> 4 steps | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a></b> 3.5 steps:<b>~1.14x↑🎉</b>
78
+ <img src=./assets/flux.C0_Q0_NONE_T23.69s.png width=90px>
79
+ <img src=./assets/flux.C0_Q0_DBCACHE_F1B0_W4M0MC0_T1O2_R0.15_S16_T11.39s.png width=90px>
80
+ <p><b>🔥Qwen-Image</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.8x↑🎉 | <b>FLUX.1-dev</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.1x↑🎉</p>
81
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-lightning.4steps.C0_L1_Q0_NONE.png width=160px>
82
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-lightning.4steps.C0_L1_Q0_DBCACHE_F16B16_W2M1MC1_T0O2_R0.9_S1.png width=160px>
83
+ <img src=./assets/sd_3_5.C0_L0_Q0_NONE.png width=90px>
84
+ <img src=./assets/sd_3_5.C0_L0_Q0_DBCACHE_F1B0_W8M0MC3_T0O2_R0.12_S30.png width=90px>
85
+ <p><b>🔥Qwen-Image-Lightning</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.14x↑🎉 | <b>SD 3.5</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.5x↑🎉</p>
86
+ <img src=./assets/hidream.C0_L0_Q0_NONE.png width=100px>
87
+ <img src=./assets/hidream.C0_L0_Q0_DBCACHE_F1B0_W8M0MC0_T0O2_R0.08_S24.png width=100px>
88
+ <img src=./assets/cogview4.C0_L0_Q0_NONE.png width=100px>
89
+ <img src=./assets/cogview4.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S15.png width=100px>
90
+ <img src=./assets/cogview4.C0_L0_Q0_DBCACHE_F1B0_W4M0MC4_T0O2_R0.2_S22.png width=100px>
91
+ <p><b>🔥HiDream-I1</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.9x↑🎉 | <b>CogView4</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.4x↑🎉 | 1.7x↑🎉</p>
92
+ <img src=./assets/gifs/mochi.C0_L0_Q0_NONE.gif width=160px>
93
+ <img src=./assets/gifs/mochi.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S34.gif width=160px>
94
+ <img src=./assets/hunyuan-image-2.1.C0_L0_Q1_fp8_w8a16_wo_NONE.png width=91px>
95
+ <img src=./assets/hunyuan-image-2.1.C0_L0_Q1_fp8_w8a16_wo_DBCACHE_F8B0_W8M0MC2_T1O2_R0.12_S25.png width=91px>
96
+ <p><b>🔥Mochi-1</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.8x↑🎉 | <b>HunyuanImage-2.1</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.7x↑🎉
82
97
  <br>♥️ Please consider to leave a <b>⭐️ Star</b> to support us ~ ♥️</p>
83
98
  </div>
84
99
 
85
100
  ## 🔥News
86
101
 
102
+ - [2025-09-10] 🎉Day 1 support [**HunyuanImage-2.1**](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) with **1.7x↑🎉** speedup! Check this [example](./examples/pipeline/run_hunyuan_image_2.1.py).
87
103
  - [2025-09-08] 🔥[**Qwen-Image-Lightning**](./examples/pipeline/run_qwen_image_lightning.py) **7.1/3.5 steps🎉** inference with **[DBCache: F16B16](https://github.com/vipshop/cache-dit)**.
88
104
  - [2025-09-03] 🎉[**Wan2.2-MoE**](https://github.com/Wan-Video) **2.4x↑🎉** speedup! Please refer to [run_wan_2.2.py](./examples/pipeline/run_wan_2.2.py) as an example.
89
105
  - [2025-08-19] 🔥[**Qwen-Image-Edit**](https://github.com/QwenLM/Qwen-Image) **2x↑🎉** speedup! Check the example: [run_qwen_image_edit.py](./examples/pipeline/run_qwen_image_edit.py).
90
- - [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
91
106
  - [2025-08-11] 🔥[**Qwen-Image**](https://github.com/QwenLM/Qwen-Image) **1.8x↑🎉** speedup! Please refer to [run_qwen_image.py](./examples/pipeline/run_qwen_image.py) as an example.
92
107
  - [2025-07-13] 🎉[**FLUX.1-dev**](https://github.com/xlite-dev/flux-faster) **3.3x↑🎉** speedup! NVIDIA L20 with **[cache-dit](https://github.com/vipshop/cache-dit)** + **compile + FP8 DQ**.
93
108
 
94
109
  <details>
95
110
  <summary> Previous News </summary>
96
111
 
112
+ - [2025-09-08] 🎉First caching mechanism in [Qwen-Image-Lightning](https://github.com/ModelTC/Qwen-Image-Lightning) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/ModelTC/Qwen-Image-Lightning/pull/35).
113
+ - [2025-09-08] 🎉First caching mechanism in [Wan2.2](https://github.com/Wan-Video/Wan2.2) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/Wan-Video/Wan2.2/pull/127) for more details.
114
+ - [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
97
115
  - [2025-09-01] 📚[**Hybird Forward Pattern**](#unified) is supported! Please check [FLUX.1-dev](./examples/run_flux_adapter.py) as an example.
98
116
  - [2025-08-10] 🔥[**FLUX.1-Kontext-dev**](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) is supported! Please refer [run_flux_kontext.py](./examples/pipeline/run_flux_kontext.py) as an example.
99
117
  - [2025-07-18] 🎉First caching mechanism in [🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/huggingface/flux-fast/pull/13).
@@ -139,6 +157,7 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
139
157
 
140
158
  Currently, **cache-dit** library supports almost **Any** Diffusion Transformers (with **Transformer Blocks** that match the specific Input and Output **patterns**). Please check [🎉Unified Cache APIs](#unified) for more details. Here are just some of the tested models listed:
141
159
 
160
+ - [🚀HunyuanImage-2.1](https://github.com/vipshop/cache-dit/raw/main/examples)
142
161
  - [🚀Qwen-Image-Lightning](https://github.com/vipshop/cache-dit/raw/main/examples)
143
162
  - [🚀Qwen-Image-Edit](https://github.com/vipshop/cache-dit/raw/main/examples)
144
163
  - [🚀Qwen-Image](https://github.com/vipshop/cache-dit/raw/main/examples)
@@ -147,6 +166,8 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
147
166
  - [🚀FLUX.1-Kontext-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
148
167
  - [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples)
149
168
  - [🚀CogVideoX1.5](https://github.com/vipshop/cache-dit/raw/main/examples)
169
+ - [🚀CogView3-Plus](https://github.com/vipshop/cache-dit/raw/main/examples)
170
+ - [🚀CogView4](https://github.com/vipshop/cache-dit/raw/main/examples)
150
171
  - [🚀Wan2.2-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
151
172
  - [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
152
173
  - [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
@@ -294,23 +315,7 @@ cache_dit.enable_cache(
294
315
  Bn_compute_blocks=8, # Bn, B8, etc.
295
316
  residual_diff_threshold=0.12,
296
317
  )
297
- ```
298
- Moreover, users configuring higher **Bn** values (e.g., **F8B16**) while aiming to maintain good performance can specify **Bn_compute_blocks_ids** to work with Bn. DBCache will only compute the specified blocks, with the remaining estimated using the previous step's residual cache.
299
-
300
- ```python
301
- # Custom options, F8B16, higher precision with good performance.
302
- cache_dit.enable_cache(
303
- pipe,
304
- Fn_compute_blocks=8, # Fn, F8, etc.
305
- Bn_compute_blocks=16, # Bn, B16, etc.
306
- # 0, 2, 4, ..., 14, 15, etc. [0,16)
307
- Bn_compute_blocks_ids=cache_dit.block_range(0, 16, 2),
308
- # If the L1 difference is below this threshold, skip Bn blocks
309
- # not in `Bn_compute_blocks_ids`(1, 3,..., etc), Otherwise,
310
- # compute these blocks.
311
- non_compute_blocks_diff_threshold=0.08,
312
- )
313
- ```
318
+ ```
314
319
 
315
320
  <div align="center">
316
321
  <p align="center">
@@ -1,5 +1,5 @@
1
1
  cache_dit/__init__.py,sha256=kX9V-FegZG4c8LMwI4PTmMqH794MEW0pzDArdhC0cJw,1241
2
- cache_dit/_version.py,sha256=J0YTFDgdG9rY1Xk5pUbWWGgbT2rbSasvUHcntxayVtA,706
2
+ cache_dit/_version.py,sha256=gTEHTWtuqv38KTvjBsXd5hC019b6d7AyfC8gLMY7KAo,706
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/utils.py,sha256=WK7eqgH6gCYNHXNLmWyxBDU0XSHTPg7CfOcyXlGXBqE,10510
5
5
  cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
@@ -10,17 +10,17 @@ cache_dit/cache_factory/cache_types.py,sha256=ooukxQRG55uTLmaZ0SKw6gIeY6SQHhMxkb
10
10
  cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
11
11
  cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
12
12
  cache_dit/cache_factory/block_adapters/__init__.py,sha256=OZM5vJwmQIkoIwVmMxKXiHqKvs31NyAva1Z91C_ko3w,17547
13
- cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=EQBiJYyoInKU1ND69wTm7M0n5Ja4I8QW01SgRpBjSn8,21671
13
+ cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=IqHV10aK2qA8kEVDi7EEoUSBt0GzwCUM4GpLNf8Jgww,21656
14
14
  cache_dit/cache_factory/block_adapters/block_registers.py,sha256=ZeN2wGPmuf2u3puSsBx8x-rl3wRo8-cWcuWNcrssVfA,2553
15
15
  cache_dit/cache_factory/cache_blocks/__init__.py,sha256=08Ox7kD05lkRKCOsVTdEZeKAWBheqpxfrAT1Nz7eclI,2916
16
16
  cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=ElMps6_7uI74tSF9GDR_dEI0bZEhdzcepM29xFWnYo8,428
17
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=nf2f5wdxp6tfq9AhFyMyBeKiZfxh63WG1g8q-c2BBSg,10182
18
- cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=f1ojREQcDoBtDG3dzl8t1g_Vru8140LVDRPWlY-kAXw,21311
17
+ cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=Bv56qETXhsREvCrNvnZpSqDIIHsi6Ze3FJW4Yk2x3uI,8597
18
+ cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=d4H9kEB0AgnVMT8aF0Y54SUMUQUxw5HQ8gRkoCuTQ_A,14577
19
19
  cache_dit/cache_factory/cache_blocks/utils.py,sha256=dGOC1tMMOvcbvEgx44eTESKn_jsv-0RZ3tRHPa3wmQ4,1315
20
20
  cache_dit/cache_factory/cache_contexts/__init__.py,sha256=rqnJ5__zqnpVHK5A1OqWILpNh5Ss-0ZDTGgtxZMKGGo,250
21
- cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=N88WLdd4KE9DuMWmpX8URcF55E2zWNwcKMxgVYkxMJY,13691
22
- cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=_NUXcMYYEIVfDHpc4HJr9RUjU5RUEkZmAgFGE8bh5Wc,34883
23
- cache_dit/cache_factory/cache_contexts/taylorseer.py,sha256=etSUIZzDvqW3ScKCbccTPcFaSmxV1T-xAXdk-p3e3wk,3802
21
+ cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=zqixcxV_LjnyoYDZ6q3HAC-hqYyVV6g0MWKBI2hA1nQ,11855
22
+ cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=Mcj1upIpXT_CwO4AdY4ZNJSWoOXn3Lx2mBZRi_QuLbU,32710
23
+ cache_dit/cache_factory/cache_contexts/taylorseer.py,sha256=hgLmgIkQgwbFTjxqtLUCJ3mgDGEcJK09B7RK8sBdPiI,3593
24
24
  cache_dit/cache_factory/patch_functors/__init__.py,sha256=06zdddrjvSCgBzJ0a8niRHd3ucF2qsbzlbL00d4aCvk,451
25
25
  cache_dit/cache_factory/patch_functors/functor_base.py,sha256=Ahk0fTfrHgNdEl-9JSkACvfyyv9G-Ei5OSz7XBIlX5o,357
26
26
  cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=2iLxlsc-1dDHRveqCXaC07E9CeMNOuBNkvpJ1atpK7E,10048
@@ -38,12 +38,11 @@ cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR
38
38
  cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
39
39
  cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
40
40
  cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70A,59
41
- cache_dit/quantize/quantize_ao.py,sha256=mGspqYgQtenl3QnKPtsSYsSD7LbVX93f1M940bhXKLU,6066
41
+ cache_dit/quantize/quantize_ao.py,sha256=Fx1KW4l3gdEkdrcAYtPoDW7WKBJWrs3glOHiEwW_TgE,6160
42
42
  cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
43
- cache_dit/quantize/quantize_svdq.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
- cache_dit-0.2.32.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
45
- cache_dit-0.2.32.dist-info/METADATA,sha256=WQ9GP-Om05j3NBvtifkmbz5t20XBU_-KJQptrK7jQBs,24222
46
- cache_dit-0.2.32.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
47
- cache_dit-0.2.32.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
48
- cache_dit-0.2.32.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
49
- cache_dit-0.2.32.dist-info/RECORD,,
43
+ cache_dit-0.2.33.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
44
+ cache_dit-0.2.33.dist-info/METADATA,sha256=GQBvDzKLXL3tABguCRqLNc-Z39h0AcMK_J37demDTu8,25977
45
+ cache_dit-0.2.33.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
46
+ cache_dit-0.2.33.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
47
+ cache_dit-0.2.33.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
48
+ cache_dit-0.2.33.dist-info/RECORD,,
File without changes