cache-dit 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.0.1'
32
- __version_tuple__ = version_tuple = (1, 0, 1)
31
+ __version__ = version = '1.0.3'
32
+ __version_tuple__ = version_tuple = (1, 0, 3)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -143,14 +143,30 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
143
143
  from diffusers import QwenImageTransformer2DModel
144
144
 
145
145
  assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
146
- return BlockAdapter(
147
- pipe=pipe,
148
- transformer=pipe.transformer,
149
- blocks=pipe.transformer.transformer_blocks,
150
- forward_pattern=ForwardPattern.Pattern_1,
151
- has_separate_cfg=True,
152
- **kwargs,
153
- )
146
+
147
+ pipe_cls_name: str = pipe.__class__.__name__
148
+ if pipe_cls_name.startswith("QwenImageControlNet"):
149
+ from cache_dit.cache_factory.patch_functors import (
150
+ QwenImageControlNetPatchFunctor,
151
+ )
152
+
153
+ return BlockAdapter(
154
+ pipe=pipe,
155
+ transformer=pipe.transformer,
156
+ blocks=pipe.transformer.transformer_blocks,
157
+ forward_pattern=ForwardPattern.Pattern_1,
158
+ patch_functor=QwenImageControlNetPatchFunctor(),
159
+ has_separate_cfg=True,
160
+ )
161
+ else:
162
+ return BlockAdapter(
163
+ pipe=pipe,
164
+ transformer=pipe.transformer,
165
+ blocks=pipe.transformer.transformer_blocks,
166
+ forward_pattern=ForwardPattern.Pattern_1,
167
+ has_separate_cfg=True,
168
+ **kwargs,
169
+ )
154
170
 
155
171
 
156
172
  @BlockAdapterRegistry.register("LTX")
@@ -14,10 +14,6 @@ from cache_dit.cache_factory.cache_contexts import CachedContextManager
14
14
  from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
15
15
  from cache_dit.cache_factory.cache_contexts import CalibratorConfig
16
16
  from cache_dit.cache_factory.cache_blocks import CachedBlocks
17
- from cache_dit.cache_factory.cache_blocks import (
18
- patch_cached_stats,
19
- remove_cached_stats,
20
- )
21
17
  from cache_dit.logger import init_logger
22
18
 
23
19
  logger = init_logger(__name__)
@@ -167,7 +163,7 @@ class CachedAdapter:
167
163
  cls,
168
164
  block_adapter: BlockAdapter,
169
165
  **cache_context_kwargs,
170
- ) -> DiffusionPipeline:
166
+ ) -> Tuple[List[str], List[Dict[str, Any]]]:
171
167
 
172
168
  BlockAdapter.assert_normalized(block_adapter)
173
169
 
@@ -221,7 +217,7 @@ class CachedAdapter:
221
217
 
222
218
  cls.apply_params_hooks(block_adapter, contexts_kwargs)
223
219
 
224
- return block_adapter.pipe
220
+ return flatten_contexts, contexts_kwargs
225
221
 
226
222
  @classmethod
227
223
  def modify_context_params(
@@ -470,6 +466,10 @@ class CachedAdapter:
470
466
  cls,
471
467
  block_adapter: BlockAdapter,
472
468
  ):
469
+ from cache_dit.cache_factory.cache_blocks import (
470
+ patch_cached_stats,
471
+ )
472
+
473
473
  cache_manager = block_adapter.pipe._cache_manager
474
474
 
475
475
  for i in range(len(block_adapter.transformer)):
@@ -557,6 +557,10 @@ class CachedAdapter:
557
557
  )
558
558
 
559
559
  # release stats hooks
560
+ from cache_dit.cache_factory.cache_blocks import (
561
+ remove_cached_stats,
562
+ )
563
+
560
564
  cls.release_hooks(
561
565
  pipe_or_adapter,
562
566
  remove_cached_stats,
@@ -1,6 +1,9 @@
1
1
  import torch
2
2
 
3
3
  from cache_dit.cache_factory import ForwardPattern
4
+ from cache_dit.cache_factory.cache_contexts.cache_manager import (
5
+ CacheNotExistError,
6
+ )
4
7
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
5
8
  CachedBlocks_Pattern_Base,
6
9
  )
@@ -16,6 +19,70 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
16
19
  ForwardPattern.Pattern_5,
17
20
  ]
18
21
 
