cache-dit 1.0.0__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.0.0'
32
- __version_tuple__ = version_tuple = (1, 0, 0)
31
+ __version__ = version = '1.0.2'
32
+ __version_tuple__ = version_tuple = (1, 0, 2)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -143,14 +143,30 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
143
143
  from diffusers import QwenImageTransformer2DModel
144
144
 
145
145
  assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
146
- return BlockAdapter(
147
- pipe=pipe,
148
- transformer=pipe.transformer,
149
- blocks=pipe.transformer.transformer_blocks,
150
- forward_pattern=ForwardPattern.Pattern_1,
151
- has_separate_cfg=True,
152
- **kwargs,
153
- )
146
+
147
+ pipe_cls_name: str = pipe.__class__.__name__
148
+ if pipe_cls_name.startswith("QwenImageControlNet"):
149
+ from cache_dit.cache_factory.patch_functors import (
150
+ QwenImageControlNetPatchFunctor,
151
+ )
152
+
153
+ return BlockAdapter(
154
+ pipe=pipe,
155
+ transformer=pipe.transformer,
156
+ blocks=pipe.transformer.transformer_blocks,
157
+ forward_pattern=ForwardPattern.Pattern_1,
158
+ patch_functor=QwenImageControlNetPatchFunctor(),
159
+ has_separate_cfg=True,
160
+ )
161
+ else:
162
+ return BlockAdapter(
163
+ pipe=pipe,
164
+ transformer=pipe.transformer,
165
+ blocks=pipe.transformer.transformer_blocks,
166
+ forward_pattern=ForwardPattern.Pattern_1,
167
+ has_separate_cfg=True,
168
+ **kwargs,
169
+ )
154
170
 
155
171
 
156
172
  @BlockAdapterRegistry.register("LTX")
@@ -14,10 +14,6 @@ from cache_dit.cache_factory.cache_contexts import CachedContextManager
14
14
  from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
15
15
  from cache_dit.cache_factory.cache_contexts import CalibratorConfig
16
16
  from cache_dit.cache_factory.cache_blocks import CachedBlocks
17
- from cache_dit.cache_factory.cache_blocks import (
18
- patch_cached_stats,
19
- remove_cached_stats,
20
- )
21
17
  from cache_dit.logger import init_logger
22
18
 
23
19
  logger = init_logger(__name__)
@@ -167,7 +163,7 @@ class CachedAdapter:
167
163
  cls,
168
164
  block_adapter: BlockAdapter,
169
165
  **cache_context_kwargs,
170
- ) -> DiffusionPipeline:
166
+ ) -> Tuple[List[str], List[Dict[str, Any]]]:
171
167
 
172
168
  BlockAdapter.assert_normalized(block_adapter)
173
169
 
@@ -221,7 +217,7 @@ class CachedAdapter:
221
217
 
222
218
  cls.apply_params_hooks(block_adapter, contexts_kwargs)
223
219
 
224
- return block_adapter.pipe
220
+ return flatten_contexts, contexts_kwargs
225
221
 
226
222
  @classmethod
227
223
  def modify_context_params(
@@ -470,6 +466,10 @@ class CachedAdapter:
470
466
  cls,
471
467
  block_adapter: BlockAdapter,
472
468
  ):
469
+ from cache_dit.cache_factory.cache_blocks import (
470
+ patch_cached_stats,
471
+ )
472
+
473
473
  cache_manager = block_adapter.pipe._cache_manager
474
474
 
475
475
  for i in range(len(block_adapter.transformer)):
@@ -557,6 +557,10 @@ class CachedAdapter:
557
557
  )
558
558
 
559
559
  # release stats hooks
560
+ from cache_dit.cache_factory.cache_blocks import (
561
+ remove_cached_stats,
562
+ )
563
+
560
564
  cls.release_hooks(
561
565
  pipe_or_adapter,
562
566
  remove_cached_stats,
@@ -1,6 +1,9 @@
1
1
  import torch
2
2
 
3
3
  from cache_dit.cache_factory import ForwardPattern
4
+ from cache_dit.cache_factory.cache_contexts.cache_manager import (
5
+ CacheNotExistError,
6
+ )
4
7
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
5
8
  CachedBlocks_Pattern_Base,
6
9
  )
