cache-dit 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -4,7 +4,7 @@ import unittest
4
4
  import functools
5
5
 
6
6
  from contextlib import ExitStack
7
- from typing import Dict, List, Tuple, Any
7
+ from typing import Dict, List, Tuple, Any, Union, Callable
8
8
 
9
9
  from diffusers import DiffusionPipeline
10
10
 
@@ -14,7 +14,10 @@ from cache_dit.cache_factory import ParamsModifier
14
14
  from cache_dit.cache_factory import BlockAdapterRegistry
15
15
  from cache_dit.cache_factory import CachedContextManager
16
16
  from cache_dit.cache_factory import CachedBlocks
17
-
17
+ from cache_dit.cache_factory.cache_blocks.utils import (
18
+ patch_cached_stats,
19
+ remove_cached_stats,
20
+ )
18
21
  from cache_dit.logger import init_logger
19
22
 
20
23
  logger = init_logger(__name__)
@@ -29,36 +32,45 @@ class CachedAdapter:
29
32
  @classmethod
30
33
  def apply(
31
34
  cls,
32
- pipe: DiffusionPipeline = None,
33
- block_adapter: BlockAdapter = None,
35
+ pipe_or_adapter: Union[
36
+ DiffusionPipeline,
37
+ BlockAdapter,
38
+ ],
34
39
  **cache_context_kwargs,
35
- ) -> DiffusionPipeline:
40
+ ) -> Union[
41
+ DiffusionPipeline,
42
+ BlockAdapter,
43
+ ]:
36
44
  assert (
37
- pipe is not None or block_adapter is not None
45
+ pipe_or_adapter is not None
38
46
  ), "pipe or block_adapter can not both None!"
39
47
 
40
- if pipe is not None:
41
- if BlockAdapterRegistry.is_supported(pipe):
48
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
49
+ if BlockAdapterRegistry.is_supported(pipe_or_adapter):
42
50
  logger.info(
43
- f"{pipe.__class__.__name__} is officially supported by cache-dit. "
44
- "Use it's pre-defined BlockAdapter directly!"
51
+ f"{pipe_or_adapter.__class__.__name__} is officially "
52
+ "supported by cache-dit. Use it's pre-defined BlockAdapter "
53
+ "directly!"
54
+ )
55
+ block_adapter = BlockAdapterRegistry.get_adapter(
56
+ pipe_or_adapter
45
57
  )
46
- block_adapter = BlockAdapterRegistry.get_adapter(pipe)
47
58
  return cls.cachify(
48
59
  block_adapter,
49
60
  **cache_context_kwargs,
50
- )
61
+ ).pipe
51
62
  else:
52
63
  raise ValueError(
53
- f"{pipe.__class__.__name__} is not officially supported "
64
+ f"{pipe_or_adapter.__class__.__name__} is not officially supported "
54
65
  "by cache-dit, please set BlockAdapter instead!"
55
66
  )
56
67
  else:
68
+ assert isinstance(pipe_or_adapter, BlockAdapter)
57
69
  logger.info(
58
- "Adapting cache acceleration using custom BlockAdapter!"
70
+ "Adapting Cache Acceleration using custom BlockAdapter!"
59
71
  )
60
72
  return cls.cachify(
61
- block_adapter,
73
+ pipe_or_adapter,
62
74
  **cache_context_kwargs,
63
75
  )
64
76
 
@@ -67,7 +79,7 @@ class CachedAdapter:
67
79
  cls,
68
80
  block_adapter: BlockAdapter,
69
81
  **cache_context_kwargs,
70
- ) -> DiffusionPipeline:
82
+ ) -> BlockAdapter:
71
83
 
72
84
  if block_adapter.auto:
73
85
  block_adapter = BlockAdapter.auto_block_adapter(
@@ -79,7 +91,7 @@ class CachedAdapter:
79
91
  # 0. Must normalize block_adapter before apply cache
80
92
  block_adapter = BlockAdapter.normalize(block_adapter)
81
93
  if BlockAdapter.is_cached(block_adapter):
82
- return block_adapter.pipe
94
+ return block_adapter
83
95
 
84
96
  # 1. Apply cache on pipeline: wrap cache context, must
85
97
  # call create_context before mock_blocks.
@@ -93,53 +105,36 @@ class CachedAdapter:
93
105
  block_adapter,
94
106
  )
95
107
 
96
- return block_adapter.pipe
108
+ return block_adapter
97
109
 
98
110
  @classmethod
99
- def patch_params(
111
+ def check_context_kwargs(
100
112
  cls,
101
113
  block_adapter: BlockAdapter,
102
- contexts_kwargs: List[Dict],
114
+ **cache_context_kwargs,
103
115
  ):
104
- block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
105
-
106
- params_shift = 0
107
- for i in range(len(block_adapter.transformer)):
108
-
109
- block_adapter.transformer[i]._forward_pattern = (
110
- block_adapter.forward_pattern
111
- )
112
- block_adapter.transformer[i]._has_separate_cfg = (
113
- block_adapter.has_separate_cfg
114
- )
115
- block_adapter.transformer[i]._cache_context_kwargs = (
116
- contexts_kwargs[params_shift]
117
- )
118
-
119
- blocks = block_adapter.blocks[i]
120
- for j in range(len(blocks)):
121
- blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
122
- blocks[j]._cache_context_kwargs = contexts_kwargs[
123
- params_shift + j
124
- ]
125
-
126
- params_shift += len(blocks)
127
-
128
- @classmethod
129
- def check_context_kwargs(cls, pipe, **cache_context_kwargs):
130
116
  # Check cache_context_kwargs
131
117
  if not cache_context_kwargs["enable_spearate_cfg"]:
132
118
  # Check cfg for some specific case if users don't set it as True
133
- cache_context_kwargs["enable_spearate_cfg"] = (
134
- BlockAdapterRegistry.has_separate_cfg(pipe)
135
- )
136
- logger.info(
137
- f"Use default 'enable_spearate_cfg': "
138
- f"{cache_context_kwargs['enable_spearate_cfg']}, "
139
- f"Pipeline: {pipe.__class__.__name__}."
140
- )
119
+ if BlockAdapterRegistry.has_separate_cfg(block_adapter):
120
+ cache_context_kwargs["enable_spearate_cfg"] = True
121
+ logger.info(
122
+ f"Use custom 'enable_spearate_cfg' from BlockAdapter: True. "
123
+ f"Pipeline: {block_adapter.pipe.__class__.__name__}."
124
+ )
125
+ else:
126
+ cache_context_kwargs["enable_spearate_cfg"] = (
127
+ BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
128
+ )
129
+ logger.info(
130
+ f"Use default 'enable_spearate_cfg' from block adapter "
131
+ f"register: {cache_context_kwargs['enable_spearate_cfg']}, "
132
+ f"Pipeline: {block_adapter.pipe.__class__.__name__}."
133
+ )
141
134
 
142
- if cache_type := cache_context_kwargs.pop("cache_type", None):
135
+ if (
136
+ cache_type := cache_context_kwargs.pop("cache_type", None)
137
+ ) is not None:
143
138
  assert (
144
139
  cache_type == CacheType.DBCache
145
140
  ), "Custom cache setting only support for DBCache now!"
@@ -160,8 +155,7 @@ class CachedAdapter:
160
155
 
161
156
  # Check cache_context_kwargs
162
157
  cache_context_kwargs = cls.check_context_kwargs(
163
- block_adapter.pipe,
164
- **cache_context_kwargs,
158
+ block_adapter, **cache_context_kwargs
165
159
  )
166
160
  # Apply cache on pipeline: wrap cache context
167
161
  pipe_cls_name = block_adapter.pipe.__class__.__name__
@@ -197,14 +191,14 @@ class CachedAdapter:
197
191
  )
