cache-dit 0.2.32__py3-none-any.whl → 0.2.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.32'
32
- __version_tuple__ = version_tuple = (0, 2, 32)
31
+ __version__ = version = '0.2.34'
32
+ __version_tuple__ = version_tuple = (0, 2, 34)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -153,7 +153,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
153
153
  )
154
154
 
155
155
 
156
- @BlockAdapterRegistry.register("LTXVideo")
156
+ @BlockAdapterRegistry.register("LTX")
157
157
  def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
158
158
  from diffusers import LTXVideoTransformer3DModel
159
159
 
@@ -248,7 +248,10 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
248
248
  pipe=pipe,
249
249
  transformer=pipe.transformer,
250
250
  blocks=pipe.transformer.blocks,
251
- forward_pattern=ForwardPattern.Pattern_2,
251
+ # NOTE: Use Pattern_3 instead of Pattern_2 because the
252
+ # encoder_hidden_states will never change in the blocks
253
+ # forward loop.
254
+ forward_pattern=ForwardPattern.Pattern_3,
252
255
  has_separate_cfg=True,
253
256
  **kwargs,
254
257
  )
@@ -285,6 +288,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
285
288
  @BlockAdapterRegistry.register("DiT")
286
289
  def dit_adapter(pipe, **kwargs) -> BlockAdapter:
287
290
  from diffusers import DiTTransformer2DModel
291
+ from cache_dit.cache_factory.patch_functors import DiTPatchFunctor
288
292
 
289
293
  assert isinstance(pipe.transformer, DiTTransformer2DModel)
290
294
  return BlockAdapter(
@@ -292,6 +296,7 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
292
296
  transformer=pipe.transformer,
293
297
  blocks=pipe.transformer.transformer_blocks,
294
298
  forward_pattern=ForwardPattern.Pattern_3,
299
+ patch_functor=DiTPatchFunctor(),
295
300
  **kwargs,
296
301
  )
297
302
 
@@ -331,24 +336,13 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
331
336
 
332
337
 
333
338
  @BlockAdapterRegistry.register("Lumina")
334
- def lumina_adapter(pipe, **kwargs) -> BlockAdapter:
335
- from diffusers import LuminaNextDiT2DModel
336
-
337
- assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
338
- return BlockAdapter(
339
- pipe=pipe,
340
- transformer=pipe.transformer,
341
- blocks=pipe.transformer.layers,
342
- forward_pattern=ForwardPattern.Pattern_3,
343
- **kwargs,
344
- )
345
-
346
-
347
- @BlockAdapterRegistry.register("Lumina2")
348
339
  def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
349
340
  from diffusers import Lumina2Transformer2DModel
341
+ from diffusers import LuminaNextDiT2DModel
350
342
 
351
- assert isinstance(pipe.transformer, Lumina2Transformer2DModel)
343
+ assert isinstance(
344
+ pipe.transformer, (Lumina2Transformer2DModel, LuminaNextDiT2DModel)
345
+ )
352
346
  return BlockAdapter(
353
347
  pipe=pipe,
354
348
  transformer=pipe.transformer,
@@ -386,12 +380,10 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
386
380
  )
387
381
 
388
382
 
389
- @BlockAdapterRegistry.register("Sana", supported=False)
383
+ @BlockAdapterRegistry.register("Sana")
390
384
  def sana_adapter(pipe, **kwargs) -> BlockAdapter:
391
385
  from diffusers import SanaTransformer2DModel
392
386
 
393
- # TODO: fix -> got multiple values for argument 'encoder_hidden_states'
394
-
395
387
  assert isinstance(pipe.transformer, SanaTransformer2DModel)
396
388
  return BlockAdapter(
397
389
  pipe=pipe,
@@ -469,6 +461,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
469
461
  @BlockAdapterRegistry.register("Chroma")
470
462
  def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
471
463
  from diffusers import ChromaTransformer2DModel
464
+ from cache_dit.cache_factory.patch_functors import ChromaPatchFunctor
472
465
 
473
466
  assert isinstance(pipe.transformer, ChromaTransformer2DModel)
474
467
  return BlockAdapter(
@@ -482,6 +475,7 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
482
475
  ForwardPattern.Pattern_1,
483
476
  ForwardPattern.Pattern_3,
484
477
  ],
478
+ patch_functor=ChromaPatchFunctor(),
485
479
  has_separate_cfg=True,
486
480
  **kwargs,
487
481
  )
