cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
- cache_dit/{cache_factory → caching}/utils.py +19 -8
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
- cache_dit/cache_factory/patch_functors/__init__.py +0 -15
- cache_dit-1.0.3.dist-info/RECORD +0 -58
- cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -1,306 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
from cache_dit.cache_factory import ForwardPattern
|
|
4
|
-
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
5
|
-
CacheNotExistError,
|
|
6
|
-
)
|
|
7
|
-
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
8
|
-
CachedBlocks_Pattern_Base,
|
|
9
|
-
)
|
|
10
|
-
from cache_dit.logger import init_logger
|
|
11
|
-
|
|
12
|
-
logger = init_logger(__name__)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
16
|
-
_supported_patterns = [
|
|
17
|
-
ForwardPattern.Pattern_3,
|
|
18
|
-
ForwardPattern.Pattern_4,
|
|
19
|
-
ForwardPattern.Pattern_5,
|
|
20
|
-
]
|
|
21
|
-
|
|
22
|
-
def call_blocks(
|
|
23
|
-
self,
|
|
24
|
-
hidden_states: torch.Tensor,
|
|
25
|
-
*args,
|
|
26
|
-
**kwargs,
|
|
27
|
-
):
|
|
28
|
-
# Call all blocks to process the hidden states without cache.
|
|
29
|
-
new_encoder_hidden_states = None
|
|
30
|
-
for block in self.transformer_blocks:
|
|
31
|
-
hidden_states = block(
|
|
32
|
-
hidden_states,
|
|
33
|
-
*args,
|
|
34
|
-
**kwargs,
|
|
35
|
-
)
|
|
36
|
-
hidden_states, new_encoder_hidden_states = (
|
|
37
|
-
self._process_block_outputs(hidden_states)
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
return hidden_states, new_encoder_hidden_states
|
|
41
|
-
|
|
42
|
-
@torch.compiler.disable
|
|
43
|
-
def _process_block_outputs(
|
|
44
|
-
self, hidden_states: torch.Tensor | tuple
|
|
45
|
-
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
46
|
-
# Process the outputs for the block.
|
|
47
|
-
new_encoder_hidden_states = None
|
|
48
|
-
if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
|
|
49
|
-
if len(hidden_states) == 2:
|
|
50
|
-
if isinstance(hidden_states[1], torch.Tensor):
|
|
51
|
-
hidden_states, new_encoder_hidden_states = hidden_states
|
|
52
|
-
if not self.forward_pattern.Return_H_First:
|
|
53
|
-
hidden_states, new_encoder_hidden_states = (
|
|
54
|
-
new_encoder_hidden_states,
|
|
55
|
-
hidden_states,
|
|
56
|
-
)
|
|
57
|
-
elif isinstance(hidden_states[0], torch.Tensor):
|
|
58
|
-
hidden_states = hidden_states[0]
|
|
59
|
-
else:
|
|
60
|
-
raise ValueError("Unexpected hidden_states format.")
|
|
61
|
-
else:
|
|
62
|
-
assert (
|
|
63
|
-
len(hidden_states) == 1
|
|
64
|
-
), f"Unexpected output length: {len(hidden_states)}"
|
|
65
|
-
hidden_states = hidden_states[0]
|
|
66
|
-
return hidden_states, new_encoder_hidden_states
|
|
67
|
-
|
|
68
|
-
@torch.compiler.disable
|
|
69
|
-
def _process_forward_outputs(
|
|
70
|
-
self,
|
|
71
|
-
hidden_states: torch.Tensor,
|
|
72
|
-
new_encoder_hidden_states: torch.Tensor | None,
|
|
73
|
-
) -> (
|
|
74
|
-
torch.Tensor
|
|
75
|
-
| tuple[torch.Tensor, torch.Tensor]
|
|
76
|
-
| tuple[torch.Tensor, None]
|
|
77
|
-
):
|
|
78
|
-
if self.forward_pattern.Return_H_Only:
|
|
79
|
-
return hidden_states
|
|
80
|
-
else:
|
|
81
|
-
if self.forward_pattern.Return_H_First:
|
|
82
|
-
return (hidden_states, new_encoder_hidden_states)
|
|
83
|
-
else:
|
|
84
|
-
return (new_encoder_hidden_states, hidden_states)
|
|
85
|
-
|
|
86
|
-
def forward(
|
|
87
|
-
self,
|
|
88
|
-
hidden_states: torch.Tensor,
|
|
89
|
-
*args,
|
|
90
|
-
**kwargs,
|
|
91
|
-
):
|
|
92
|
-
# Use it's own cache context.
|
|
93
|
-
try:
|
|
94
|
-
self.cache_manager.set_context(self.cache_context)
|
|
95
|
-
self._check_cache_params()
|
|
96
|
-
except CacheNotExistError as e:
|
|
97
|
-
logger.warning(f"Cache context not exist: {e}, skip cache.")
|
|
98
|
-
hidden_states, new_encoder_hidden_states = self.call_blocks(
|
|
99
|
-
hidden_states,
|
|
100
|
-
*args,
|
|
101
|
-
**kwargs,
|
|
102
|
-
)
|
|
103
|
-
return self._process_forward_outputs(
|
|
104
|
-
hidden_states, new_encoder_hidden_states
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
original_hidden_states = hidden_states
|
|
108
|
-
# Call first `n` blocks to process the hidden states for
|
|
109
|
-
# more stable diff calculation.
|
|
110
|
-
hidden_states, new_encoder_hidden_states = self.call_Fn_blocks(
|
|
111
|
-
hidden_states,
|
|
112
|
-
*args,
|
|
113
|
-
**kwargs,
|
|
114
|
-
)
|
|
115
|
-
|
|
116
|
-
Fn_hidden_states_residual = hidden_states - original_hidden_states.to(
|
|
117
|
-
hidden_states.device
|
|
118
|
-
)
|
|
119
|
-
del original_hidden_states
|
|
120
|
-
|
|
121
|
-
self.cache_manager.mark_step_begin()
|
|
122
|
-
# Residual L1 diff or Hidden States L1 diff
|
|
123
|
-
can_use_cache = self.cache_manager.can_cache(
|
|
124
|
-
(
|
|
125
|
-
Fn_hidden_states_residual
|
|
126
|
-
if not self.cache_manager.is_l1_diff_enabled()
|
|
127
|
-
else hidden_states
|
|
128
|
-
),
|
|
129
|
-
parallelized=self._is_parallelized(),
|
|
130
|
-
prefix=(
|
|
131
|
-
f"{self.cache_prefix}_Fn_residual"
|
|
132
|
-
if not self.cache_manager.is_l1_diff_enabled()
|
|
133
|
-
else f"{self.cache_prefix}_Fn_hidden_states"
|
|
134
|
-
),
|
|
135
|
-
)
|
|
136
|
-
|
|
137
|
-
torch._dynamo.graph_break()
|
|
138
|
-
if can_use_cache:
|
|
139
|
-
self.cache_manager.add_cached_step()
|
|
140
|
-
del Fn_hidden_states_residual
|
|
141
|
-
hidden_states, new_encoder_hidden_states = (
|
|
142
|
-
self.cache_manager.apply_cache(
|
|
143
|
-
hidden_states,
|
|
144
|
-
new_encoder_hidden_states, # encoder_hidden_states not use cache
|
|
145
|
-
prefix=(
|
|
146
|
-
f"{self.cache_prefix}_Bn_residual"
|
|
147
|
-
if self.cache_manager.is_cache_residual()
|
|
148
|
-
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
149
|
-
),
|
|
150
|
-
encoder_prefix=(
|
|
151
|
-
f"{self.cache_prefix}_Bn_residual"
|
|
152
|
-
if self.cache_manager.is_encoder_cache_residual()
|
|
153
|
-
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
154
|
-
),
|
|
155
|
-
)
|
|
156
|
-
)
|
|
157
|
-
torch._dynamo.graph_break()
|
|
158
|
-
# Call last `n` blocks to further process the hidden states
|
|
159
|
-
# for higher precision.
|
|
160
|
-
if self.cache_manager.Bn_compute_blocks() > 0:
|
|
161
|
-
hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
|
|
162
|
-
hidden_states,
|
|
163
|
-
*args,
|
|
164
|
-
**kwargs,
|
|
165
|
-
)
|
|
166
|
-
else:
|
|
167
|
-
self.cache_manager.set_Fn_buffer(
|
|
168
|
-
Fn_hidden_states_residual,
|
|
169
|
-
prefix=f"{self.cache_prefix}_Fn_residual",
|
|
170
|
-
)
|
|
171
|
-
if self.cache_manager.is_l1_diff_enabled():
|
|
172
|
-
# for hidden states L1 diff
|
|
173
|
-
self.cache_manager.set_Fn_buffer(
|
|
174
|
-
hidden_states,
|
|
175
|
-
f"{self.cache_prefix}_Fn_hidden_states",
|
|
176
|
-
)
|
|
177
|
-
del Fn_hidden_states_residual
|
|
178
|
-
torch._dynamo.graph_break()
|
|
179
|
-
old_encoder_hidden_states = new_encoder_hidden_states
|
|
180
|
-
(
|
|
181
|
-
hidden_states,
|
|
182
|
-
new_encoder_hidden_states,
|
|
183
|
-
hidden_states_residual,
|
|
184
|
-
) = self.call_Mn_blocks( # middle
|
|
185
|
-
hidden_states,
|
|
186
|
-
*args,
|
|
187
|
-
**kwargs,
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
torch._dynamo.graph_break()
|
|
191
|
-
if self.cache_manager.is_cache_residual():
|
|
192
|
-
self.cache_manager.set_Bn_buffer(
|
|
193
|
-
hidden_states_residual,
|
|
194
|
-
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
195
|
-
)
|
|
196
|
-
else:
|
|
197
|
-
self.cache_manager.set_Bn_buffer(
|
|
198
|
-
hidden_states,
|
|
199
|
-
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
if new_encoder_hidden_states is not None:
|
|
203
|
-
new_encoder_hidden_states_residual = (
|
|
204
|
-
new_encoder_hidden_states - old_encoder_hidden_states
|
|
205
|
-
)
|
|
206
|
-
if self.cache_manager.is_encoder_cache_residual():
|
|
207
|
-
if new_encoder_hidden_states is not None:
|
|
208
|
-
self.cache_manager.set_Bn_encoder_buffer(
|
|
209
|
-
new_encoder_hidden_states_residual,
|
|
210
|
-
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
211
|
-
)
|
|
212
|
-
else:
|
|
213
|
-
if new_encoder_hidden_states is not None:
|
|
214
|
-
self.cache_manager.set_Bn_encoder_buffer(
|
|
215
|
-
new_encoder_hidden_states_residual,
|
|
216
|
-
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
217
|
-
)
|
|
218
|
-
torch._dynamo.graph_break()
|
|
219
|
-
# Call last `n` blocks to further process the hidden states
|
|
220
|
-
# for higher precision.
|
|
221
|
-
if self.cache_manager.Bn_compute_blocks() > 0:
|
|
222
|
-
hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
|
|
223
|
-
hidden_states,
|
|
224
|
-
*args,
|
|
225
|
-
**kwargs,
|
|
226
|
-
)
|
|
227
|
-
|
|
228
|
-
torch._dynamo.graph_break()
|
|
229
|
-
|
|
230
|
-
return self._process_forward_outputs(
|
|
231
|
-
hidden_states,
|
|
232
|
-
new_encoder_hidden_states,
|
|
233
|
-
)
|
|
234
|
-
|
|
235
|
-
def call_Fn_blocks(
|
|
236
|
-
self,
|
|
237
|
-
hidden_states: torch.Tensor,
|
|
238
|
-
*args,
|
|
239
|
-
**kwargs,
|
|
240
|
-
):
|
|
241
|
-
new_encoder_hidden_states = None
|
|
242
|
-
for block in self._Fn_blocks():
|
|
243
|
-
hidden_states = block(
|
|
244
|
-
hidden_states,
|
|
245
|
-
*args,
|
|
246
|
-
**kwargs,
|
|
247
|
-
)
|
|
248
|
-
hidden_states, new_encoder_hidden_states = (
|
|
249
|
-
self._process_block_outputs(hidden_states)
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
return hidden_states, new_encoder_hidden_states
|
|
253
|
-
|
|
254
|
-
def call_Mn_blocks(
|
|
255
|
-
self,
|
|
256
|
-
hidden_states: torch.Tensor,
|
|
257
|
-
*args,
|
|
258
|
-
**kwargs,
|
|
259
|
-
):
|
|
260
|
-
original_hidden_states = hidden_states
|
|
261
|
-
new_encoder_hidden_states = None
|
|
262
|
-
for block in self._Mn_blocks():
|
|
263
|
-
hidden_states = block(
|
|
264
|
-
hidden_states,
|
|
265
|
-
*args,
|
|
266
|
-
**kwargs,
|
|
267
|
-
)
|
|
268
|
-
|
|
269
|
-
hidden_states, new_encoder_hidden_states = (
|
|
270
|
-
self._process_block_outputs(hidden_states)
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
# compute hidden_states residual
|
|
274
|
-
hidden_states = hidden_states.contiguous()
|
|
275
|
-
hidden_states_residual = hidden_states - original_hidden_states.to(
|
|
276
|
-
hidden_states.device
|
|
277
|
-
)
|
|
278
|
-
|
|
279
|
-
return (
|
|
280
|
-
hidden_states,
|
|
281
|
-
new_encoder_hidden_states,
|
|
282
|
-
hidden_states_residual,
|
|
283
|
-
)
|
|
284
|
-
|
|
285
|
-
def call_Bn_blocks(
|
|
286
|
-
self,
|
|
287
|
-
hidden_states: torch.Tensor,
|
|
288
|
-
*args,
|
|
289
|
-
**kwargs,
|
|
290
|
-
):
|
|
291
|
-
new_encoder_hidden_states = None
|
|
292
|
-
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
293
|
-
return hidden_states, new_encoder_hidden_states
|
|
294
|
-
|
|
295
|
-
for block in self._Bn_blocks():
|
|
296
|
-
hidden_states = block(
|
|
297
|
-
hidden_states,
|
|
298
|
-
*args,
|
|
299
|
-
**kwargs,
|
|
300
|
-
)
|
|
301
|
-
|
|
302
|
-
hidden_states, new_encoder_hidden_states = (
|
|
303
|
-
self._process_block_outputs(hidden_states)
|
|
304
|
-
)
|
|
305
|
-
|
|
306
|
-
return hidden_states, new_encoder_hidden_states
|