198
192
  )
199
193
  outputs = original_call(self, *args, **kwargs)
200
- cls.patch_stats(block_adapter)
194
+ cls.apply_stats_hooks(block_adapter)
201
195
  return outputs
202
196
 
203
197
  block_adapter.pipe.__class__.__call__ = new_call
204
198
  block_adapter.pipe.__class__._original_call = original_call
205
199
  block_adapter.pipe.__class__._is_cached = True
206
200
 
207
- cls.patch_params(block_adapter, contexts_kwargs)
201
+ cls.apply_params_hooks(block_adapter, contexts_kwargs)
208
202
 
209
203
  return block_adapter.pipe
210
204
 
@@ -248,33 +242,6 @@ class CachedAdapter:
248
242
 
249
243
  return flatten_contexts, contexts_kwargs
250
244
 
251
- @classmethod
252
- def patch_stats(
253
- cls,
254
- block_adapter: BlockAdapter,
255
- ):
256
- from cache_dit.cache_factory.cache_blocks.utils import (
257
- patch_cached_stats,
258
- )
259
-
260
- cache_manager = block_adapter.pipe._cache_manager
261
-
262
- for i in range(len(block_adapter.transformer)):
263
- patch_cached_stats(
264
- block_adapter.transformer[i],
265
- cache_context=block_adapter.unique_blocks_name[i][-1],
266
- cache_manager=cache_manager,
267
- )
268
- for blocks, unique_name in zip(
269
- block_adapter.blocks[i],
270
- block_adapter.unique_blocks_name[i],
271
- ):
272
- patch_cached_stats(
273
- blocks,
274
- cache_context=unique_name,
275
- cache_manager=cache_manager,
276
- )
277
-
278
245
  @classmethod