@@ -16,8 +16,52 @@ logger = init_logger(__name__)
16
16
 
17
17
 
18
18
  class ParamsModifier:
19
- def __init__(self, **kwargs):
20
- self._context_kwargs = kwargs.copy()
19
+ def __init__(
20
+ self,
21
+ # Cache context kwargs
22
+ Fn_compute_blocks: Optional[int] = None,
23
+ Bn_compute_blocks: Optional[int] = None,
24
+ max_warmup_steps: Optional[int] = None,
25
+ max_cached_steps: Optional[int] = None,
26
+ max_continuous_cached_steps: Optional[int] = None,
27
+ residual_diff_threshold: Optional[float] = None,
28
+ # Cache CFG or not
29
+ enable_separate_cfg: Optional[bool] = None,
30
+ cfg_compute_first: Optional[bool] = None,
31
+ cfg_diff_compute_separate: Optional[bool] = None,
32
+ # Hybird TaylorSeer
33
+ enable_taylorseer: Optional[bool] = None,
34
+ enable_encoder_taylorseer: Optional[bool] = None,
35
+ taylorseer_cache_type: Optional[str] = None,
36
+ taylorseer_order: Optional[int] = None,
37
+ **other_cache_context_kwargs,
38
+ ):
39
+ self._context_kwargs = other_cache_context_kwargs.copy()
40
+ self._maybe_update_param("Fn_compute_blocks", Fn_compute_blocks)
41
+ self._maybe_update_param("Bn_compute_blocks", Bn_compute_blocks)
42
+ self._maybe_update_param("max_warmup_steps", max_warmup_steps)
43
+ self._maybe_update_param("max_cached_steps", max_cached_steps)
44
+ self._maybe_update_param(
45
+ "max_continuous_cached_steps", max_continuous_cached_steps
46
+ )
47
+ self._maybe_update_param(
48
+ "residual_diff_threshold", residual_diff_threshold
49
+ )
50
+ self._maybe_update_param("enable_separate_cfg", enable_separate_cfg)
51
+ self._maybe_update_param("cfg_compute_first", cfg_compute_first)
52
+ self._maybe_update_param(
53
+ "cfg_diff_compute_separate", cfg_diff_compute_separate
54
+ )
55
+ self._maybe_update_param("enable_taylorseer", enable_taylorseer)
56
+ self._maybe_update_param(
57
+ "enable_encoder_taylorseer", enable_encoder_taylorseer
58
+ )
59
+ self._maybe_update_param("taylorseer_cache_type", taylorseer_cache_type)
60
+ self._maybe_update_param("taylorseer_order", taylorseer_order)
61
+
62
+ def _maybe_update_param(self, key: str, value: Any):
63
+ if value is not None:
64
+ self._context_kwargs[key] = value
21
65
 
22
66
 
23
67
  @dataclasses.dataclass
@@ -579,7 +623,7 @@ class BlockAdapter:
579
623
  assert isinstance(adapter[0], torch.nn.Module)
580
624
  return getattr(adapter[0], "_is_cached", False)
581
625
  else:
582
- raise TypeError(f"Can't check this type: {type(adapter)}!")
626
+ return getattr(adapter, "_is_cached", False)
583
627
 
584
628
  @classmethod
585
629
  def nested_depth(cls, obj: Any):
@@ -10,13 +10,14 @@ logger = init_logger(__name__)
10
10
 
11
11
  class BlockAdapterRegistry:
12
12
  _adapters: Dict[str, Callable[..., BlockAdapter]] = {}
13
- _predefined_adapters_has_spearate_cfg: List[str] = [
13
+ _predefined_adapters_has_separate_cfg: List[str] = [
14
14
  "QwenImage",
15
15
  "Wan",
16
16
  "CogView4",
17
17
  "Cosmos",
18
18
  "SkyReelsV2",
19
19
  "Chroma",
20
+ "Lumina2",
20
21
  ]
21
22
 
22
23
  @classmethod
@@ -68,7 +69,7 @@ class BlockAdapterRegistry:
68
69
  return True
69
70
 
70
71
  pipe_cls_name = pipe_or_adapter.__class__.__name__
71
- for name in cls._predefined_adapters_has_spearate_cfg:
72
+ for name in cls._predefined_adapters_has_separate_cfg:
72
73
  if pipe_cls_name.startswith(name):
73
74
  return True
74
75
 
@@ -114,27 +114,27 @@ class CachedAdapter:
114
114
  **cache_context_kwargs,
115
115
  ):