@@ -16,6 +19,70 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
16
19
  ForwardPattern.Pattern_5,
17
20
  ]
18
21
 
22
+ def call_blocks(
23
+ self,
24
+ hidden_states: torch.Tensor,
25
+ *args,
26
+ **kwargs,
27
+ ):
28
+ # Call all blocks to process the hidden states without cache.
29
+ new_encoder_hidden_states = None
30
+ for block in self.transformer_blocks:
31
+ hidden_states = block(
32
+ hidden_states,
33
+ *args,
34
+ **kwargs,
35
+ )
36
+ hidden_states, new_encoder_hidden_states = self._process_outputs(
37
+ hidden_states
38
+ )
39
+
40
+ return hidden_states, new_encoder_hidden_states
41
+
42
+ @torch.compiler.disable
43
+ def _process_outputs(
44
+ self, hidden_states: torch.Tensor | tuple
45
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
46
+ # Process the outputs for the block.
47
+ new_encoder_hidden_states = None
48
+ if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
49
+ if len(hidden_states) == 2:
50
+ if isinstance(hidden_states[1], torch.Tensor):
51
+ hidden_states, new_encoder_hidden_states = hidden_states
52
+ if not self.forward_pattern.Return_H_First:
53
+ hidden_states, new_encoder_hidden_states = (
54
+ new_encoder_hidden_states,
55
+ hidden_states,
56
+ )
57
+ elif isinstance(hidden_states[0], torch.Tensor):
58
+ hidden_states = hidden_states[0]
59
+ else:
60
+ raise ValueError("Unexpected hidden_states format.")
61
+ else:
62
+ assert (
63
+ len(hidden_states) == 1
64
+ ), f"Unexpected output length: {len(hidden_states)}"
65
+ hidden_states = hidden_states[0]
66
+ return hidden_states, new_encoder_hidden_states
67
+
68
+ @torch.compiler.disable
69
+ def _forward_outputs(
70
+ self,
71
+ hidden_states: torch.Tensor,
72
+ new_encoder_hidden_states: torch.Tensor | None,
73
+ ) -> (
74
+ torch.Tensor
75
+ | tuple[torch.Tensor, torch.Tensor]
76
+ | tuple[torch.Tensor, None]
77
+ ):
78
+ if self.forward_pattern.Return_H_Only:
79
+ return hidden_states
80
+ else:
81
+ if self.forward_pattern.Return_H_First:
82
+ return (hidden_states, new_encoder_hidden_states)
83
+ else:
84
+ return (new_encoder_hidden_states, hidden_states)
85
+
19
86
  def forward(
20
87
  self,
21
88
  hidden_states: torch.Tensor,
@@ -23,8 +90,19 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
23
90
  **kwargs,
24
91
  ):
25
92
  # Use it's own cache context.
26
- self.cache_manager.set_context(self.cache_context)
27
- self._check_cache_params()
93
+ try:
94
+ self.cache_manager.set_context(self.cache_context)
95
+ self._check_cache_params()
96
+ except CacheNotExistError as e:
97
+ logger.warning(f"Cache context not exist: {e}, skip cache.")
98
+ hidden_states, new_encoder_hidden_states = self.call_blocks(
99
+ hidden_states,
100
+ *args,
101
+ **kwargs,
102
+ )
103
+ return self._forward_outputs(
104
+ hidden_states, new_encoder_hidden_states
105
+ )
28
106
 
29
107
  original_hidden_states = hidden_states
30
108
  # Call first `n` blocks to process the hidden states for
@@ -35,7 +113,9 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
35
113
  **kwargs,
36
114
  )
37
115
 
38
- Fn_hidden_states_residual = hidden_states - original_hidden_states
116
+ Fn_hidden_states_residual = hidden_states - original_hidden_states.to(
117
+ hidden_states.device
118
+ )
39
119
  del original_hidden_states
40
120
 
41
121
  self.cache_manager.mark_step_begin()