22
+ def call_blocks(
23
+ self,
24
+ hidden_states: torch.Tensor,
25
+ *args,
26
+ **kwargs,
27
+ ):
28
+ # Call all blocks to process the hidden states without cache.
29
+ new_encoder_hidden_states = None
30
+ for block in self.transformer_blocks:
31
+ hidden_states = block(
32
+ hidden_states,
33
+ *args,
34
+ **kwargs,
35
+ )
36
+ hidden_states, new_encoder_hidden_states = (
37
+ self._process_block_outputs(hidden_states)
38
+ )
39
+
40
+ return hidden_states, new_encoder_hidden_states
41
+
42
+ @torch.compiler.disable
43
+ def _process_block_outputs(
44
+ self, hidden_states: torch.Tensor | tuple
45
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
46
+ # Process the outputs for the block.
47
+ new_encoder_hidden_states = None
48
+ if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
49
+ if len(hidden_states) == 2:
50
+ if isinstance(hidden_states[1], torch.Tensor):
51
+ hidden_states, new_encoder_hidden_states = hidden_states
52
+ if not self.forward_pattern.Return_H_First:
53
+ hidden_states, new_encoder_hidden_states = (
54
+ new_encoder_hidden_states,
55
+ hidden_states,
56
+ )
57
+ elif isinstance(hidden_states[0], torch.Tensor):
58
+ hidden_states = hidden_states[0]
59
+ else:
60
+ raise ValueError("Unexpected hidden_states format.")
61
+ else:
62
+ assert (
63
+ len(hidden_states) == 1
64
+ ), f"Unexpected output length: {len(hidden_states)}"
65
+ hidden_states = hidden_states[0]
66
+ return hidden_states, new_encoder_hidden_states
67
+
68
+ @torch.compiler.disable
69
+ def _process_forward_outputs(
70
+ self,
71
+ hidden_states: torch.Tensor,
72
+ new_encoder_hidden_states: torch.Tensor | None,
73
+ ) -> (
74
+ torch.Tensor
75
+ | tuple[torch.Tensor, torch.Tensor]
76
+ | tuple[torch.Tensor, None]
77
+ ):
78
+ if self.forward_pattern.Return_H_Only:
79
+ return hidden_states
80
+ else:
81
+ if self.forward_pattern.Return_H_First:
82
+ return (hidden_states, new_encoder_hidden_states)
83
+ else:
84
+ return (new_encoder_hidden_states, hidden_states)
85
+
19
86
  def forward(
20
87
  self,
21
88
  hidden_states: torch.Tensor,
@@ -23,8 +90,19 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
23
90
  **kwargs,
24
91
  ):
25
92
  # Use it's own cache context.
26
- self.cache_manager.set_context(self.cache_context)
27
- self._check_cache_params()
93
+ try:
94
+ self.cache_manager.set_context(self.cache_context)
95
+ self._check_cache_params()
96
+ except CacheNotExistError as e:
97
+ logger.warning(f"Cache context not exist: {e}, skip cache.")
98
+ hidden_states, new_encoder_hidden_states = self.call_blocks(
99
+ hidden_states,
100
+ *args,
101
+ **kwargs,
102
+ )
103
+ return self._process_forward_outputs(
104
+ hidden_states, new_encoder_hidden_states
105
+ )
28
106
 
29
107
  original_hidden_states = hidden_states
30
108
  # Call first `n` blocks to process the hidden states for
@@ -35,7 +113,9 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
35
113
  **kwargs,
36
114
  )
37
115
 
38
- Fn_hidden_states_residual = hidden_states - original_hidden_states
116
+ Fn_hidden_states_residual = hidden_states - original_hidden_states.to(
117
+ hidden_states.device
118
+ )
39
119
  del original_hidden_states
40
120
 
41
121
  self.cache_manager.mark_step_begin()
@@ -147,14 +227,9 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
147
227
 
148
228
  torch._dynamo.graph_break()
149
229
 