116
116
  # Check cache_context_kwargs
117
- if cache_context_kwargs["enable_spearate_cfg"] is None:
117
+ if cache_context_kwargs["enable_separate_cfg"] is None:
118
118
  # Check cfg for some specific case if users don't set it as True
119
119
  if BlockAdapterRegistry.has_separate_cfg(block_adapter):
120
- cache_context_kwargs["enable_spearate_cfg"] = True
120
+ cache_context_kwargs["enable_separate_cfg"] = True
121
121
  logger.info(
122
- f"Use custom 'enable_spearate_cfg' from BlockAdapter: True. "
122
+ f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
123
123
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
124
124
  )
125
125
  else:
126
- cache_context_kwargs["enable_spearate_cfg"] = (
126
+ cache_context_kwargs["enable_separate_cfg"] = (
127
127
  BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
128
128
  )
129
129
  logger.info(
130
- f"Use default 'enable_spearate_cfg' from block adapter "
131
- f"register: {cache_context_kwargs['enable_spearate_cfg']}, "
130
+ f"Use default 'enable_separate_cfg' from block adapter "
131
+ f"register: {cache_context_kwargs['enable_separate_cfg']}, "
132
132
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
133
133
  )
134
134
  else:
135
135
  logger.info(
136
- f"Use custom 'enable_spearate_cfg' from cache context "
137
- f"kwargs: {cache_context_kwargs['enable_spearate_cfg']}. "
136
+ f"Use custom 'enable_separate_cfg' from cache context "
137
+ f"kwargs: {cache_context_kwargs['enable_separate_cfg']}. "
138
138
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
139
139
  )
140
140
 
@@ -1,6 +1,5 @@
1
1
  import torch
2
2
 
3
- from typing import Dict, Any
4
3
  from cache_dit.cache_factory import ForwardPattern
5
4
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
6
5
  CachedBlocks_Pattern_Base,
@@ -24,14 +23,12 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
24
23
  **kwargs,
25
24
  ):
26
25
  # Use it's own cache context.
27
- self.cache_manager.set_context(
28
- self.cache_context,
29
- )
26
+ self.cache_manager.set_context(self.cache_context)
27
+ self._check_cache_params()
30
28
 
31
29
  original_hidden_states = hidden_states
32
30
  # Call first `n` blocks to process the hidden states for
33
31
  # more stable diff calculation.
34
- # encoder_hidden_states: None Pattern 3, else 4, 5
35
32
  hidden_states, new_encoder_hidden_states = self.call_Fn_blocks(
36
33
  hidden_states,
37
34
  *args,
@@ -109,10 +106,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
109
106
  *args,
110
107
  **kwargs,
111
108
  )
112
- if new_encoder_hidden_states is not None:
113
- new_encoder_hidden_states_residual = (
114
- new_encoder_hidden_states - old_encoder_hidden_states
115
- )
109
+
116
110
  torch._dynamo.graph_break()
117
111
  if self.cache_manager.is_cache_residual():
118
112
  self.cache_manager.set_Bn_buffer(
@@ -125,6 +119,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
125
119
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
126
120
  )
127
121
 
122
+ if new_encoder_hidden_states is not None:
123
+ new_encoder_hidden_states_residual = (
124
+ new_encoder_hidden_states - old_encoder_hidden_states
125
+ )
128
126
  if self.cache_manager.is_encoder_cache_residual():
129
127
  if new_encoder_hidden_states is not None:
130
128
  self.cache_manager.set_Bn_encoder_buffer(
@@ -159,27 +157,12 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
159
157
  )
160
158
  )
161
159
 