@@ -147,15 +227,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
147
227
 
148
228
  torch._dynamo.graph_break()
149
229
 
150
- return (
151
- hidden_states
152
- if self.forward_pattern.Return_H_Only
153
- else (
154
- (hidden_states, new_encoder_hidden_states)
155
- if self.forward_pattern.Return_H_First
156
- else (new_encoder_hidden_states, hidden_states)
157
- )
158
- )
230
+ return self._forward_outputs(hidden_states, new_encoder_hidden_states)
159
231
 
160
232
  def call_Fn_blocks(
161
233
  self,
@@ -170,13 +242,9 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
170
242
  *args,
171
243
  **kwargs,
172
244
  )
173
- if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
174
- hidden_states, new_encoder_hidden_states = hidden_states
175
- if not self.forward_pattern.Return_H_First:
176
- hidden_states, new_encoder_hidden_states = (
177
- new_encoder_hidden_states,
178
- hidden_states,
179
- )
245
+ hidden_states, new_encoder_hidden_states = self._process_outputs(
246
+ hidden_states
247
+ )
180
248
 
181
249
  return hidden_states, new_encoder_hidden_states
182
250
 
@@ -194,16 +262,16 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
194
262
  *args,
195
263
  **kwargs,
196
264
  )
197
- if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
198
- hidden_states, new_encoder_hidden_states = hidden_states
199
- if not self.forward_pattern.Return_H_First:
200
- hidden_states, new_encoder_hidden_states = (
201
- new_encoder_hidden_states,
202
- hidden_states,
203
- )
265
+
266
+ hidden_states, new_encoder_hidden_states = self._process_outputs(
267
+ hidden_states
268
+ )
269
+
204
270
  # compute hidden_states residual
205
271
  hidden_states = hidden_states.contiguous()
206
- hidden_states_residual = hidden_states - original_hidden_states
272
+ hidden_states_residual = hidden_states - original_hidden_states.to(
273
+ hidden_states.device
274
+ )
207
275
 
208
276
  return (
209
277
  hidden_states,
@@ -227,12 +295,9 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
227
295
  *args,
228
296
  **kwargs,
229
297
  )
230
- if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
231
- hidden_states, new_encoder_hidden_states = hidden_states
232
- if not self.forward_pattern.Return_H_First:
233
- hidden_states, new_encoder_hidden_states = (
234
- new_encoder_hidden_states,
235
- hidden_states,
236
- )
298
+
299
+ hidden_states, new_encoder_hidden_states = self._process_outputs(
300
+ hidden_states
301
+ )
237
302
 
238
303
  return hidden_states, new_encoder_hidden_states
@@ -1,12 +1,11 @@
1
1
  import inspect
2
- import asyncio
3
2
  import torch
4
3
  import torch.distributed as dist
5
4
 
6
- from typing import List
7
5
  from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
8
6
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
9
7
  CachedContextManager,
8
+ CacheNotExistError,
10
9
  )
11
10
  from cache_dit.cache_factory import ForwardPattern
12
11
  from cache_dit.logger import init_logger
@@ -47,7 +46,6 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
47
46
  self.cache_prefix = cache_prefix
48
47
  self.cache_context = cache_context
49
48
  self.cache_manager = cache_manager
50
- self.pending_tasks: List[asyncio.Task] = []
51
49
 
52
50
  self._check_forward_pattern()
53
51
  logger.info(
@@ -111,6 +109,62 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
111
109
  f"the number of transformer blocks {len(self.transformer_blocks)}"
112
110
  )
113
111
 