150
- return (
151
- hidden_states
152
- if self.forward_pattern.Return_H_Only
153
- else (
154
- (hidden_states, new_encoder_hidden_states)
155
- if self.forward_pattern.Return_H_First
156
- else (new_encoder_hidden_states, hidden_states)
157
- )
230
+ return self._process_forward_outputs(
231
+ hidden_states,
232
+ new_encoder_hidden_states,
158
233
  )
159
234
 
160
235
  def call_Fn_blocks(
@@ -170,13 +245,9 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
170
245
  *args,
171
246
  **kwargs,
172
247
  )
173
- if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
174
- hidden_states, new_encoder_hidden_states = hidden_states
175
- if not self.forward_pattern.Return_H_First:
176
- hidden_states, new_encoder_hidden_states = (
177
- new_encoder_hidden_states,
178
- hidden_states,
179
- )
248
+ hidden_states, new_encoder_hidden_states = (
249
+ self._process_block_outputs(hidden_states)
250
+ )
180
251
 
181
252
  return hidden_states, new_encoder_hidden_states
182
253
 
@@ -194,16 +265,16 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
194
265
  *args,
195
266
  **kwargs,
196
267
  )
197
- if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
198
- hidden_states, new_encoder_hidden_states = hidden_states
199
- if not self.forward_pattern.Return_H_First:
200
- hidden_states, new_encoder_hidden_states = (
201
- new_encoder_hidden_states,
202
- hidden_states,
203
- )
268
+
269
+ hidden_states, new_encoder_hidden_states = (
270
+ self._process_block_outputs(hidden_states)
271
+ )
272
+
204
273
  # compute hidden_states residual
205
274
  hidden_states = hidden_states.contiguous()
206
- hidden_states_residual = hidden_states - original_hidden_states
275
+ hidden_states_residual = hidden_states - original_hidden_states.to(
276
+ hidden_states.device
277
+ )
207
278
 
208
279
  return (
209
280
  hidden_states,
@@ -227,12 +298,9 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
227
298
  *args,
228
299
  **kwargs,
229
300
  )
230
- if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
231
- hidden_states, new_encoder_hidden_states = hidden_states
232
- if not self.forward_pattern.Return_H_First:
233
- hidden_states, new_encoder_hidden_states = (
234
- new_encoder_hidden_states,
235
- hidden_states,
236
- )
301
+
302
+ hidden_states, new_encoder_hidden_states = (
303
+ self._process_block_outputs(hidden_states)
304
+ )
237
305
 
238
306
  return hidden_states, new_encoder_hidden_states
@@ -1,12 +1,11 @@
1
1
  import inspect
2
- import asyncio
3
2
  import torch
4
3
  import torch.distributed as dist
5
4
 
6
- from typing import List
7
5
  from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
8
6
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
9
7
  CachedContextManager,
8
+ CacheNotExistError,
10
9
  )
11
10
  from cache_dit.cache_factory import ForwardPattern
12
11
  from cache_dit.logger import init_logger
@@ -47,7 +46,6 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
47
46
  self.cache_prefix = cache_prefix
48
47
  self.cache_context = cache_context
49
48
  self.cache_manager = cache_manager
50
- self.pending_tasks: List[asyncio.Task] = []
51
49
 
52
50
  self._check_forward_pattern()
53
51
  logger.info(
@@ -111,6 +109,62 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
111
109
  f"the number of transformer blocks {len(self.transformer_blocks)}"
112
110
  )
113
111
 