162
- @torch.compiler.disable
163
- def maybe_update_kwargs(
164
- self, encoder_hidden_states, kwargs: Dict[str, Any]
165
- ) -> Dict[str, Any]:
166
- # if "encoder_hidden_states" in kwargs:
167
- # kwargs["encoder_hidden_states"] = encoder_hidden_states
168
- # return kwargs
169
- return kwargs
170
-
171
160
  def call_Fn_blocks(
172
161
  self,
173
162
  hidden_states: torch.Tensor,
174
163
  *args,
175
164
  **kwargs,
176
165
  ):
177
- assert self.cache_manager.Fn_compute_blocks() <= len(
178
- self.transformer_blocks
179
- ), (
180
- f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
181
- f"the number of transformer blocks {len(self.transformer_blocks)}"
182
- )
183
166
  new_encoder_hidden_states = None
184
167
  for block in self._Fn_blocks():
185
168
  hidden_states = block(
@@ -194,10 +177,6 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
194
177
  new_encoder_hidden_states,
195
178
  hidden_states,
196
179
  )
197
- kwargs = self.maybe_update_kwargs(
198
- new_encoder_hidden_states,
199
- kwargs,
200
- )
201
180
 
202
181
  return hidden_states, new_encoder_hidden_states
203
182
 
@@ -222,11 +201,6 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
222
201
  new_encoder_hidden_states,
223
202
  hidden_states,
224
203
  )
225
- kwargs = self.maybe_update_kwargs(
226
- new_encoder_hidden_states,
227
- kwargs,
228
- )
229
-
230
204
  # compute hidden_states residual
231
205
  hidden_states = hidden_states.contiguous()
232
206
  hidden_states_residual = hidden_states - original_hidden_states
@@ -243,35 +217,22 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
243
217
  *args,
244
218
  **kwargs,
245
219
  ):
246
- assert self.cache_manager.Bn_compute_blocks() <= len(
247
- self.transformer_blocks
248
- ), (
249
- f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
250
- f"the number of transformer blocks {len(self.transformer_blocks)}"
251
- )
252
- if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
253
- raise ValueError(
254
- f"Bn_compute_blocks_ids is not support for "
255
- f"patterns: {self._supported_patterns}."
220
+ new_encoder_hidden_states = None
221
+ if self.cache_manager.Bn_compute_blocks() == 0:
222
+ return hidden_states, new_encoder_hidden_states
223
+
224
+ for block in self._Bn_blocks():
225
+ hidden_states = block(
226
+ hidden_states,
227
+ *args,
228
+ **kwargs,
256
229
  )
257
- else:
258
- # Compute all Bn blocks if no specific Bn compute blocks ids are set.
259
- for block in self._Bn_blocks():
260
- hidden_states = block(
261
- hidden_states,
262
- *args,
263
- **kwargs,
264
- )
265
- if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
266
- hidden_states, new_encoder_hidden_states = hidden_states
267
- if not self.forward_pattern.Return_H_First:
268
- hidden_states, new_encoder_hidden_states = (
269
- new_encoder_hidden_states,
270
- hidden_states,
271
- )
272
- kwargs = self.maybe_update_kwargs(
273
- new_encoder_hidden_states,
274
- kwargs,
275
- )
230
+ if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
231
+ hidden_states, new_encoder_hidden_states = hidden_states
232
+ if not self.forward_pattern.Return_H_First:
233
+ hidden_states, new_encoder_hidden_states = (
234
+ new_encoder_hidden_states,
235
+ hidden_states,
236
+ )
276
237
 
277
238
  return hidden_states, new_encoder_hidden_states
@@ -93,6 +93,21 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
93
93
  required_param in forward_parameters
94
94
  ), f"The input parameters must contains: {required_param}."
95
95
 
96
+ @torch.compiler.disable
97
+ def _check_cache_params(self):
98
+ assert self.cache_manager.Fn_compute_blocks() <= len(
99
+ self.transformer_blocks
100
+ ), (
101
+ f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
102
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
103
+ )
104
+ assert self.cache_manager.Bn_compute_blocks() <= len(
105
+ self.transformer_blocks
106
+ ), (
107
+ f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
108
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
109
+ )
110
+
96
111
  def forward(
97
112
  self,
98
113
  hidden_states: torch.Tensor,
@@ -100,7 +115,9 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
100
115
  *args,
101
116
  **kwargs,
102
117
  ):