279
246
  def mock_blocks(
280
247
  cls,
@@ -392,3 +359,159 @@ class CachedAdapter:
392
359
  total_cached_blocks.append(cached_blocks_bind_context)
393
360
 
394
361
  return total_cached_blocks
362
+
363
+ @classmethod
364
+ def apply_params_hooks(
365
+ cls,
366
+ block_adapter: BlockAdapter,
367
+ contexts_kwargs: List[Dict],
368
+ ):
369
+ block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
370
+
371
+ params_shift = 0
372
+ for i in range(len(block_adapter.transformer)):
373
+
374
+ block_adapter.transformer[i]._forward_pattern = (
375
+ block_adapter.forward_pattern
376
+ )
377
+ block_adapter.transformer[i]._has_separate_cfg = (
378
+ block_adapter.has_separate_cfg
379
+ )
380
+ block_adapter.transformer[i]._cache_context_kwargs = (
381
+ contexts_kwargs[params_shift]
382
+ )
383
+
384
+ blocks = block_adapter.blocks[i]
385
+ for j in range(len(blocks)):
386
+ blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
387
+ blocks[j]._cache_context_kwargs = contexts_kwargs[
388
+ params_shift + j
389
+ ]
390
+
391
+ params_shift += len(blocks)
392
+
393
+ @classmethod
394
+ def apply_stats_hooks(
395
+ cls,
396
+ block_adapter: BlockAdapter,
397
+ ):
398
+ cache_manager = block_adapter.pipe._cache_manager
399
+
400
+ for i in range(len(block_adapter.transformer)):
401
+ patch_cached_stats(
402
+ block_adapter.transformer[i],
403
+ cache_context=block_adapter.unique_blocks_name[i][-1],
404
+ cache_manager=cache_manager,
405
+ )
406
+ for blocks, unique_name in zip(
407
+ block_adapter.blocks[i],
408
+ block_adapter.unique_blocks_name[i],
409
+ ):
410
+ patch_cached_stats(
411
+ blocks,
412
+ cache_context=unique_name,
413
+ cache_manager=cache_manager,
414
+ )
415
+
416
+ @classmethod
417
+ def maybe_release_hooks(
418
+ cls,
419
+ pipe_or_adapter: Union[
420
+ DiffusionPipeline,
421
+ BlockAdapter,
422
+ ],
423
+ ):
424
+ # release model hooks
425
+ def _release_blocks_hooks(blocks):
426
+ return
427
+
428
+ def _release_transformer_hooks(transformer):
429
+ if hasattr(transformer, "_original_forward"):
430
+ original_forward = transformer._original_forward
431
+ transformer.forward = original_forward.__get__(transformer)
432
+ del transformer._original_forward
433
+ if hasattr(transformer, "_is_cached"):
434
+ del transformer._is_cached
435
+
436
+ def _release_pipeline_hooks(pipe):
437
+ if hasattr(pipe, "_original_call"):
438
+ original_call = pipe.__class__._original_call
439
+ pipe.__class__.__call__ = original_call
440
+ del pipe.__class__._original_call
441
+ if hasattr(pipe, "_cache_manager"):
442
+ cache_manager = pipe._cache_manager
443
+ if isinstance(cache_manager, CachedContextManager):
444
+ cache_manager.clear_contexts()
445
+ del pipe._cache_manager
446
+ if hasattr(pipe, "_is_cached"):
447
+ del pipe.__class__._is_cached
448
+
449
+ cls.release_hooks(
450
+ pipe_or_adapter,
451
+ _release_blocks_hooks,
452
+ _release_transformer_hooks,
453
+ _release_pipeline_hooks,
454
+ )
455
+
456
+ # release params hooks
457
+ def _release_blocks_params(blocks):
458
+ if hasattr(blocks, "_forward_pattern"):
459
+ del blocks._forward_pattern
460
+ if hasattr(blocks, "_cache_context_kwargs"):
461
+ del blocks._cache_context_kwargs
462
+
463
+ def _release_transformer_params(transformer):
464
+ if hasattr(transformer, "_forward_pattern"):
465
+ del transformer._forward_pattern
466
+ if hasattr(transformer, "_has_separate_cfg"):
467
+ del transformer._has_separate_cfg
468
+ if hasattr(transformer, "_cache_context_kwargs"):
469
+ del transformer._cache_context_kwargs
470
+ for blocks in BlockAdapter.find_blocks(transformer):
471
+ _release_blocks_params(blocks)
472
+
473
+ def _release_pipeline_params(pipe):
474
+ if hasattr(pipe, "_cache_context_kwargs"):
475
+ del pipe._cache_context_kwargs
476
+
477
+ cls.release_hooks(
478
+ pipe_or_adapter,
479
+ _release_blocks_params,
480
+ _release_transformer_params,
481
+ _release_pipeline_params,
482
+ )
483
+
484
+ # release stats hooks
485
+ cls.release_hooks(
486
+ pipe_or_adapter,
487
+ remove_cached_stats,
488
+ remove_cached_stats,
489
+ remove_cached_stats,
490
+ )
491
+
492
+ @classmethod
493
+ def release_hooks(
494
+ cls,
495
+ pipe_or_adapter: Union[
496
+ DiffusionPipeline,
497
+ BlockAdapter,
498
+ ],
499
+ _release_blocks: Callable,
500
+ _release_transformer: Callable,
501
+ _release_pipeline: Callable,
502
+ ):
503
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
504
+ pipe = pipe_or_adapter
505
+ _release_pipeline(pipe)
506
+ if hasattr(pipe, "transformer"):
507
+ _release_transformer(pipe.transformer)
508
+ if hasattr(pipe, "transformer_2"): # Wan 2.2
509
+ _release_transformer(pipe.transformer_2)
510
+ elif isinstance(pipe_or_adapter, BlockAdapter):
511
+ adapter = pipe_or_adapter
512
+ BlockAdapter.assert_normalized(adapter)
513
+ _release_pipeline(adapter.pipe)
514
+ for transformer in BlockAdapter.flatten(adapter.transformer):
515
+ _release_transformer(transformer)
516
+ for blocks in BlockAdapter.flatten(adapter.blocks):
517
+ _release_blocks(blocks)
@@ -1,5 +1,6 @@
1
1
  import torch
2
2
 
3
+ from typing import Dict, Any
3
4
  from cache_dit.cache_factory import ForwardPattern
4
5
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
5
6
  CachedBlocks_Pattern_Base,
@@ -31,7 +32,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
31
32
  # Call first `n` blocks to process the hidden states for
32
33
  # more stable diff calculation.
33
34
  # encoder_hidden_states: None Pattern 3, else 4, 5
34
- hidden_states, encoder_hidden_states = self.call_Fn_blocks(
35
+ hidden_states, new_encoder_hidden_states = self.call_Fn_blocks(
35
36
  hidden_states,
36
37
  *args,
37
38
  **kwargs,
@@ -60,11 +61,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
60
61
  if can_use_cache:
61
62
  self.cache_manager.add_cached_step()
62
63
  del Fn_hidden_states_residual
63
- hidden_states, encoder_hidden_states = (
64
+ hidden_states, new_encoder_hidden_states = (
64
65
  self.cache_manager.apply_cache(
65
66
  hidden_states,
66
- # None Pattern 3, else 4, 5
67
- encoder_hidden_states,
67
+ new_encoder_hidden_states, # encoder_hidden_states not use cache
68
68
  prefix=(
69
69
  f"{self.cache_prefix}_Bn_residual"
70
70
  if self.cache_manager.is_cache_residual()
@@ -80,12 +80,12 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
80
80
  torch._dynamo.graph_break()
81
81
  # Call last `n` blocks to further process the hidden states
82
82
  # for higher precision.
83
- hidden_states, encoder_hidden_states = self.call_Bn_blocks(
84
- hidden_states,
85
- encoder_hidden_states,
86
- *args,
87
- **kwargs,
88
- )
83
+ if self.cache_manager.Bn_compute_blocks() > 0:
84
+ hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
85
+ hidden_states,
86
+ *args,
87
+ **kwargs,
88
+ )
89
89
  else:
90
90
  self.cache_manager.set_Fn_buffer(
91
91
  Fn_hidden_states_residual,
@@ -99,19 +99,20 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
99
99
  )
100
100
  del Fn_hidden_states_residual
101
101
  torch._dynamo.graph_break()
102
+ old_encoder_hidden_states = new_encoder_hidden_states
102
103
  (
103
104
  hidden_states,
104
- encoder_hidden_states,
105
+ new_encoder_hidden_states,
105
106
  hidden_states_residual,
106
- # None Pattern 3, else 4, 5
107
- encoder_hidden_states_residual,
108
107
  ) = self.call_Mn_blocks( # middle
109
108
  hidden_states,
110
- # None Pattern 3, else 4, 5
111
- encoder_hidden_states,
112
109
  *args,
113
110
  **kwargs,
114
111
  )
112
+ if new_encoder_hidden_states is not None:
113
+ new_encoder_hidden_states_residual = (
114
+ new_encoder_hidden_states - old_encoder_hidden_states
115
+ )
115
116
  torch._dynamo.graph_break()
116
117
  if self.cache_manager.is_cache_residual():
117
118
  self.cache_manager.set_Bn_buffer(
@@ -119,34 +120,32 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
119
120
  prefix=f"{self.cache_prefix}_Bn_residual",
120
121
  )
121
122
  else:
122
- # TaylorSeer
123
123
  self.cache_manager.set_Bn_buffer(
124
124
  hidden_states,
125
125
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
126
126
  )
127
+
127
128
  if self.cache_manager.is_encoder_cache_residual():
128
- self.cache_manager.set_Bn_encoder_buffer(
129
- # None Pattern 3, else 4, 5
130
- encoder_hidden_states_residual,
131
- prefix=f"{self.cache_prefix}_Bn_residual",
132
- )
129
+ if new_encoder_hidden_states is not None:
130
+ self.cache_manager.set_Bn_encoder_buffer(
131
+ new_encoder_hidden_states_residual,
132
+ prefix=f"{self.cache_prefix}_Bn_residual",
133
+ )
133
134
  else:
134
- # TaylorSeer
135
- self.cache_manager.set_Bn_encoder_buffer(
136
- # None Pattern 3, else 4, 5
137
- encoder_hidden_states,
138
- prefix=f"{self.cache_prefix}_Bn_hidden_states",
139
- )
135
+ if new_encoder_hidden_states is not None:
136
+ self.cache_manager.set_Bn_encoder_buffer(
137
+ new_encoder_hidden_states_residual,
138
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
139
+ )
140
140
  torch._dynamo.graph_break()
141
141
  # Call last `n` blocks to further process the hidden states
142
142
  # for higher precision.
143
- hidden_states, encoder_hidden_states = self.call_Bn_blocks(
144
- hidden_states,
145
- # None Pattern 3, else 4, 5
146
- encoder_hidden_states,
147
- *args,
148
- **kwargs,
149
- )
143
+ if self.cache_manager.Bn_compute_blocks() > 0:
144
+ hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
145
+ hidden_states,
146
+ *args,
147
+ **kwargs,
148
+ )
150
149
 
151
150
  torch._dynamo.graph_break()
152
151
 
@@ -154,12 +153,21 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
154
153
  hidden_states
155
154
  if self.forward_pattern.Return_H_Only
156
155
  else (
157
- (hidden_states, encoder_hidden_states)
156
+ (hidden_states, new_encoder_hidden_states)
158
157
  if self.forward_pattern.Return_H_First
159
- else (encoder_hidden_states, hidden_states)
158
+ else (new_encoder_hidden_states, hidden_states)
160
159
  )
161
160
  )
162
161
 
162
+ @torch.compiler.disable
163
+ def maybe_update_kwargs(
164
+ self, encoder_hidden_states, kwargs: Dict[str, Any]
165
+ ) -> Dict[str, Any]:
166
+ # if "encoder_hidden_states" in kwargs:
167
+ # kwargs["encoder_hidden_states"] = encoder_hidden_states
168
+ # return kwargs
169
+ return kwargs
170
+
163
171
  def call_Fn_blocks(
164
172
  self,
165
173
  hidden_states: torch.Tensor,
@@ -172,7 +180,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
172
180
  f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
173
181
  f"the number of transformer blocks {len(self.transformer_blocks)}"
174
182
  )
175
- encoder_hidden_states = None # Pattern 3
183
+ new_encoder_hidden_states = None
176
184
  for block in self._Fn_blocks():
177
185
  hidden_states = block(
178
186
  hidden_states,
@@ -180,25 +188,27 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
180
188
  **kwargs,
181
189
  )
182
190
  if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
183
- hidden_states, encoder_hidden_states = hidden_states
191
+ hidden_states, new_encoder_hidden_states = hidden_states
184
192
  if not self.forward_pattern.Return_H_First:
185
- hidden_states, encoder_hidden_states = (
186
- encoder_hidden_states,
193
+ hidden_states, new_encoder_hidden_states = (
194
+ new_encoder_hidden_states,
187
195
  hidden_states,
188
196
  )
197
+ kwargs = self.maybe_update_kwargs(
198
+ new_encoder_hidden_states,
199
+ kwargs,
200
+ )
189
201
 
190
- return hidden_states, encoder_hidden_states
202
+ return hidden_states, new_encoder_hidden_states
191
203
 
192
204
  def call_Mn_blocks(
193
205
  self,
194
206
  hidden_states: torch.Tensor,
195
- # None Pattern 3, else 4, 5
196
- encoder_hidden_states: torch.Tensor | None,
197
207
  *args,
198
208
  **kwargs,
199
209
  ):
200
210
  original_hidden_states = hidden_states
201
- original_encoder_hidden_states = encoder_hidden_states
211
+ new_encoder_hidden_states = None
202
212
  for block in self._Mn_blocks():
203
213
  hidden_states = block(
204
214
  hidden_states,
@@ -206,44 +216,33 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
206
216
  **kwargs,
207
217
  )
208
218
  if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
209
- hidden_states, encoder_hidden_states = hidden_states
219
+ hidden_states, new_encoder_hidden_states = hidden_states
210
220
  if not self.forward_pattern.Return_H_First:
211
- hidden_states, encoder_hidden_states = (
212
- encoder_hidden_states,
221
+ hidden_states, new_encoder_hidden_states = (
222
+ new_encoder_hidden_states,
213
223
  hidden_states,
214
224
  )
225
+ kwargs = self.maybe_update_kwargs(
226
+ new_encoder_hidden_states,
227
+ kwargs,
228
+ )
215
229
 
216
230
  # compute hidden_states residual
217
231
  hidden_states = hidden_states.contiguous()
218
232
  hidden_states_residual = hidden_states - original_hidden_states
219
- if (
220
- original_encoder_hidden_states is not None
221
- and encoder_hidden_states is not None
222
- ): # Pattern 4, 5
223
- encoder_hidden_states_residual = (
224
- encoder_hidden_states - original_encoder_hidden_states
225
- )
226
- else:
227
- encoder_hidden_states_residual = None # Pattern 3
228
233
 
229
234
  return (
230
235
  hidden_states,
231
- encoder_hidden_states,
236
+ new_encoder_hidden_states,
232
237
  hidden_states_residual,
233
- encoder_hidden_states_residual,
234
238
  )
235
239
 
236
240
  def call_Bn_blocks(
237
241
  self,
238
242
  hidden_states: torch.Tensor,
239
- # None Pattern 3, else 4, 5
240
- encoder_hidden_states: torch.Tensor | None,
241
243
  *args,
242
244
  **kwargs,
243
245
  ):
244
- if self.cache_manager.Bn_compute_blocks() == 0:
245
- return hidden_states, encoder_hidden_states
246
-
247
246
  assert self.cache_manager.Bn_compute_blocks() <= len(
248
247
  self.transformer_blocks
249
248
  ), (
@@ -264,11 +263,15 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
264
263
  **kwargs,
265
264
  )
266
265
  if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
267
- hidden_states, encoder_hidden_states = hidden_states
266
+ hidden_states, new_encoder_hidden_states = hidden_states
268
267
  if not self.forward_pattern.Return_H_First:
269
- hidden_states, encoder_hidden_states = (
270
- encoder_hidden_states,
268
+ hidden_states, new_encoder_hidden_states = (
269
+ new_encoder_hidden_states,
271
270
  hidden_states,
272
271
  )
272
+ kwargs = self.maybe_update_kwargs(
273
+ new_encoder_hidden_states,
274
+ kwargs,
275
+ )
273
276
 
274
- return hidden_states, encoder_hidden_states
277
+ return hidden_states, new_encoder_hidden_states
@@ -23,3 +23,19 @@ def patch_cached_stats(
23
23
  module._residual_diffs = cache_manager.get_residual_diffs()
24
24
  module._cfg_cached_steps = cache_manager.get_cfg_cached_steps()
25
25
  module._cfg_residual_diffs = cache_manager.get_cfg_residual_diffs()
26
+
27
+
28
+ def remove_cached_stats(
29
+ module: torch.nn.Module | Any,
30
+ ):
31
+ if module is None:
32
+ return
33
+
34
+ if hasattr(module, "_cached_steps"):
35
+ del module._cached_steps
36
+ if hasattr(module, "_residual_diffs"):
37
+ del module._residual_diffs
38
+ if hasattr(module, "_cfg_cached_steps"):
39
+ del module._cfg_cached_steps
40
+ if hasattr(module, "_cfg_residual_diffs"):
41
+ del module._cfg_residual_diffs