112
+ def call_blocks(
113
+ self,
114
+ hidden_states: torch.Tensor,
115
+ encoder_hidden_states: torch.Tensor,
116
+ *args,
117
+ **kwargs,
118
+ ):
119
+ # Call all blocks to process the hidden states without cache.
120
+ for block in self.transformer_blocks:
121
+ hidden_states = block(
122
+ hidden_states,
123
+ encoder_hidden_states,
124
+ *args,
125
+ **kwargs,
126
+ )
127
+ if not isinstance(hidden_states, torch.Tensor):
128
+ hidden_states, encoder_hidden_states = hidden_states
129
+ if not self.forward_pattern.Return_H_First:
130
+ hidden_states, encoder_hidden_states = (
131
+ encoder_hidden_states,
132
+ hidden_states,
133
+ )
134
+
135
+ return hidden_states, encoder_hidden_states
136
+
137
+ @torch.compiler.disable
138
+ def _process_block_outputs(
139
+ self,
140
+ hidden_states: torch.Tensor | tuple,
141
+ encoder_hidden_states: torch.Tensor | None,
142
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
143
+ if not isinstance(hidden_states, torch.Tensor):
144
+ hidden_states, encoder_hidden_states = hidden_states
145
+ if not self.forward_pattern.Return_H_First:
146
+ hidden_states, encoder_hidden_states = (
147
+ encoder_hidden_states,
148
+ hidden_states,
149
+ )
150
+ return hidden_states, encoder_hidden_states
151
+
152
+ @torch.compiler.disable
153
+ def _process_forward_outputs(
154
+ self,
155
+ hidden_states: torch.Tensor,
156
+ encoder_hidden_states: torch.Tensor | None,
157
+ ) -> tuple[torch.Tensor, torch.Tensor | None] | torch.Tensor:
158
+ return (
159
+ hidden_states
160
+ if self.forward_pattern.Return_H_Only
161
+ else (
162
+ (hidden_states, encoder_hidden_states)
163
+ if self.forward_pattern.Return_H_First
164
+ else (encoder_hidden_states, hidden_states)
165
+ )
166
+ )
167
+
114
168
  def forward(
115
169
  self,
116
170
  hidden_states: torch.Tensor,
@@ -119,8 +173,22 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
119
173
  **kwargs,
120
174
  ):
121
175
  # Use it's own cache context.
122
- self.cache_manager.set_context(self.cache_context)
123
- self._check_cache_params()
176
+ try:
177
+ self.cache_manager.set_context(self.cache_context)
178
+ self._check_cache_params()
179
+ except CacheNotExistError as e:
180
+ logger.warning(f"Cache context not exist: {e}, skip cache.")
181
+ # Call all blocks to process the hidden states.
182
+ hidden_states, encoder_hidden_states = self.call_blocks(
183
+ hidden_states,
184
+ encoder_hidden_states,
185
+ *args,
186
+ **kwargs,
187
+ )
188
+ return self._process_forward_outputs(
189
+ hidden_states,
190
+ encoder_hidden_states,
191
+ )
124
192
 
125
193
  original_hidden_states = hidden_states
126
194
  # Call first `n` blocks to process the hidden states for
@@ -239,14 +307,9 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
239
307
  # patch cached stats for blocks or remove it.
240
308
  torch._dynamo.graph_break()
241
309
 
242
- return (
243
- hidden_states
244
- if self.forward_pattern.Return_H_Only
245
- else (
246
- (hidden_states, encoder_hidden_states)
247
- if self.forward_pattern.Return_H_First
248
- else (encoder_hidden_states, hidden_states)
249
- )
310
+ return self._process_forward_outputs(
311
+ hidden_states,
312
+ encoder_hidden_states,
250
313
  )
251
314
 
252
315
  @torch.compiler.disable
@@ -322,13 +385,9 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
322
385
  *args,
323
386
  **kwargs,
324
387
  )
325
- if not isinstance(hidden_states, torch.Tensor):
326
- hidden_states, encoder_hidden_states = hidden_states
327
- if not self.forward_pattern.Return_H_First:
328
- hidden_states, encoder_hidden_states = (
329
- encoder_hidden_states,
330
- hidden_states,
331
- )
388
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
389
+ hidden_states, encoder_hidden_states
390
+ )
332
391
 
333
392
  return hidden_states, encoder_hidden_states
334
393
 
@@ -348,13 +407,9 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
348
407
  *args,
349
408
  **kwargs,
350
409
  )
351
- if not isinstance(hidden_states, torch.Tensor):
352
- hidden_states, encoder_hidden_states = hidden_states
353
- if not self.forward_pattern.Return_H_First:
354
- hidden_states, encoder_hidden_states = (
355
- encoder_hidden_states,
356
- hidden_states,
357
- )
410
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
411
+ hidden_states, encoder_hidden_states
412
+ )
358
413
 
359
414
  # compute hidden_states residual
360
415
  hidden_states = hidden_states.contiguous()
@@ -396,12 +451,8 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
396
451
  *args,
397
452
  **kwargs,
398
453
  )
399
- if not isinstance(hidden_states, torch.Tensor):
400
- hidden_states, encoder_hidden_states = hidden_states
401
- if not self.forward_pattern.Return_H_First:
402
- hidden_states, encoder_hidden_states = (
403
- encoder_hidden_states,
404
- hidden_states,
405
- )
454
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
455
+ hidden_states, encoder_hidden_states
456
+ )
406
457
 
407
458
  return hidden_states, encoder_hidden_states
@@ -11,4 +11,5 @@ from cache_dit.cache_factory.cache_contexts.cache_context import (
11
11
  )
12
12
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
13
13
  CachedContextManager,
14
+ CacheNotExistError,
14
15
  )
@@ -38,6 +38,10 @@ class BasicCacheConfig:
38
38
  # DBCache does not apply the caching strategy when the number of running steps is less than
39
39
  # or equal to this value, ensuring the model sufficiently learns basic features during warmup.