118
+ # Use it's own cache context.
103
119
  self.cache_manager.set_context(self.cache_context)
120
+ self._check_cache_params()
104
121
 
105
122
  original_hidden_states = hidden_states
106
123
  # Call first `n` blocks to process the hidden states for
@@ -191,18 +208,17 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
191
208
  prefix=f"{self.cache_prefix}_Bn_residual",
192
209
  )
193
210
  else:
194
- # TaylorSeer
195
211
  self.cache_manager.set_Bn_buffer(
196
212
  hidden_states,
197
213
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
198
214
  )
215
+
199
216
  if self.cache_manager.is_encoder_cache_residual():
200
217
  self.cache_manager.set_Bn_encoder_buffer(
201
218
  encoder_hidden_states_residual,
202
219
  prefix=f"{self.cache_prefix}_Bn_residual",
203
220
  )
204
221
  else:
205
- # TaylorSeer
206
222
  self.cache_manager.set_Bn_encoder_buffer(
207
223
  encoder_hidden_states,
208
224
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
@@ -296,12 +312,6 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
296
312
  *args,
297
313
  **kwargs,
298
314
  ):
299
- assert self.cache_manager.Fn_compute_blocks() <= len(
300
- self.transformer_blocks
301
- ), (
302
- f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
303
- f"the number of transformer blocks {len(self.transformer_blocks)}"
304
- )
305
315
  for block in self._Fn_blocks():
306
316
  hidden_states = block(
307
317
  hidden_states,
@@ -366,28 +376,17 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
366
376
  encoder_hidden_states_residual,
367
377
  )
368
378
 