112
+ def call_blocks(
113
+ self,
114
+ hidden_states: torch.Tensor,
115
+ encoder_hidden_states: torch.Tensor,
116
+ *args,
117
+ **kwargs,
118
+ ):
119
+ # Call all blocks to process the hidden states without cache.
120
+ for block in self.transformer_blocks:
121
+ hidden_states = block(
122
+ hidden_states,
123
+ encoder_hidden_states,
124
+ *args,
125
+ **kwargs,
126
+ )
127
+ if not isinstance(hidden_states, torch.Tensor):
128
+ hidden_states, encoder_hidden_states = hidden_states
129
+ if not self.forward_pattern.Return_H_First:
130
+ hidden_states, encoder_hidden_states = (
131
+ encoder_hidden_states,
132
+ hidden_states,
133
+ )
134
+
135
+ return hidden_states, encoder_hidden_states
136
+
137
+ @torch.compiler.disable
138
+ def _process_outputs(
139
+ self,
140
+ hidden_states: torch.Tensor | tuple,
141
+ encoder_hidden_states: torch.Tensor | None,
142
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
143
+ if not isinstance(hidden_states, torch.Tensor):
144
+ hidden_states, encoder_hidden_states = hidden_states
145
+ if not self.forward_pattern.Return_H_First:
146
+ hidden_states, encoder_hidden_states = (
147
+ encoder_hidden_states,
148
+ hidden_states,
149
+ )
150
+ return hidden_states, encoder_hidden_states
151
+
152
+ @torch.compiler.disable
153
+ def _forward_outputs(
154
+ self,
155
+ hidden_states: torch.Tensor,
156
+ encoder_hidden_states: torch.Tensor | None,
157
+ ) -> tuple[torch.Tensor, torch.Tensor | None] | torch.Tensor:
158
+ return (
159
+ hidden_states
160
+ if self.forward_pattern.Return_H_Only
161
+ else (
162
+ (hidden_states, encoder_hidden_states)
163
+ if self.forward_pattern.Return_H_First
164
+ else (encoder_hidden_states, hidden_states)
165
+ )
166
+ )
167
+
114
168
  def forward(
115
169
  self,
116
170
  hidden_states: torch.Tensor,
@@ -119,8 +173,19 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
119
173
  **kwargs,
120
174
  ):
121
175
  # Use it's own cache context.
122
- self.cache_manager.set_context(self.cache_context)
123
- self._check_cache_params()
176
+ try:
177
+ self.cache_manager.set_context(self.cache_context)
178
+ self._check_cache_params()
179
+ except CacheNotExistError as e:
180
+ logger.warning(f"Cache context not exist: {e}, skip cache.")
181
+ # Call all blocks to process the hidden states.
182
+ hidden_states, encoder_hidden_states = self.call_blocks(
183
+ hidden_states,
184
+ encoder_hidden_states,
185
+ *args,
186
+ **kwargs,
187
+ )
188
+ return self._forward_outputs(hidden_states, encoder_hidden_states)
124
189
 
125
190
  original_hidden_states = hidden_states
126
191
  # Call first `n` blocks to process the hidden states for
@@ -239,15 +304,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
239
304
  # patch cached stats for blocks or remove it.
240
305
  torch._dynamo.graph_break()
241
306
 
242
- return (
243
- hidden_states
244
- if self.forward_pattern.Return_H_Only
245
- else (
246
- (hidden_states, encoder_hidden_states)
247
- if self.forward_pattern.Return_H_First
248
- else (encoder_hidden_states, hidden_states)
249
- )
250
- )
307
+ return self._forward_outputs(hidden_states, encoder_hidden_states)
251
308
 
252
309
  @torch.compiler.disable
253
310
  def _is_parallelized(self):
@@ -322,13 +379,9 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
322
379
  *args,
323
380
  **kwargs,
324
381
  )
325
- if not isinstance(hidden_states, torch.Tensor):
326
- hidden_states, encoder_hidden_states = hidden_states
327
- if not self.forward_pattern.Return_H_First:
328
- hidden_states, encoder_hidden_states = (
329
- encoder_hidden_states,
330
- hidden_states,
331
- )
382
+ hidden_states, encoder_hidden_states = self._process_outputs(
383
+ hidden_states, encoder_hidden_states
384
+ )
332
385
 
333
386
  return hidden_states, encoder_hidden_states
334
387
 
@@ -348,13 +401,9 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
348
401
  *args,
349
402
  **kwargs,
350
403
  )