40
40
  max_warmup_steps: int = 8 # DON'T Cache in warmup steps
41
+ # warmup_interval (`int`, *required*, defaults to 1):
42
+ # Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
43
+ # in warmup steps will be computed, others will use dynamic cache.
44
+ warmup_interval: int = 1 # skip interval in warmup steps
41
45
  # max_cached_steps (`int`, *required*, defaults to -1):
42
46
  # DBCache disables the caching strategy when the previous cached steps exceed this value to
43
47
  # prevent precision degradation.
@@ -71,6 +75,7 @@ class BasicCacheConfig:
71
75
  f"DBCACHE_F{self.Fn_compute_blocks}"
72
76
  f"B{self.Bn_compute_blocks}_"
73
77
  f"W{self.max_warmup_steps}"
78
+ f"I{self.warmup_interval}"
74
79
  f"M{max(0, self.max_cached_steps)}"
75
80
  f"MC{max(0, self.max_continuous_cached_steps)}_"
76
81
  f"R{self.residual_diff_threshold}"
@@ -346,5 +351,15 @@ class CachedContext:
346
351
  # CFG steps: 1, 3, 5, 7, ...
347
352
  return self.get_current_transformer_step() % 2 != 0
348
353
 
354
+ @property
355
+ def warmup_steps(self) -> List[int]:
356
+ return list(
357
+ range(
358
+ 0,
359
+ self.cache_config.max_warmup_steps,
360
+ self.cache_config.warmup_interval,
361
+ )
362
+ )
363
+
349
364
  def is_in_warmup(self):
350
- return self.get_current_step() < self.cache_config.max_warmup_steps
365
+ return self.get_current_step() in self.warmup_steps
@@ -14,6 +14,10 @@ from cache_dit.logger import init_logger
14
14
  logger = init_logger(__name__)
15
15
 
16
16
 
17
+ class CacheNotExistError(Exception):
18
+ pass
19
+
20
+
17
21
  class CachedContextManager:
18
22
  # Each Pipeline should have it's own context manager instance.
19
23
 
@@ -27,16 +31,19 @@ class CachedContextManager:
27
31
  self._cached_context_manager[_context.name] = _context
28
32
  return _context
29
33
 
30
- def set_context(self, cached_context: CachedContext | str):
34
+ def set_context(self, cached_context: CachedContext | str) -> CachedContext:
31
35
  if isinstance(cached_context, CachedContext):
32
36
  self._current_context = cached_context
33
37
  else:
38
+ if cached_context not in self._cached_context_manager:
39
+ raise CacheNotExistError("Context not exist!")
34
40
  self._current_context = self._cached_context_manager[cached_context]
41
+ return self._current_context
35
42
 
36
43
  def get_context(self, name: str = None) -> CachedContext:
37
44
  if name is not None:
38
45
  if name not in self._cached_context_manager:
39
- raise ValueError("Context not exist!")
46
+ raise CacheNotExistError("Context not exist!")
40
47
  return self._cached_context_manager[name]
41
48
  return self._current_context
42
49
 
@@ -38,23 +38,43 @@ def enable_cache(
38
38
  BlockAdapter,
39
39
  ]:
