broccoli-ml 0.29.1__py3-none-any.whl → 10.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/activation.py +1 -4
- broccoli/cnn.py +1 -289
- broccoli/linear.py +237 -7
- broccoli/rope.py +19 -4
- broccoli/tensor.py +36 -31
- broccoli/transformer.py +523 -186
- broccoli/utils.py +13 -7
- broccoli/vit.py +214 -56
- {broccoli_ml-0.29.1.dist-info → broccoli_ml-10.0.1.dist-info}/METADATA +5 -3
- broccoli_ml-10.0.1.dist-info/RECORD +13 -0
- broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl +0 -0
- broccoli/assets/cifar100_eigenvectors_size_2.pt +0 -0
- broccoli/assets/cifar100_eigenvectors_size_3.pt +0 -0
- broccoli/eigenpatches.py +0 -49
- broccoli_ml-0.29.1.dist-info/RECORD +0 -17
- {broccoli_ml-0.29.1.dist-info → broccoli_ml-10.0.1.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.29.1.dist-info → broccoli_ml-10.0.1.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
|
@@ -1,16 +1,72 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import math
|
|
2
|
-
from
|
|
3
|
-
from typing import Optional
|
|
4
|
-
from numpy import random
|
|
3
|
+
from typing import Optional, Tuple
|
|
5
4
|
|
|
6
5
|
import torch
|
|
7
6
|
import torch.nn as nn
|
|
8
7
|
import torch.nn.functional as F
|
|
8
|
+
from torch.utils.checkpoint import checkpoint
|
|
9
9
|
|
|
10
10
|
from einops import rearrange
|
|
11
11
|
|
|
12
12
|
from .rope import RotaryEmbedding, apply_rotary_emb
|
|
13
|
-
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from flash_attn import flash_attn_func
|
|
16
|
+
|
|
17
|
+
print("Using flash-attn.")
|
|
18
|
+
FLASH_ATTN = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
pass
|
|
21
|
+
FLASH_ATTN = False
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LayerScale(nn.Module):
|
|
25
|
+
def __init__(self, dim, init_values=1e-4):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.nondecay_scale = nn.Parameter(init_values * torch.ones(dim))
|
|
28
|
+
|
|
29
|
+
def forward(self, x):
|
|
30
|
+
return x * self.nondecay_scale
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def drop_path(
|
|
34
|
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
|
38
|
+
Copyright 2019 Ross Wightman
|
|
39
|
+
See documentation and licence there.
|
|
40
|
+
"""
|
|
41
|
+
if drop_prob == 0.0 or not training:
|
|
42
|
+
return x
|
|
43
|
+
keep_prob = 1 - drop_prob
|
|
44
|
+
shape = (x.shape[0],) + (1,) * (
|
|
45
|
+
x.ndim - 1
|
|
46
|
+
) # work with diff dim tensors, not just 2D ConvNets
|
|
47
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
48
|
+
if keep_prob > 0.0 and scale_by_keep:
|
|
49
|
+
random_tensor.div_(keep_prob)
|
|
50
|
+
return x * random_tensor
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class DropPath(nn.Module):
|
|
54
|
+
"""
|
|
55
|
+
From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
|
56
|
+
Copyright 2019 Ross Wightman
|
|
57
|
+
See documentation and licence there.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
|
61
|
+
super(DropPath, self).__init__()
|
|
62
|
+
self.drop_prob = drop_prob
|
|
63
|
+
self.scale_by_keep = scale_by_keep
|
|
64
|
+
|
|
65
|
+
def forward(self, x):
|
|
66
|
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
|
67
|
+
|
|
68
|
+
def extra_repr(self):
|
|
69
|
+
return f"drop_prob={round(self.drop_prob, 3):0.3f}"
|
|
14
70
|
|
|
15
71
|
|
|
16
72
|
class MHAttention(nn.Module):
|
|
@@ -21,45 +77,6 @@ class MHAttention(nn.Module):
|
|
|
21
77
|
are the same shape.
|
|
22
78
|
|
|
23
79
|
Assumes bias=False and batch_first=True, as God intended.
|
|
24
|
-
|
|
25
|
-
Optionally adds various bells and whistles suggested in the
|
|
26
|
-
literature, including:
|
|
27
|
-
|
|
28
|
-
Noam Shazeer's scaled attention per "Attention is All You Need"
|
|
29
|
-
(https://arxiv.org/abs/1706.03762).
|
|
30
|
-
|
|
31
|
-
Max subtract softmax as discussed in "Attention As An RNN"
|
|
32
|
-
(https://arxiv.org/abs/2405.13956)
|
|
33
|
-
|
|
34
|
-
Log-length scaled softmax per "Overcoming a Theoretical Limitation of
|
|
35
|
-
Self-Attention" (https://arxiv.org/abs/2202.12172).
|
|
36
|
-
|
|
37
|
-
Quiet softmax per
|
|
38
|
-
https://www.evanmiller.org/attention-is-off-by-one.html
|
|
39
|
-
|
|
40
|
-
Args:
|
|
41
|
-
d_model: ...
|
|
42
|
-
n_heads: ...
|
|
43
|
-
dropout: ...
|
|
44
|
-
causal: should a causal mask be applied to the logits before attention
|
|
45
|
-
is applied? This is standard when using self-attention. Cannot be
|
|
46
|
-
True if inputs won't be square (e.g. if sequence length for
|
|
47
|
-
encoder and decoder are different)
|
|
48
|
-
sequence_length: ...
|
|
49
|
-
share_kv: ...
|
|
50
|
-
linear_module: ...
|
|
51
|
-
max_subtract: if True, the maximum logit value is subtracted from all
|
|
52
|
-
logits before performing the softmax operation to create a more
|
|
53
|
-
numerically stable softmax. This is discussed in "Attention As An
|
|
54
|
-
RNN" (https://arxiv.org/abs/2405.13956).
|
|
55
|
-
d_model_scale: ...
|
|
56
|
-
log_length_scale: if True, multiplies logits by the log length of
|
|
57
|
-
the decoder sequence before performing the softmax operation, as
|
|
58
|
-
proposed in "Overcoming a Theoretical Limitation of Self-Attention"
|
|
59
|
-
(https://arxiv.org/abs/2202.12172).
|
|
60
|
-
quiet: if True, adds 1 to the denominator of the softmax operation,
|
|
61
|
-
allowing some tokens to attend to no other tokens as described in
|
|
62
|
-
https://www.evanmiller.org/attention-is-off-by-one.html.
|
|
63
80
|
"""
|
|
64
81
|
|
|
65
82
|
def __init__(
|
|
@@ -70,10 +87,19 @@ class MHAttention(nn.Module):
|
|
|
70
87
|
causal=False,
|
|
71
88
|
seq_len=None,
|
|
72
89
|
linear_module: nn.Module = nn.Linear,
|
|
73
|
-
|
|
90
|
+
utility_tokens=0,
|
|
91
|
+
talking_heads=False,
|
|
74
92
|
rotary_embedding=None,
|
|
75
93
|
source_size=None,
|
|
94
|
+
scaling="d",
|
|
76
95
|
):
|
|
96
|
+
"""
|
|
97
|
+
Args:
|
|
98
|
+
scaling: how should the attention logits be scaled? Can be "sqrtd"
|
|
99
|
+
to mimic the original Attention is All You Need approach of
|
|
100
|
+
dividing by the sqrt of the embedding Dimension or "d" per
|
|
101
|
+
"Tensor Programs V...". Default "d"
|
|
102
|
+
"""
|
|
77
103
|
super().__init__()
|
|
78
104
|
|
|
79
105
|
if rotary_embedding is not None:
|
|
@@ -81,12 +107,31 @@ class MHAttention(nn.Module):
|
|
|
81
107
|
if causal:
|
|
82
108
|
assert seq_len is not None
|
|
83
109
|
|
|
110
|
+
self.talking_heads = talking_heads
|
|
111
|
+
|
|
112
|
+
if self.talking_heads:
|
|
113
|
+
self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
114
|
+
self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
115
|
+
else:
|
|
116
|
+
self.head_projection = None
|
|
117
|
+
self.sample_projection = None
|
|
118
|
+
|
|
84
119
|
self.embed_dim = embed_dim
|
|
85
120
|
self.n_heads = n_heads
|
|
86
121
|
assert embed_dim % n_heads == 0
|
|
122
|
+
self.scaling = scaling
|
|
87
123
|
|
|
88
124
|
self.head_dim = self.embed_dim // self.n_heads
|
|
89
125
|
|
|
126
|
+
if self.scaling == "sqrtd":
|
|
127
|
+
self.scaling_factor = 1 / math.sqrt(self.head_dim)
|
|
128
|
+
elif self.scaling == "d":
|
|
129
|
+
# 8/d_model for backwards compatibility,
|
|
130
|
+
# per https://github.com/microsoft/mup
|
|
131
|
+
self.scaling_factor = 8 / self.head_dim
|
|
132
|
+
else:
|
|
133
|
+
raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
|
|
134
|
+
|
|
90
135
|
self.q_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
91
136
|
self.k_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
92
137
|
self.v_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
@@ -105,7 +150,9 @@ class MHAttention(nn.Module):
|
|
|
105
150
|
)
|
|
106
151
|
self.rotary_embedding = rotary_embedding
|
|
107
152
|
self.source_size = source_size
|
|
108
|
-
self.
|
|
153
|
+
self.utility_tokens = utility_tokens
|
|
154
|
+
|
|
155
|
+
self.reset_parameters()
|
|
109
156
|
|
|
110
157
|
@property
|
|
111
158
|
def _kv_distance(self) -> float:
|
|
@@ -126,7 +173,71 @@ class MHAttention(nn.Module):
|
|
|
126
173
|
|
|
127
174
|
return 1 - similarity
|
|
128
175
|
|
|
129
|
-
def
|
|
176
|
+
def add_axial_rope(
|
|
177
|
+
self, q: torch.Tensor, k: torch.Tensor
|
|
178
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
179
|
+
"""
|
|
180
|
+
Apply Axial RoPE to all tokens except utility tokens
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
if len(self.source_size) == 1:
|
|
184
|
+
spatial_dimension_names = "D1"
|
|
185
|
+
spatial_dimension_values = {"D1": self.source_size[0]}
|
|
186
|
+
elif len(self.source_size) == 2:
|
|
187
|
+
spatial_dimension_names = "D1 D2"
|
|
188
|
+
spatial_dimension_values = {
|
|
189
|
+
"D1": self.source_size[0],
|
|
190
|
+
"D2": self.source_size[1],
|
|
191
|
+
}
|
|
192
|
+
elif len(self.source_size) == 3:
|
|
193
|
+
spatial_dimension_names = "D1 D2 D3"
|
|
194
|
+
spatial_dimension_values = {
|
|
195
|
+
"D1": self.source_size[0],
|
|
196
|
+
"D2": self.source_size[1],
|
|
197
|
+
"D3": self.source_size[2],
|
|
198
|
+
}
|
|
199
|
+
else:
|
|
200
|
+
raise NotImplementedError(
|
|
201
|
+
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
|
|
205
|
+
k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
|
|
206
|
+
|
|
207
|
+
q_img = rearrange(
|
|
208
|
+
q_img,
|
|
209
|
+
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
210
|
+
**spatial_dimension_values,
|
|
211
|
+
)
|
|
212
|
+
k_img = rearrange(
|
|
213
|
+
k_img,
|
|
214
|
+
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
215
|
+
**spatial_dimension_values,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
|
|
219
|
+
|
|
220
|
+
q_img = apply_rotary_emb(freqs, q_img)
|
|
221
|
+
k_img = apply_rotary_emb(freqs, k_img)
|
|
222
|
+
|
|
223
|
+
q_img = rearrange(
|
|
224
|
+
q_img,
|
|
225
|
+
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
226
|
+
)
|
|
227
|
+
k_img = rearrange(
|
|
228
|
+
k_img,
|
|
229
|
+
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Re-combine the utility tokens and the RoPE-enhanced sequence tokens
|
|
233
|
+
q = torch.cat([q_util, q_img], dim=1)
|
|
234
|
+
k = torch.cat([k_util, k_img], dim=1)
|
|
235
|
+
|
|
236
|
+
return q, k
|
|
237
|
+
|
|
238
|
+
def project_qkv(
|
|
239
|
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
|
240
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
130
241
|
query_batch_size, query_tokens, query_features = q.size()
|
|
131
242
|
key_batch_size, key_tokens, key_features = k.size()
|
|
132
243
|
|
|
@@ -139,66 +250,74 @@ class MHAttention(nn.Module):
|
|
|
139
250
|
|
|
140
251
|
if self.causal:
|
|
141
252
|
assert query_tokens == key_tokens
|
|
142
|
-
assert query_tokens == self.
|
|
253
|
+
assert query_tokens == self.seq_len
|
|
143
254
|
|
|
144
|
-
|
|
145
|
-
q = self.q_proj(q)
|
|
146
|
-
k = self.k_proj(k)
|
|
147
|
-
v = self.v_proj(v)
|
|
255
|
+
q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
|
|
148
256
|
|
|
149
|
-
# Rearrange dimensions and add RoPE if needed
|
|
150
257
|
if self.rotary_embedding is not None:
|
|
258
|
+
q, k = self.add_axial_rope(q, k)
|
|
151
259
|
|
|
152
|
-
|
|
153
|
-
spatial_dimension_names = "D1"
|
|
154
|
-
spatial_dimension_values = {"D1": self.source_size[0]}
|
|
155
|
-
elif len(self.source_size) == 2:
|
|
156
|
-
spatial_dimension_names = "D1 D2"
|
|
157
|
-
spatial_dimension_values = {
|
|
158
|
-
"D1": self.source_size[0],
|
|
159
|
-
"D2": self.source_size[1],
|
|
160
|
-
}
|
|
161
|
-
elif len(self.source_size) == 3:
|
|
162
|
-
spatial_dimension_names = "D1 D2 D3"
|
|
163
|
-
spatial_dimension_values = {
|
|
164
|
-
"D1": self.source_size[0],
|
|
165
|
-
"D2": self.source_size[1],
|
|
166
|
-
"D3": self.source_size[2],
|
|
167
|
-
}
|
|
168
|
-
else:
|
|
169
|
-
raise NotImplementedError(
|
|
170
|
-
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
171
|
-
)
|
|
260
|
+
return q, k, v
|
|
172
261
|
|
|
173
|
-
|
|
174
|
-
k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
|
|
262
|
+
def forward(self, q, k, v):
|
|
175
263
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
264
|
+
q, k, v = self.project_qkv(q, k, v)
|
|
265
|
+
|
|
266
|
+
if FLASH_ATTN and not self.talking_heads:
|
|
267
|
+
# Divide Q/K/V into heads
|
|
268
|
+
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
269
|
+
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
270
|
+
v = rearrange(v, "b t (h d) -> b t h d", h=self.n_heads)
|
|
271
|
+
|
|
272
|
+
output_with_heads = flash_attn_func(
|
|
273
|
+
q,
|
|
274
|
+
k,
|
|
275
|
+
v,
|
|
276
|
+
dropout_p=self.dropout.p if self.training else 0.0,
|
|
277
|
+
softmax_scale=self.scaling_factor,
|
|
278
|
+
causal=self.causal,
|
|
185
279
|
)
|
|
186
|
-
freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
|
|
187
|
-
q_img = apply_rotary_emb(freqs, q_img)
|
|
188
|
-
k_img = apply_rotary_emb(freqs, k_img)
|
|
189
280
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
)
|
|
281
|
+
output_without_heads = rearrange(output_with_heads, "b t h d -> b t (h d)")
|
|
282
|
+
|
|
283
|
+
return self.out_proj(output_without_heads)
|
|
284
|
+
else:
|
|
285
|
+
# Divide Q/K/V into heads
|
|
286
|
+
q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
|
|
287
|
+
k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
|
|
288
|
+
v = rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
|
|
198
289
|
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
290
|
+
qk_scores = q @ k.transpose(-1, -2)
|
|
291
|
+
|
|
292
|
+
qk_scores *= self.scaling_factor
|
|
293
|
+
|
|
294
|
+
if self.talking_heads:
|
|
295
|
+
qk_scores = torch.einsum(
|
|
296
|
+
"b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# Apply mask if causal (must come before softmax)
|
|
300
|
+
if self.causal:
|
|
301
|
+
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
302
|
+
|
|
303
|
+
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
304
|
+
|
|
305
|
+
if self.talking_heads:
|
|
306
|
+
qk_scores = torch.einsum(
|
|
307
|
+
"b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
qk_scores = self.dropout(qk_scores)
|
|
311
|
+
|
|
312
|
+
output_with_heads = qk_scores @ v
|
|
313
|
+
|
|
314
|
+
output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
|
|
315
|
+
|
|
316
|
+
return self.out_proj(output_without_heads)
|
|
317
|
+
|
|
318
|
+
def attention_logits(self, q, k, v):
|
|
319
|
+
|
|
320
|
+
q, k, v = self.project_qkv(q, k, v)
|
|
202
321
|
|
|
203
322
|
# Divide Q/K/V into heads
|
|
204
323
|
q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
|
|
@@ -207,19 +326,24 @@ class MHAttention(nn.Module):
|
|
|
207
326
|
|
|
208
327
|
qk_scores = q @ k.transpose(-1, -2)
|
|
209
328
|
|
|
210
|
-
qk_scores
|
|
329
|
+
qk_scores *= self.scaling_factor
|
|
211
330
|
|
|
212
331
|
# Apply mask if causal (must come before softmax)
|
|
213
332
|
if self.causal:
|
|
214
333
|
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
215
334
|
|
|
216
|
-
qk_scores
|
|
217
|
-
|
|
218
|
-
output_with_heads = qk_scores @ v
|
|
335
|
+
return qk_scores # (batch, head, seq_len, seq_len)
|
|
219
336
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
337
|
+
def reset_parameters(self):
|
|
338
|
+
# Default nn.Linear init is kaiming_uniform, which is fine
|
|
339
|
+
self.q_proj.reset_parameters()
|
|
340
|
+
self.k_proj.reset_parameters()
|
|
341
|
+
self.v_proj.reset_parameters()
|
|
342
|
+
self.out_proj.reset_parameters()
|
|
343
|
+
if self.talking_heads:
|
|
344
|
+
# Initialize close to identity
|
|
345
|
+
nn.init.eye_(self.head_projection.weight)
|
|
346
|
+
nn.init.eye_(self.sample_projection.weight)
|
|
223
347
|
|
|
224
348
|
|
|
225
349
|
class FeedforwardBlock(nn.Module):
|
|
@@ -235,44 +359,123 @@ class FeedforwardBlock(nn.Module):
|
|
|
235
359
|
activation=nn.ReLU,
|
|
236
360
|
activation_kwargs=None,
|
|
237
361
|
dropout=0.0,
|
|
238
|
-
|
|
362
|
+
inner_dropout=None,
|
|
363
|
+
outer_dropout=None,
|
|
364
|
+
linear_module_up=nn.Linear,
|
|
365
|
+
linear_module_down=nn.Linear,
|
|
239
366
|
pre_norm=True,
|
|
240
367
|
normformer=False,
|
|
241
|
-
|
|
368
|
+
post_norm=True,
|
|
369
|
+
residual_path=True,
|
|
370
|
+
checkpoint=True,
|
|
242
371
|
):
|
|
243
372
|
super().__init__()
|
|
244
373
|
|
|
374
|
+
self.checkpoint = checkpoint
|
|
375
|
+
self.residual_path = residual_path
|
|
376
|
+
self.post_norm = post_norm
|
|
377
|
+
self.xglu = activation.__name__.endswith("GLU")
|
|
378
|
+
|
|
379
|
+
if self.residual_path and (output_features < input_features):
|
|
380
|
+
raise ValueError(
|
|
381
|
+
"If the number of output features will be less than "
|
|
382
|
+
"the number of input features, then `residual_path` "
|
|
383
|
+
"should be set to False."
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
if self.post_norm:
|
|
387
|
+
self.layernorm = nn.LayerNorm(output_features)
|
|
388
|
+
|
|
245
389
|
if activation_kwargs is not None:
|
|
246
390
|
self.activation = activation(**activation_kwargs)
|
|
247
391
|
else:
|
|
248
392
|
self.activation = activation()
|
|
249
393
|
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
394
|
+
self.inner_dropout = nn.Dropout(
|
|
395
|
+
inner_dropout if inner_dropout is not None else dropout
|
|
396
|
+
)
|
|
397
|
+
self.outer_dropout = nn.Dropout(
|
|
398
|
+
outer_dropout if outer_dropout is not None else dropout
|
|
399
|
+
)
|
|
256
400
|
|
|
257
401
|
self.max_features = (
|
|
258
|
-
2 * ratio * output_features
|
|
259
|
-
if
|
|
260
|
-
else ratio * output_features
|
|
402
|
+
2 * int(ratio * output_features)
|
|
403
|
+
if self.xglu
|
|
404
|
+
else int(ratio * output_features)
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
self.linear_in = linear_module_up(input_features, self.max_features)
|
|
408
|
+
self.linear_out = linear_module_down(
|
|
409
|
+
int(ratio * output_features), output_features
|
|
261
410
|
)
|
|
262
411
|
|
|
263
412
|
self.process = nn.Sequential(
|
|
264
413
|
*[
|
|
265
414
|
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
|
266
|
-
|
|
415
|
+
self.linear_in,
|
|
267
416
|
self.activation,
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
417
|
+
self.inner_dropout,
|
|
418
|
+
(
|
|
419
|
+
nn.LayerNorm(int(ratio * output_features))
|
|
420
|
+
if normformer
|
|
421
|
+
else nn.Identity()
|
|
422
|
+
),
|
|
423
|
+
self.linear_out,
|
|
424
|
+
self.outer_dropout,
|
|
271
425
|
]
|
|
272
426
|
)
|
|
273
427
|
|
|
428
|
+
self.recycling_enabled = False
|
|
429
|
+
if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
|
|
430
|
+
self.linear_out, "column_recycling_rate"
|
|
431
|
+
):
|
|
432
|
+
self.recycling_enabled = True
|
|
433
|
+
self.master_recycling_rate = self.linear_in.row_recycling_rate
|
|
434
|
+
self.linear_in.row_recycling_rate = 0.0
|
|
435
|
+
self.linear_out.column_recycling_rate = 0.0
|
|
436
|
+
if (
|
|
437
|
+
hasattr(self.linear_in, "column_recycling_rate")
|
|
438
|
+
and self.linear_in.column_recycling_rate > 0
|
|
439
|
+
) or (
|
|
440
|
+
hasattr(self.linear_out, "row_recycling_rate")
|
|
441
|
+
and self.linear_out.row_recycling_rate > 0
|
|
442
|
+
):
|
|
443
|
+
raise NotImplementedError(
|
|
444
|
+
"At the moment this layer can only support recycling linear "
|
|
445
|
+
"layers if the in layer resets only rows and the out layer "
|
|
446
|
+
"resets only columns."
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
self.reset_parameters()
|
|
450
|
+
|
|
274
451
|
def forward(self, x):
|
|
275
|
-
|
|
452
|
+
|
|
453
|
+
# Recycle weights if using recycling linear layers
|
|
454
|
+
if self.training and self.recycling_enabled:
|
|
455
|
+
indices = self.linear_out.get_reset_indices(1)
|
|
456
|
+
self.linear_in.reset_rows(indices, incoming_data=x)
|
|
457
|
+
self.linear_out.reset_columns(indices)
|
|
458
|
+
|
|
459
|
+
if self.checkpoint:
|
|
460
|
+
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
461
|
+
else:
|
|
462
|
+
processed = self.process(x)
|
|
463
|
+
|
|
464
|
+
if self.residual_path and self.post_norm:
|
|
465
|
+
return self.layernorm(x + processed)
|
|
466
|
+
elif self.residual_path:
|
|
467
|
+
return x + processed
|
|
468
|
+
else:
|
|
469
|
+
return processed
|
|
470
|
+
|
|
471
|
+
def reset_parameters(self):
|
|
472
|
+
if self.post_norm:
|
|
473
|
+
self.layernorm.reset_parameters()
|
|
474
|
+
|
|
475
|
+
# Iterate over the sequential block to reset parameters
|
|
476
|
+
for module in self.process:
|
|
477
|
+
if hasattr(module, "reset_parameters"):
|
|
478
|
+
module.reset_parameters()
|
|
276
479
|
|
|
277
480
|
|
|
278
481
|
class TransformerBlock(nn.Module):
|
|
@@ -289,30 +492,57 @@ class TransformerBlock(nn.Module):
|
|
|
289
492
|
seq_len,
|
|
290
493
|
d_model,
|
|
291
494
|
n_heads,
|
|
292
|
-
|
|
495
|
+
relative_position_embedding=False,
|
|
293
496
|
source_size=None,
|
|
294
|
-
|
|
497
|
+
utility_tokens=0,
|
|
498
|
+
talking_heads=False,
|
|
295
499
|
mlp_ratio=4,
|
|
296
500
|
activation: nn.Module = nn.ReLU,
|
|
297
501
|
activation_kwargs: Optional[dict] = None,
|
|
298
|
-
|
|
502
|
+
ff_linear_module_up=None,
|
|
503
|
+
ff_linear_module_down=None,
|
|
504
|
+
msa_scaling="d",
|
|
505
|
+
ff_dropout=0.0,
|
|
506
|
+
ff_inner_dropout=0.0,
|
|
507
|
+
ff_outer_dropout=0.0,
|
|
299
508
|
msa_dropout=0.0,
|
|
300
509
|
identity_probability=0.0,
|
|
301
510
|
causal=False,
|
|
302
511
|
linear_module=nn.Linear,
|
|
303
512
|
pre_norm=True,
|
|
513
|
+
post_norm=False,
|
|
304
514
|
normformer=False,
|
|
515
|
+
checkpoint_ff=True,
|
|
516
|
+
layerscale=True,
|
|
305
517
|
):
|
|
518
|
+
"""
|
|
519
|
+
Args:
|
|
520
|
+
msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
|
|
521
|
+
to mimic the original Attention is All You Need approach of
|
|
522
|
+
dividing by the sqrt of the embedding Dimension or "d" per
|
|
523
|
+
"Tensor Programs V...". Default "d"
|
|
524
|
+
"""
|
|
525
|
+
|
|
306
526
|
super().__init__()
|
|
307
527
|
|
|
308
528
|
self.pre_norm = pre_norm
|
|
529
|
+
self.post_norm = post_norm
|
|
530
|
+
self.normformer = normformer
|
|
309
531
|
|
|
310
|
-
self.
|
|
532
|
+
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
|
311
533
|
|
|
312
534
|
self.layer_norm_1 = nn.LayerNorm(d_model)
|
|
313
535
|
self.layer_norm_2 = nn.LayerNorm(d_model)
|
|
536
|
+
self.layer_norm_3 = nn.LayerNorm(d_model)
|
|
314
537
|
|
|
315
|
-
if
|
|
538
|
+
if layerscale:
|
|
539
|
+
self.layerscale1 = LayerScale(d_model)
|
|
540
|
+
self.layerscale2 = LayerScale(d_model)
|
|
541
|
+
else:
|
|
542
|
+
self.layerscale1 = nn.Identity()
|
|
543
|
+
self.layerscale2 = nn.Identity()
|
|
544
|
+
|
|
545
|
+
if relative_position_embedding:
|
|
316
546
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
|
317
547
|
if d_model < 16:
|
|
318
548
|
dim = d_model
|
|
@@ -333,63 +563,87 @@ class TransformerBlock(nn.Module):
|
|
|
333
563
|
linear_module=linear_module,
|
|
334
564
|
rotary_embedding=self.rotary_embedding,
|
|
335
565
|
source_size=source_size,
|
|
336
|
-
|
|
566
|
+
utility_tokens=utility_tokens,
|
|
567
|
+
talking_heads=talking_heads,
|
|
568
|
+
scaling=msa_scaling,
|
|
337
569
|
)
|
|
338
570
|
|
|
339
|
-
#
|
|
571
|
+
# Submodule for the feedforward process
|
|
340
572
|
self.ff = FeedforwardBlock(
|
|
341
573
|
d_model,
|
|
342
574
|
mlp_ratio,
|
|
343
575
|
d_model,
|
|
344
576
|
activation=activation,
|
|
345
577
|
activation_kwargs=activation_kwargs,
|
|
346
|
-
dropout=
|
|
347
|
-
|
|
348
|
-
|
|
578
|
+
dropout=ff_dropout,
|
|
579
|
+
inner_dropout=ff_inner_dropout,
|
|
580
|
+
outer_dropout=ff_outer_dropout,
|
|
581
|
+
linear_module_up=(
|
|
582
|
+
ff_linear_module_up
|
|
583
|
+
if ff_linear_module_up is not None
|
|
584
|
+
else linear_module
|
|
585
|
+
),
|
|
586
|
+
linear_module_down=(
|
|
587
|
+
ff_linear_module_down
|
|
588
|
+
if ff_linear_module_down is not None
|
|
589
|
+
else linear_module
|
|
590
|
+
),
|
|
591
|
+
pre_norm=False, # Handled outside the block
|
|
349
592
|
normformer=normformer,
|
|
593
|
+
post_norm=False, # Handled outside the block
|
|
594
|
+
residual_path=False, # Handled outside the block
|
|
595
|
+
checkpoint=checkpoint_ff,
|
|
350
596
|
)
|
|
351
597
|
|
|
598
|
+
self.reset_parameters()
|
|
599
|
+
|
|
352
600
|
@property
|
|
353
601
|
def _kv_distance(self) -> float:
|
|
354
602
|
return self.attn._kv_distance
|
|
355
603
|
|
|
356
604
|
def forward(self, x):
|
|
357
|
-
if not self.training:
|
|
358
|
-
identity_probability = 0.0
|
|
359
|
-
else:
|
|
360
|
-
identity_probability = self.identity_probability
|
|
361
|
-
|
|
362
|
-
# perform the identity operation for some rows in the batch
|
|
363
|
-
identity_count = random.binomial(n=x.size(0), p=identity_probability)
|
|
364
|
-
shuffle_indices = torch.randperm(x.size(0), device=x.device)
|
|
365
|
-
unshuffle_indices = torch.argsort(shuffle_indices)
|
|
366
|
-
shuffled = x[shuffle_indices, :, :]
|
|
367
|
-
identity_x = shuffled[:identity_count, :, :]
|
|
368
|
-
process_x = shuffled[identity_count:, :, :]
|
|
369
605
|
|
|
370
606
|
if self.pre_norm:
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
)
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
607
|
+
x = self.layer_norm_1(x)
|
|
608
|
+
x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
|
|
609
|
+
x = self.layer_norm_2(x)
|
|
610
|
+
x = x + self.drop_path(self.layerscale2(self.ff(x)))
|
|
611
|
+
if self.post_norm: # i.e. in addition! Pre and post.
|
|
612
|
+
x = self.layer_norm_3(x)
|
|
613
|
+
elif self.post_norm: # i.e. only, not prenorm, just post
|
|
614
|
+
x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
|
|
615
|
+
x = self.layer_norm_1(x)
|
|
616
|
+
x = x + self.drop_path(self.layerscale2(self.ff(x)))
|
|
617
|
+
x = self.layer_norm_2(x)
|
|
618
|
+
else: # Not pre or post norm. Stand well back.
|
|
619
|
+
x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
|
|
620
|
+
x = x + self.drop_path(self.layerscale2(self.ff(x)))
|
|
385
621
|
|
|
386
622
|
return x
|
|
387
623
|
|
|
624
|
+
def attention_logits(self, x):
|
|
625
|
+
"""
|
|
626
|
+
Give back the attention scores used in this layer.
|
|
627
|
+
"""
|
|
628
|
+
if self.pre_norm:
|
|
629
|
+
x = self.layer_norm_1(x)
|
|
630
|
+
return self.attn.attention_logits(x, x, x)
|
|
631
|
+
else:
|
|
632
|
+
return self.attn.attention_logits(x, x, x)
|
|
633
|
+
|
|
634
|
+
def reset_parameters(self):
|
|
635
|
+
self.layer_norm_1.reset_parameters()
|
|
636
|
+
self.layer_norm_2.reset_parameters()
|
|
637
|
+
self.layer_norm_3.reset_parameters()
|
|
638
|
+
|
|
639
|
+
self.attn.reset_parameters()
|
|
640
|
+
self.ff.reset_parameters()
|
|
641
|
+
|
|
388
642
|
|
|
389
643
|
class TransformerEncoder(nn.Module):
|
|
390
644
|
"""
|
|
391
645
|
This assumes we already get a sequence of embeddings (e.g. word or image
|
|
392
|
-
patch embeddings).
|
|
646
|
+
patch embeddings).
|
|
393
647
|
"""
|
|
394
648
|
|
|
395
649
|
def __init__(
|
|
@@ -398,53 +652,93 @@ class TransformerEncoder(nn.Module):
|
|
|
398
652
|
d_model,
|
|
399
653
|
n_layers,
|
|
400
654
|
n_heads,
|
|
401
|
-
|
|
655
|
+
absolute_position_embedding=True,
|
|
656
|
+
relative_position_embedding=False,
|
|
402
657
|
source_size=None,
|
|
403
658
|
mlp_ratio=4,
|
|
404
659
|
activation: nn.Module = nn.ReLU,
|
|
405
660
|
activation_kwargs: Optional[dict] = None,
|
|
406
|
-
|
|
661
|
+
ff_linear_module_up=None,
|
|
662
|
+
ff_linear_module_down=None,
|
|
663
|
+
ff_dropout=0.0,
|
|
664
|
+
ff_inner_dropout=0.0,
|
|
665
|
+
ff_outer_dropout=0.0,
|
|
407
666
|
msa_dropout=0.0,
|
|
408
667
|
stochastic_depth=0.0,
|
|
409
668
|
causal=False,
|
|
410
669
|
linear_module=nn.Linear,
|
|
411
|
-
|
|
412
|
-
|
|
670
|
+
utility_tokens=0,
|
|
671
|
+
talking_heads=False,
|
|
672
|
+
return_utility_tokens=False,
|
|
413
673
|
pre_norm=True,
|
|
674
|
+
post_norm=False,
|
|
414
675
|
normformer=False,
|
|
676
|
+
msa_scaling="d",
|
|
677
|
+
checkpoint_ff=True,
|
|
678
|
+
layerscale=True,
|
|
415
679
|
):
|
|
416
|
-
|
|
417
|
-
|
|
680
|
+
"""
|
|
681
|
+
Args:
|
|
682
|
+
msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
|
|
683
|
+
to mimic the original Attention is All You Need approach of
|
|
684
|
+
dividing by the sqrt of the embedding Dimension or "d" per
|
|
685
|
+
"Tensor Programs V...". Default "d"
|
|
686
|
+
"""
|
|
687
|
+
|
|
688
|
+
if relative_position_embedding and (source_size is None):
|
|
689
|
+
raise ValueError(
|
|
690
|
+
"`source_size` for TransformerEncoder cannot be None if"
|
|
691
|
+
" `relative_position_embedding` is True"
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
if absolute_position_embedding and (seq_len is None):
|
|
695
|
+
raise ValueError(
|
|
696
|
+
"`seq_len` for TransformerEncoder cannot be None if"
|
|
697
|
+
" `absolute_position_embedding` is True"
|
|
698
|
+
)
|
|
418
699
|
|
|
419
700
|
super().__init__()
|
|
701
|
+
|
|
702
|
+
if FLASH_ATTN and talking_heads:
|
|
703
|
+
warnings.warn(
|
|
704
|
+
"Using talking heads currently prevents using flash attention.",
|
|
705
|
+
stacklevel=2,
|
|
706
|
+
)
|
|
707
|
+
|
|
420
708
|
self.seq_len = seq_len
|
|
421
709
|
self.n_heads = n_heads
|
|
422
|
-
self.
|
|
423
|
-
self.
|
|
424
|
-
|
|
425
|
-
# Initialise
|
|
426
|
-
if self.
|
|
427
|
-
self.
|
|
428
|
-
|
|
429
|
-
|
|
710
|
+
self._utility_tokens = utility_tokens
|
|
711
|
+
self.return_utility_tokens = return_utility_tokens
|
|
712
|
+
|
|
713
|
+
# Initialise utility tokens with normal init, like usual Pytorch embeddings
|
|
714
|
+
if self._utility_tokens:
|
|
715
|
+
self._utility_token_embedding = nn.Parameter(
|
|
716
|
+
torch.empty(self._utility_tokens, d_model)
|
|
717
|
+
)
|
|
718
|
+
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
719
|
+
else:
|
|
720
|
+
self._utility_token_embedding = None
|
|
721
|
+
|
|
722
|
+
if self._utility_tokens and (self.seq_len is not None):
|
|
723
|
+
self.full_sequence_length = self.seq_len + self._utility_tokens
|
|
430
724
|
else:
|
|
431
|
-
self._bos_embedding = None
|
|
432
725
|
self.full_sequence_length = self.seq_len
|
|
433
726
|
|
|
434
727
|
self.d_model = d_model
|
|
435
728
|
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
if self.position_embedding_type == "absolute":
|
|
729
|
+
if absolute_position_embedding:
|
|
439
730
|
self.absolute_position_embedding = nn.Embedding(
|
|
440
731
|
self.full_sequence_length, d_model
|
|
441
732
|
)
|
|
733
|
+
else:
|
|
734
|
+
self.absolute_position_embedding = None
|
|
442
735
|
|
|
443
|
-
self.mlp_dropout =
|
|
736
|
+
self.mlp_dropout = ff_dropout
|
|
444
737
|
self.msa_dropout = msa_dropout
|
|
445
738
|
self.stochastic_depth = stochastic_depth
|
|
446
739
|
|
|
447
|
-
assert isinstance(n_layers, int)
|
|
740
|
+
assert isinstance(n_layers, int)
|
|
741
|
+
|
|
448
742
|
if n_layers == 1:
|
|
449
743
|
self.stochastic_depth_probabilities = [0.0]
|
|
450
744
|
else:
|
|
@@ -459,35 +753,48 @@ class TransformerEncoder(nn.Module):
|
|
|
459
753
|
self.full_sequence_length,
|
|
460
754
|
d_model,
|
|
461
755
|
n_heads,
|
|
462
|
-
|
|
756
|
+
relative_position_embedding=relative_position_embedding,
|
|
463
757
|
source_size=source_size,
|
|
464
|
-
|
|
758
|
+
utility_tokens=utility_tokens,
|
|
759
|
+
talking_heads=talking_heads,
|
|
465
760
|
mlp_ratio=mlp_ratio,
|
|
466
761
|
activation=activation,
|
|
467
762
|
activation_kwargs=activation_kwargs,
|
|
468
|
-
|
|
763
|
+
ff_linear_module_up=ff_linear_module_up,
|
|
764
|
+
ff_linear_module_down=ff_linear_module_down,
|
|
765
|
+
msa_scaling=msa_scaling,
|
|
766
|
+
ff_dropout=ff_dropout,
|
|
767
|
+
ff_inner_dropout=ff_inner_dropout,
|
|
768
|
+
ff_outer_dropout=ff_outer_dropout,
|
|
469
769
|
msa_dropout=msa_dropout,
|
|
470
770
|
identity_probability=self.stochastic_depth_probabilities[i],
|
|
471
771
|
causal=causal,
|
|
472
772
|
linear_module=linear_module,
|
|
473
773
|
pre_norm=pre_norm,
|
|
774
|
+
post_norm=post_norm,
|
|
474
775
|
normformer=normformer,
|
|
776
|
+
checkpoint_ff=checkpoint_ff,
|
|
777
|
+
layerscale=layerscale,
|
|
475
778
|
)
|
|
476
779
|
for i in range(n_layers)
|
|
477
780
|
]
|
|
478
781
|
)
|
|
479
782
|
|
|
783
|
+
self.reset_parameters()
|
|
784
|
+
|
|
480
785
|
@property
|
|
481
786
|
def _kv_distances(self) -> float:
|
|
482
787
|
return ",".join([str(block._kv_distance) for block in self.blocks])
|
|
483
788
|
|
|
484
|
-
def
|
|
485
|
-
if self.
|
|
486
|
-
x = torch.cat(
|
|
789
|
+
def preprocess(self, x):
|
|
790
|
+
if self._utility_tokens:
|
|
791
|
+
x = torch.cat(
|
|
792
|
+
[self._utility_token_embedding.expand(x.size(0), -1, -1), x], dim=1
|
|
793
|
+
)
|
|
487
794
|
else:
|
|
488
795
|
x = x
|
|
489
796
|
|
|
490
|
-
if self.
|
|
797
|
+
if self.absolute_position_embedding is not None:
|
|
491
798
|
x = x + self.absolute_position_embedding(
|
|
492
799
|
torch.arange(
|
|
493
800
|
0, self.full_sequence_length, dtype=torch.long, device=x.device
|
|
@@ -496,10 +803,40 @@ class TransformerEncoder(nn.Module):
|
|
|
496
803
|
) # to shape (1, seq_len) to broadcast over batch
|
|
497
804
|
)
|
|
498
805
|
|
|
806
|
+
return x
|
|
807
|
+
|
|
808
|
+
def forward(self, x):
|
|
809
|
+
|
|
810
|
+
x = self.preprocess(x)
|
|
811
|
+
|
|
499
812
|
for block in self.blocks:
|
|
500
813
|
x = block(x)
|
|
501
814
|
|
|
502
|
-
if self.
|
|
503
|
-
return x[:, self.
|
|
815
|
+
if self._utility_tokens and not self.return_utility_tokens:
|
|
816
|
+
return x[:, self._utility_tokens :, :]
|
|
504
817
|
else:
|
|
505
818
|
return x
|
|
819
|
+
|
|
820
|
+
def attention_logits(self, x):
|
|
821
|
+
|
|
822
|
+
x = self.preprocess(x)
|
|
823
|
+
|
|
824
|
+
layer_scores = []
|
|
825
|
+
|
|
826
|
+
for block in self.blocks:
|
|
827
|
+
# Get attention scores with shape (batch, 1, head, seq_len, seq_len)
|
|
828
|
+
layer_attention_logits = block.attention_logits(x).unsqueeze(1)
|
|
829
|
+
layer_scores.append(layer_attention_logits)
|
|
830
|
+
x = block(x)
|
|
831
|
+
|
|
832
|
+
return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
|
|
833
|
+
|
|
834
|
+
def reset_parameters(self):
|
|
835
|
+
if self._utility_token_embedding is not None:
|
|
836
|
+
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
837
|
+
|
|
838
|
+
if self.absolute_position_embedding is not None:
|
|
839
|
+
self.absolute_position_embedding.reset_parameters()
|
|
840
|
+
|
|
841
|
+
for block in self.blocks:
|
|
842
|
+
block.reset_parameters()
|