369
- def _compute_or_cache_block(
379
+ def call_Bn_blocks(
370
380
  self,
371
- # Block index in the transformer blocks
372
- # Bn: 8, block_id should be in [0, 8)
373
- block_id: int,
374
- # Below are the inputs to the block
375
- block, # The transformer block to be executed
376
381
  hidden_states: torch.Tensor,
377
382
  encoder_hidden_states: torch.Tensor,
378
383
  *args,
379
384
  **kwargs,
380
385
  ):
381
- # Helper function for `call_Bn_blocks`
382
- # Skip the blocks by reuse residual cache if they are not
383
- # in the Bn_compute_blocks_ids. NOTE: We should only skip
384
- # the specific Bn blocks in cache steps. Compute the block
385
- # and cache the residuals in non-cache steps.
386
-
387
- # Normal steps: Compute the block and cache the residuals.
388
- if not self._is_in_cache_step():
389
- Bn_i_original_hidden_states = hidden_states
390
- Bn_i_original_encoder_hidden_states = encoder_hidden_states
386
+ if self.cache_manager.Bn_compute_blocks() == 0:
387
+ return hidden_states, encoder_hidden_states
388
+
389
+ for block in self._Bn_blocks():
391
390
  hidden_states = block(
392
391
  hidden_states,
393
392
  encoder_hidden_states,
@@ -401,149 +400,5 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
401
400
  encoder_hidden_states,
402
401
  hidden_states,
403
402
  )
404
- # Cache residuals for the non-compute Bn blocks for
405
- # subsequent cache steps.
406
- if block_id not in self.cache_manager.Bn_compute_blocks_ids():
407
- Bn_i_hidden_states_residual = (
408
- hidden_states - Bn_i_original_hidden_states
409
- )
410
- if (
411
- encoder_hidden_states is not None
412
- and Bn_i_original_encoder_hidden_states is not None
413
- ):
414
- Bn_i_encoder_hidden_states_residual = (
415
- encoder_hidden_states
416
- - Bn_i_original_encoder_hidden_states
417
- )
418
- else:
419
- Bn_i_encoder_hidden_states_residual = None
420
-
421
- # Save original_hidden_states for diff calculation.
422
- self.cache_manager.set_Bn_buffer(
423
- Bn_i_original_hidden_states,
424
- prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
425
- )
426
- self.cache_manager.set_Bn_encoder_buffer(
427
- Bn_i_original_encoder_hidden_states,
428
- prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
429
- )
430
-
431
- self.cache_manager.set_Bn_buffer(
432
- Bn_i_hidden_states_residual,
433
- prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
434
- )
435
- self.cache_manager.set_Bn_encoder_buffer(
436
- Bn_i_encoder_hidden_states_residual,
437
- prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
438
- )
439
- del Bn_i_hidden_states_residual
440
- del Bn_i_encoder_hidden_states_residual
441
-
442
- del Bn_i_original_hidden_states
443
- del Bn_i_original_encoder_hidden_states
444
-
445
- else:
446
- # Cache steps: Reuse the cached residuals.
447
- # Check if the block is in the Bn_compute_blocks_ids.
448
- if block_id in self.cache_manager.Bn_compute_blocks_ids():
449
- hidden_states = block(
450
- hidden_states,
451
- encoder_hidden_states,
452
- *args,
453
- **kwargs,
454
- )
455
- if not isinstance(hidden_states, torch.Tensor):
456
- hidden_states, encoder_hidden_states = hidden_states
457
- if not self.forward_pattern.Return_H_First:
458
- hidden_states, encoder_hidden_states = (
459
- encoder_hidden_states,
460
- hidden_states,
461
- )
462
- else:
463
- # Skip the block if it is not in the Bn_compute_blocks_ids.
464
- # Use the cached residuals instead.
465
- # Check if can use the cached residuals.
466
- if self.cache_manager.can_cache(
467
- hidden_states, # curr step
468
- parallelized=self._is_parallelized(),
469
- threshold=self.cache_manager.non_compute_blocks_diff_threshold(),
470
- prefix=f"{self.cache_prefix}_Bn_{block_id}_original", # prev step
471
- ):
472
- hidden_states, encoder_hidden_states = (
473
- self.cache_manager.apply_cache(
474
- hidden_states,
475
- encoder_hidden_states,
476
- prefix=(
477
- f"{self.cache_prefix}_Bn_{block_id}_residual"
478
- if self.cache_manager.is_cache_residual()
479
- else f"{self.cache_prefix}_Bn_{block_id}_original"
480
- ),
481
- encoder_prefix=(
482
- f"{self.cache_prefix}_Bn_{block_id}_residual"
483
- if self.cache_manager.is_encoder_cache_residual()
484
- else f"{self.cache_prefix}_Bn_{block_id}_original"
485
- ),
486
- )
487
- )
488
- else:
489
- hidden_states = block(
490
- hidden_states,
491
- encoder_hidden_states,
492
- *args,
493
- **kwargs,
494
- )
495
- if not isinstance(hidden_states, torch.Tensor):
496
- hidden_states, encoder_hidden_states = hidden_states
497
- if not self.forward_pattern.Return_H_First:
498
- hidden_states, encoder_hidden_states = (
499
- encoder_hidden_states,
500
- hidden_states,
501
- )
502
- return hidden_states, encoder_hidden_states
503
-
504
- def call_Bn_blocks(
505
- self,
506
- hidden_states: torch.Tensor,
507
- encoder_hidden_states: torch.Tensor,
508
- *args,
509
- **kwargs,
510
- ):
511
- if self.cache_manager.Bn_compute_blocks() == 0:
512
- return hidden_states, encoder_hidden_states
513
-
514
- assert self.cache_manager.Bn_compute_blocks() <= len(
515
- self.transformer_blocks
516
- ), (
517
- f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
518
- f"the number of transformer blocks {len(self.transformer_blocks)}"
519
- )
520
- if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
521
- for i, block in enumerate(self._Bn_blocks()):
522
- hidden_states, encoder_hidden_states = (
523
- self._compute_or_cache_block(
524
- i,
525
- block,
526
- hidden_states,
527
- encoder_hidden_states,
528
- *args,
529
- **kwargs,
530
- )
531
- )
532
- else:
533
- # Compute all Bn blocks if no specific Bn compute blocks ids are set.
534
- for block in self._Bn_blocks():
535
- hidden_states = block(
536
- hidden_states,
537
- encoder_hidden_states,
538
- *args,
539
- **kwargs,
540
- )
541
- if not isinstance(hidden_states, torch.Tensor):
542
- hidden_states, encoder_hidden_states = hidden_states
543
- if not self.forward_pattern.Return_H_First:
544
- hidden_states, encoder_hidden_states = (
545
- encoder_hidden_states,
546
- hidden_states,
547
- )
548
403
 
549
404
  return hidden_states, encoder_hidden_states