40
40
  r"""
41
- Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
42
- that match the specific Input and Output patterns).
43
-
44
- For a good balance between performance and precision, DBCache is configured by default
45
- with F8B0, 8 warmup steps, and unlimited cached steps.
41
+ The `enable_cache` function serves as a unified caching interface designed to optimize the performance
42
+ of diffusion transformer models by implementing an intelligent caching mechanism known as `DBCache`.
43
+ This API is engineered to be compatible with nearly `all` diffusion transformer architectures that
44
+ feature transformer blocks adhering to standard input-output patterns, eliminating the need for
45
+ architecture-specific modifications.
46
+
47
+ By strategically caching intermediate outputs of transformer blocks during the diffusion process,
48
+ `DBCache` significantly reduces redundant computations without compromising generation quality.
49
+ The caching mechanism works by tracking residual differences between consecutive steps, allowing
50
+ the model to reuse previously computed features when these differences fall below a configurable
51
+ threshold. This approach maintains a balance between computational efficiency and output precision.
52
+
53
+ The default configuration (`F8B0, 8 warmup steps, unlimited cached steps`) is carefully tuned to
54
+ provide an optimal tradeoff for most common use cases. The "F8B0" configuration indicates that
55
+ the first 8 transformer blocks are used to compute stable feature differences, while no final
56
+ blocks are employed for additional fusion. The warmup phase ensures the model establishes
57
+ sufficient feature representation before caching begins, preventing potential degradation of
58
+ output quality.
59
+
60
+ This function seamlessly integrates with both standard diffusion pipelines and custom block
61
+ adapters, making it versatile for various deployment scenarios—from research prototyping to
62
+ production environments where inference speed is critical. By abstracting the complexity of
63
+ caching logic behind a simple interface, it enables developers to enhance model performance
64
+ with minimal code changes.
46
65
 
47
66
  Args:
48
67
  pipe_or_adapter (`DiffusionPipeline` or `BlockAdapter`, *required*):
49
68
  The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
50
69
  For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
51
70
  for the usgae of BlockAdapter.
71
+
52
72
  cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
53
73
  Basic DBCache config for cache context, defaults to BasicCacheConfig(). The configurable params listed belows:
54
74
  Fn_compute_blocks: (`int`, *required*, defaults to 8):
55
- Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
56
- at time step t, enabling the calculation of a more stable L1 diff and delivering more
57
- accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
75
+ Specifies that `DBCache` uses the**first n**Transformer blocks to fit the information at time step t,
76
+ enabling the calculation of a more stable L1 difference and delivering more accurate information
77
+ to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
58
78
  for more details of DBCache.
59
79
  Bn_compute_blocks: (`int`, *required*, defaults to 0):
60
80
  Further fuses approximate information in the **last n** Transformer blocks to enhance
@@ -66,6 +86,9 @@ def enable_cache(
66
86
  max_warmup_steps (`int`, *required*, defaults to 8):
67
87
  DBCache does not apply the caching strategy when the number of running steps is less than
68
88
  or equal to this value, ensuring the model sufficiently learns basic features during warmup.
89
+ warmup_interval (`int`, *required*, defaults to 1):
90
+ Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
91
+ in warmup steps will be computed, others will use dynamic cache.
69
92
  max_cached_steps (`int`, *required*, defaults to -1):
70
93
  DBCache disables the caching strategy when the previous cached steps exceed this value to
71
94
  prevent precision degradation.
@@ -77,14 +100,18 @@ def enable_cache(
77
100
  and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
78
101
  CogVideoX, HunyuanVideo, Mochi, etc.
79
102
  cfg_compute_first (`bool`, *required*, defaults to False):
80
- Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
103
+ Whether to compute cfg forward first, default is False, meaning:
104
+ 0, 2, 4, ..., -> non-CFG step;
81
105
  1, 3, 5, ... -> CFG step.
82
106
  cfg_diff_compute_separate (`bool`, *required*, defaults to True):
83
- Compute separate diff values for CFG and non-CFG step, default True. If False, we will
84
- use the computed diff from current non-CFG transformer step for current CFG step.
107
+ Whether to compute separate difference values for CFG and non-CFG steps, default is True.
108
+ If False, we will use the computed difference from the current non-CFG transformer step
109
+ for the current CFG step.
110
+
85
111
  calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
86
- Config for calibrator, if calibrator_config is not None, means that user want to use DBCache
87
- with specific calibrator, such as taylorseer, foca, and so on.
112
+ Config for calibrator. If calibrator_config is not None, it means the user wants to use DBCache
113
+ with a specific calibrator, such as taylorseer, foca, and so on.
114
+
88
115
  params_modifiers ('ParamsModifier', *optional*, defaults to None):
89
116
  Modify cache context params for specific blocks. The configurable params listed belows:
90
117
  cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
@@ -93,6 +120,7 @@ def enable_cache(
93
120
  The same as 'calibrator_config' param in cache_dit.enable_cache() interface.
94
121
  **kwargs: (`dict`, *optional*, defaults to {}):
95
122
  The same as 'kwargs' param in cache_dit.enable_cache() interface.
123
+
96
124
  kwargs (`dict`, *optional*, defaults to {})
97
125
  Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
98
126
  for more details.
@@ -10,3 +10,6 @@ from cache_dit.cache_factory.patch_functors.functor_hidream import (
10
10
  from cache_dit.cache_factory.patch_functors.functor_hunyuan_dit import (
11
11
  HunyuanDiTPatchFunctor,
12
12
  )
13
+ from cache_dit.cache_factory.patch_functors.functor_qwen_image_controlnet import (
14
+ QwenImageControlNetPatchFunctor,
15
+ )