351
- if not isinstance(hidden_states, torch.Tensor):
352
- hidden_states, encoder_hidden_states = hidden_states
353
- if not self.forward_pattern.Return_H_First:
354
- hidden_states, encoder_hidden_states = (
355
- encoder_hidden_states,
356
- hidden_states,
357
- )
404
+ hidden_states, encoder_hidden_states = self._process_outputs(
405
+ hidden_states, encoder_hidden_states
406
+ )
358
407
 
359
408
  # compute hidden_states residual
360
409
  hidden_states = hidden_states.contiguous()
@@ -396,12 +445,8 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
396
445
  *args,
397
446
  **kwargs,
398
447
  )
399
- if not isinstance(hidden_states, torch.Tensor):
400
- hidden_states, encoder_hidden_states = hidden_states
401
- if not self.forward_pattern.Return_H_First:
402
- hidden_states, encoder_hidden_states = (
403
- encoder_hidden_states,
404
- hidden_states,
405
- )
448
+ hidden_states, encoder_hidden_states = self._process_outputs(
449
+ hidden_states, encoder_hidden_states
450
+ )
406
451
 
407
452
  return hidden_states, encoder_hidden_states
@@ -11,4 +11,5 @@ from cache_dit.cache_factory.cache_contexts.cache_context import (
11
11
  )
12
12
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
13
13
  CachedContextManager,
14
+ CacheNotExistError,
14
15
  )
@@ -14,6 +14,10 @@ from cache_dit.logger import init_logger
14
14
  logger = init_logger(__name__)
15
15
 
16
16
 
17
+ class CacheNotExistError(Exception):
18
+ pass
19
+
20
+
17
21
  class CachedContextManager:
18
22
  # Each Pipeline should have it's own context manager instance.
19
23
 
@@ -27,16 +31,19 @@ class CachedContextManager:
27
31
  self._cached_context_manager[_context.name] = _context
28
32
  return _context
29
33
 
30
- def set_context(self, cached_context: CachedContext | str):
34
+ def set_context(self, cached_context: CachedContext | str) -> CachedContext:
31
35
  if isinstance(cached_context, CachedContext):
32
36
  self._current_context = cached_context
33
37
  else:
38
+ if cached_context not in self._cached_context_manager:
39
+ raise CacheNotExistError("Context not exist!")
34
40
  self._current_context = self._cached_context_manager[cached_context]
41
+ return self._current_context
35
42
 
36
43
  def get_context(self, name: str = None) -> CachedContext:
37
44
  if name is not None:
38
45
  if name not in self._cached_context_manager:
39
- raise ValueError("Context not exist!")
46
+ raise CacheNotExistError("Context not exist!")
40
47
  return self._cached_context_manager[name]
41
48
  return self._current_context
42
49
 
@@ -38,23 +38,43 @@ def enable_cache(
38
38
  BlockAdapter,
39
39
  ]:
40
40
  r"""
41
- Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
42
- that match the specific Input and Output patterns).
43
-
44
- For a good balance between performance and precision, DBCache is configured by default
45
- with F8B0, 8 warmup steps, and unlimited cached steps.
41
+ The `enable_cache` function serves as a unified caching interface designed to optimize the performance
42
+ of diffusion transformer models by implementing an intelligent caching mechanism known as `DBCache`.
43
+ This API is engineered to be compatible with nearly `all` diffusion transformer architectures that
44
+ feature transformer blocks adhering to standard input-output patterns, eliminating the need for
45
+ architecture-specific modifications.
46
+
47
+ By strategically caching intermediate outputs of transformer blocks during the diffusion process,
48
+ `DBCache` significantly reduces redundant computations without compromising generation quality.
49
+ The caching mechanism works by tracking residual differences between consecutive steps, allowing
50
+ the model to reuse previously computed features when these differences fall below a configurable
51
+ threshold. This approach maintains a balance between computational efficiency and output precision.
52
+
53
+ The default configuration (`F8B0, 8 warmup steps, unlimited cached steps`) is carefully tuned to
54
+ provide an optimal tradeoff for most common use cases. The "F8B0" configuration indicates that
55
+ the first 8 transformer blocks are used to compute stable feature differences, while no final
56
+ blocks are employed for additional fusion. The warmup phase ensures the model establishes
57
+ sufficient feature representation before caching begins, preventing potential degradation of
58
+ output quality.
59
+
60
+ This function seamlessly integrates with both standard diffusion pipelines and custom block
61
+ adapters, making it versatile for various deployment scenarios—from research prototyping to
62
+ production environments where inference speed is critical. By abstracting the complexity of
63
+ caching logic behind a simple interface, it enables developers to enhance model performance
64
+ with minimal code changes.
46
65
 
47
66
  Args:
48
67
  pipe_or_adapter (`DiffusionPipeline` or `BlockAdapter`, *required*):
49
68
  The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
50
69
  For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
51
70
  for the usgae of BlockAdapter.
71
+
52
72
  cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
53
73
  Basic DBCache config for cache context, defaults to BasicCacheConfig(). The configurable params listed belows:
54
74
  Fn_compute_blocks: (`int`, *required*, defaults to 8):
55
- Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
56
- at time step t, enabling the calculation of a more stable L1 diff and delivering more
57
- accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
75
+ Specifies that `DBCache` uses the**first n**Transformer blocks to fit the information at time step t,
76
+ enabling the calculation of a more stable L1 difference and delivering more accurate information
77
+ to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
58
78
  for more details of DBCache.
59
79
  Bn_compute_blocks: (`int`, *required*, defaults to 0):
60
80
  Further fuses approximate information in the **last n** Transformer blocks to enhance
@@ -77,14 +97,18 @@ def enable_cache(
77
97
  and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
78
98
  CogVideoX, HunyuanVideo, Mochi, etc.
79
99
  cfg_compute_first (`bool`, *required*, defaults to False):
80
- Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
100
+ Whether to compute cfg forward first, default is False, meaning:
101
+ 0, 2, 4, ..., -> non-CFG step;
81
102
  1, 3, 5, ... -> CFG step.
82
103
  cfg_diff_compute_separate (`bool`, *required*, defaults to True):
83
- Compute separate diff values for CFG and non-CFG step, default True. If False, we will
84
- use the computed diff from current non-CFG transformer step for current CFG step.
104
+ Whether to compute separate difference values for CFG and non-CFG steps, default is True.
105
+ If False, we will use the computed difference from the current non-CFG transformer step
106
+ for the current CFG step.
107
+
85
108
  calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
86
- Config for calibrator, if calibrator_config is not None, means that user want to use DBCache
87
- with specific calibrator, such as taylorseer, foca, and so on.
109
+ Config for calibrator. If calibrator_config is not None, it means the user wants to use DBCache
110
+ with a specific calibrator, such as taylorseer, foca, and so on.
111
+
88
112
  params_modifiers ('ParamsModifier', *optional*, defaults to None):
89
113
  Modify cache context params for specific blocks. The configurable params listed belows:
90
114
  cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
@@ -93,6 +117,7 @@ def enable_cache(
93
117
  The same as 'calibrator_config' param in cache_dit.enable_cache() interface.
94
118
  **kwargs: (`dict`, *optional*, defaults to {}):
95
119
  The same as 'kwargs' param in cache_dit.enable_cache() interface.
120
+
96
121
  kwargs (`dict`, *optional*, defaults to {})
97
122
  Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
98
123
  for more details.
@@ -10,3 +10,6 @@ from cache_dit.cache_factory.patch_functors.functor_hidream import (
10
10
  from cache_dit.cache_factory.patch_functors.functor_hunyuan_dit import (
11
11
  HunyuanDiTPatchFunctor,
12
12
  )
13
+ from cache_dit.cache_factory.patch_functors.functor_qwen_image_controlnet import (
14
+ QwenImageControlNetPatchFunctor,
15
+ )
@@ -362,9 +362,7 @@ def __patch_transformer_forward__(
362
362
  )
363
363
  if hidden_states_masks is not None:
364
364
  # NOTE: Patched
365
- cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[
366
- self.double_stream_blocks[-1].block._block_id
367
- ]
365
+ cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[0]
368
366
  encoder_attention_mask_ones = torch.ones(
369
367
  (
370
368
  batch_size,