broccoli-ml 3.3.1__tar.gz → 5.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-3.3.1 → broccoli_ml-5.1.0}/PKG-INFO +1 -1
- {broccoli_ml-3.3.1 → broccoli_ml-5.1.0}/broccoli/transformer.py +229 -78
- {broccoli_ml-3.3.1 → broccoli_ml-5.1.0}/broccoli/vit.py +52 -0
- {broccoli_ml-3.3.1 → broccoli_ml-5.1.0}/pyproject.toml +1 -1
- {broccoli_ml-3.3.1 → broccoli_ml-5.1.0}/LICENSE +0 -0
- {broccoli_ml-3.3.1 → broccoli_ml-5.1.0}/README.md +0 -0
- {broccoli_ml-3.3.1 → broccoli_ml-5.1.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-3.3.1 → broccoli_ml-5.1.0}/broccoli/activation.py +0 -0
- {broccoli_ml-3.3.1 → broccoli_ml-5.1.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-3.3.1 → broccoli_ml-5.1.0}/broccoli/linear.py +0 -0
- {broccoli_ml-3.3.1 → broccoli_ml-5.1.0}/broccoli/rope.py +0 -0
- {broccoli_ml-3.3.1 → broccoli_ml-5.1.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-3.3.1 → broccoli_ml-5.1.0}/broccoli/utils.py +0 -0
|
@@ -1,14 +1,23 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Optional, Tuple
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
import torch.nn as nn
|
|
6
6
|
import torch.nn.functional as F
|
|
7
|
+
from torch.utils.checkpoint import checkpoint
|
|
7
8
|
|
|
8
9
|
from einops import rearrange
|
|
9
10
|
|
|
10
11
|
from .rope import RotaryEmbedding, apply_rotary_emb
|
|
11
12
|
|
|
13
|
+
try:
|
|
14
|
+
from flash_attn import flash_attn_func
|
|
15
|
+
|
|
16
|
+
FLASH_ATTN = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
pass
|
|
19
|
+
FLASH_ATTN = False
|
|
20
|
+
|
|
12
21
|
|
|
13
22
|
def drop_path(
|
|
14
23
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
@@ -93,6 +102,15 @@ class MHAttention(nn.Module):
|
|
|
93
102
|
|
|
94
103
|
self.head_dim = self.embed_dim // self.n_heads
|
|
95
104
|
|
|
105
|
+
if self.scaling == "sqrtd":
|
|
106
|
+
self.scaling_factor = 1 / math.sqrt(self.head_dim)
|
|
107
|
+
elif self.scaling == "d":
|
|
108
|
+
# 8/d_model for backwards compatibility,
|
|
109
|
+
# per https://github.com/microsoft/mup
|
|
110
|
+
self.scaling_factor = 8 / self.head_dim
|
|
111
|
+
else:
|
|
112
|
+
raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
|
|
113
|
+
|
|
96
114
|
self.q_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
97
115
|
self.k_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
98
116
|
self.v_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
@@ -113,6 +131,8 @@ class MHAttention(nn.Module):
|
|
|
113
131
|
self.source_size = source_size
|
|
114
132
|
self.bos_tokens = bos_tokens
|
|
115
133
|
|
|
134
|
+
self.reset_parameters()
|
|
135
|
+
|
|
116
136
|
@property
|
|
117
137
|
def _kv_distance(self) -> float:
|
|
118
138
|
"""
|
|
@@ -132,7 +152,71 @@ class MHAttention(nn.Module):
|
|
|
132
152
|
|
|
133
153
|
return 1 - similarity
|
|
134
154
|
|
|
135
|
-
def
|
|
155
|
+
def add_axial_rope(
|
|
156
|
+
self, q: torch.Tensor, k: torch.Tensor
|
|
157
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
158
|
+
"""
|
|
159
|
+
Apply Axial RoPE to all tokens except BOS tokens
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
if len(self.source_size) == 1:
|
|
163
|
+
spatial_dimension_names = "D1"
|
|
164
|
+
spatial_dimension_values = {"D1": self.source_size[0]}
|
|
165
|
+
elif len(self.source_size) == 2:
|
|
166
|
+
spatial_dimension_names = "D1 D2"
|
|
167
|
+
spatial_dimension_values = {
|
|
168
|
+
"D1": self.source_size[0],
|
|
169
|
+
"D2": self.source_size[1],
|
|
170
|
+
}
|
|
171
|
+
elif len(self.source_size) == 3:
|
|
172
|
+
spatial_dimension_names = "D1 D2 D3"
|
|
173
|
+
spatial_dimension_values = {
|
|
174
|
+
"D1": self.source_size[0],
|
|
175
|
+
"D2": self.source_size[1],
|
|
176
|
+
"D3": self.source_size[2],
|
|
177
|
+
}
|
|
178
|
+
else:
|
|
179
|
+
raise NotImplementedError(
|
|
180
|
+
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
q_bos, q_img = q[:, : self.bos_tokens, :], q[:, self.bos_tokens :, :]
|
|
184
|
+
k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
|
|
185
|
+
|
|
186
|
+
q_img = rearrange(
|
|
187
|
+
q_img,
|
|
188
|
+
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
189
|
+
**spatial_dimension_values,
|
|
190
|
+
)
|
|
191
|
+
k_img = rearrange(
|
|
192
|
+
k_img,
|
|
193
|
+
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
194
|
+
**spatial_dimension_values,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
|
|
198
|
+
|
|
199
|
+
q_img = apply_rotary_emb(freqs, q_img)
|
|
200
|
+
k_img = apply_rotary_emb(freqs, k_img)
|
|
201
|
+
|
|
202
|
+
q_img = rearrange(
|
|
203
|
+
q_img,
|
|
204
|
+
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
205
|
+
)
|
|
206
|
+
k_img = rearrange(
|
|
207
|
+
k_img,
|
|
208
|
+
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Re-combine the BOS tokens and the RoPE-enhanced image tokens
|
|
212
|
+
q = torch.cat([q_bos, q_img], dim=1)
|
|
213
|
+
k = torch.cat([k_bos, k_img], dim=1)
|
|
214
|
+
|
|
215
|
+
return q, k
|
|
216
|
+
|
|
217
|
+
def project_qkv(
|
|
218
|
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
|
219
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
136
220
|
query_batch_size, query_tokens, query_features = q.size()
|
|
137
221
|
key_batch_size, key_tokens, key_features = k.size()
|
|
138
222
|
|
|
@@ -145,66 +229,64 @@ class MHAttention(nn.Module):
|
|
|
145
229
|
|
|
146
230
|
if self.causal:
|
|
147
231
|
assert query_tokens == key_tokens
|
|
148
|
-
assert query_tokens == self.
|
|
232
|
+
assert query_tokens == self.seq_len
|
|
149
233
|
|
|
150
|
-
|
|
151
|
-
q = self.q_proj(q)
|
|
152
|
-
k = self.k_proj(k)
|
|
153
|
-
v = self.v_proj(v)
|
|
234
|
+
q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
|
|
154
235
|
|
|
155
|
-
# Rearrange dimensions and add RoPE if needed
|
|
156
236
|
if self.rotary_embedding is not None:
|
|
237
|
+
q, k = self.add_axial_rope(q, k)
|
|
157
238
|
|
|
158
|
-
|
|
159
|
-
spatial_dimension_names = "D1"
|
|
160
|
-
spatial_dimension_values = {"D1": self.source_size[0]}
|
|
161
|
-
elif len(self.source_size) == 2:
|
|
162
|
-
spatial_dimension_names = "D1 D2"
|
|
163
|
-
spatial_dimension_values = {
|
|
164
|
-
"D1": self.source_size[0],
|
|
165
|
-
"D2": self.source_size[1],
|
|
166
|
-
}
|
|
167
|
-
elif len(self.source_size) == 3:
|
|
168
|
-
spatial_dimension_names = "D1 D2 D3"
|
|
169
|
-
spatial_dimension_values = {
|
|
170
|
-
"D1": self.source_size[0],
|
|
171
|
-
"D2": self.source_size[1],
|
|
172
|
-
"D3": self.source_size[2],
|
|
173
|
-
}
|
|
174
|
-
else:
|
|
175
|
-
raise NotImplementedError(
|
|
176
|
-
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
177
|
-
)
|
|
239
|
+
return q, k, v
|
|
178
240
|
|
|
179
|
-
|
|
180
|
-
k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
|
|
241
|
+
def forward(self, q, k, v):
|
|
181
242
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
243
|
+
q, k, v = self.project_qkv(q, k, v)
|
|
244
|
+
|
|
245
|
+
if FLASH_ATTN:
|
|
246
|
+
# Divide Q/K/V into heads
|
|
247
|
+
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
248
|
+
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
249
|
+
v = rearrange(v, "b t (h d) -> b t h d", h=self.n_heads)
|
|
250
|
+
|
|
251
|
+
output_with_heads = flash_attn_func(
|
|
252
|
+
q,
|
|
253
|
+
k,
|
|
254
|
+
v,
|
|
255
|
+
dropout_p=self.dropout.p if self.training else 0.0,
|
|
256
|
+
softmax_scale=self.scaling_factor,
|
|
257
|
+
causal=self.causal,
|
|
191
258
|
)
|
|
192
|
-
freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
|
|
193
|
-
q_img = apply_rotary_emb(freqs, q_img)
|
|
194
|
-
k_img = apply_rotary_emb(freqs, k_img)
|
|
195
259
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
)
|
|
260
|
+
output_without_heads = rearrange(output_with_heads, "b t h d -> b t (h d)")
|
|
261
|
+
|
|
262
|
+
return self.out_proj(output_without_heads)
|
|
263
|
+
else:
|
|
264
|
+
# Divide Q/K/V into heads
|
|
265
|
+
q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
|
|
266
|
+
k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
|
|
267
|
+
v = rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
|
|
268
|
+
|
|
269
|
+
qk_scores = q @ k.transpose(-1, -2)
|
|
270
|
+
|
|
271
|
+
qk_scores *= self.scaling_factor
|
|
272
|
+
|
|
273
|
+
# Apply mask if causal (must come before softmax)
|
|
274
|
+
if self.causal:
|
|
275
|
+
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
276
|
+
|
|
277
|
+
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
278
|
+
|
|
279
|
+
qk_scores = self.dropout(qk_scores)
|
|
204
280
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
281
|
+
output_with_heads = qk_scores @ v
|
|
282
|
+
|
|
283
|
+
output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
|
|
284
|
+
|
|
285
|
+
return self.out_proj(output_without_heads)
|
|
286
|
+
|
|
287
|
+
def attention_scores(self, q, k, v):
|
|
288
|
+
|
|
289
|
+
q, k, v = self.project_qkv(q, k, v)
|
|
208
290
|
|
|
209
291
|
# Divide Q/K/V into heads
|
|
210
292
|
q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
|
|
@@ -213,13 +295,7 @@ class MHAttention(nn.Module):
|
|
|
213
295
|
|
|
214
296
|
qk_scores = q @ k.transpose(-1, -2)
|
|
215
297
|
|
|
216
|
-
|
|
217
|
-
qk_scores /= math.sqrt(self.head_dim)
|
|
218
|
-
elif self.scaling == "d":
|
|
219
|
-
# for backwards compatibility, per https://github.com/microsoft/mup
|
|
220
|
-
qk_scores *= 8 / self.head_dim
|
|
221
|
-
else:
|
|
222
|
-
raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
|
|
298
|
+
qk_scores *= self.scaling_factor
|
|
223
299
|
|
|
224
300
|
# Apply mask if causal (must come before softmax)
|
|
225
301
|
if self.causal:
|
|
@@ -227,11 +303,14 @@ class MHAttention(nn.Module):
|
|
|
227
303
|
|
|
228
304
|
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
229
305
|
|
|
230
|
-
|
|
306
|
+
return qk_scores # (batch, head, seq_len, seq_len)
|
|
231
307
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
308
|
+
def reset_parameters(self):
|
|
309
|
+
# Default nn.Linear init is kaiming_uniform, which is fine
|
|
310
|
+
self.q_proj.reset_parameters()
|
|
311
|
+
self.k_proj.reset_parameters()
|
|
312
|
+
self.v_proj.reset_parameters()
|
|
313
|
+
self.out_proj.reset_parameters()
|
|
235
314
|
|
|
236
315
|
|
|
237
316
|
class FeedforwardBlock(nn.Module):
|
|
@@ -253,9 +332,11 @@ class FeedforwardBlock(nn.Module):
|
|
|
253
332
|
normformer=False,
|
|
254
333
|
post_norm=True,
|
|
255
334
|
residual_path=True,
|
|
335
|
+
checkpoint=True,
|
|
256
336
|
):
|
|
257
337
|
super().__init__()
|
|
258
338
|
|
|
339
|
+
self.checkpoint = checkpoint
|
|
259
340
|
self.residual_path = residual_path
|
|
260
341
|
self.post_norm = post_norm
|
|
261
342
|
|
|
@@ -293,13 +374,30 @@ class FeedforwardBlock(nn.Module):
|
|
|
293
374
|
]
|
|
294
375
|
)
|
|
295
376
|
|
|
377
|
+
self.reset_parameters()
|
|
378
|
+
|
|
296
379
|
def forward(self, x):
|
|
380
|
+
|
|
381
|
+
if self.checkpoint:
|
|
382
|
+
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
383
|
+
else:
|
|
384
|
+
processed = self.process(x)
|
|
385
|
+
|
|
297
386
|
if self.residual_path and self.post_norm:
|
|
298
|
-
return self.layernorm(x +
|
|
387
|
+
return self.layernorm(x + processed)
|
|
299
388
|
elif self.residual_path:
|
|
300
|
-
return x +
|
|
389
|
+
return x + processed
|
|
301
390
|
else:
|
|
302
|
-
return
|
|
391
|
+
return processed
|
|
392
|
+
|
|
393
|
+
def reset_parameters(self):
|
|
394
|
+
if self.post_norm:
|
|
395
|
+
self.layernorm.reset_parameters()
|
|
396
|
+
|
|
397
|
+
# Iterate over the sequential block to reset parameters
|
|
398
|
+
for module in self.process:
|
|
399
|
+
if hasattr(module, "reset_parameters"):
|
|
400
|
+
module.reset_parameters()
|
|
303
401
|
|
|
304
402
|
|
|
305
403
|
class TransformerBlock(nn.Module):
|
|
@@ -333,6 +431,7 @@ class TransformerBlock(nn.Module):
|
|
|
333
431
|
pre_norm=True,
|
|
334
432
|
post_norm=False,
|
|
335
433
|
normformer=False,
|
|
434
|
+
checkpoint_ff=True,
|
|
336
435
|
):
|
|
337
436
|
"""
|
|
338
437
|
Args:
|
|
@@ -401,8 +500,11 @@ class TransformerBlock(nn.Module):
|
|
|
401
500
|
normformer=normformer,
|
|
402
501
|
post_norm=False, # Handled outside the block
|
|
403
502
|
residual_path=False, # Handled outside the block
|
|
503
|
+
checkpoint=checkpoint_ff,
|
|
404
504
|
)
|
|
405
505
|
|
|
506
|
+
self.reset_parameters()
|
|
507
|
+
|
|
406
508
|
@property
|
|
407
509
|
def _kv_distance(self) -> float:
|
|
408
510
|
return self.attn._kv_distance
|
|
@@ -410,29 +512,46 @@ class TransformerBlock(nn.Module):
|
|
|
410
512
|
def forward(self, x):
|
|
411
513
|
|
|
412
514
|
if self.pre_norm:
|
|
413
|
-
|
|
414
|
-
x = x + self.drop_path(self.attn(
|
|
415
|
-
|
|
416
|
-
x = x + self.drop_path(self.ff(
|
|
417
|
-
|
|
515
|
+
x = self.layer_norm_1(x)
|
|
516
|
+
x = x + self.drop_path(self.attn(x, x, x))
|
|
517
|
+
x = self.layer_norm_2(x)
|
|
518
|
+
x = x + self.drop_path(self.ff(x))
|
|
519
|
+
if self.post_norm: # i.e. in addition! Pre and post.
|
|
520
|
+
x = self.layer_norm_3(x)
|
|
521
|
+
elif self.post_norm: # i.e. only, not prenorm, just post
|
|
418
522
|
x = x + self.drop_path(self.attn(x, x, x))
|
|
419
523
|
x = self.layer_norm_1(x)
|
|
420
524
|
x = x + self.drop_path(self.ff(x))
|
|
421
525
|
x = self.layer_norm_2(x)
|
|
422
|
-
else:
|
|
526
|
+
else: # Not pre or post norm. Stand well back.
|
|
423
527
|
x = x + self.drop_path(self.attn(x, x, x))
|
|
424
528
|
x = x + self.drop_path(self.ff(x))
|
|
425
529
|
|
|
426
|
-
if self.pre_norm and self.post_norm:
|
|
427
|
-
x = self.layer_norm_3(x)
|
|
428
|
-
|
|
429
530
|
return x
|
|
430
531
|
|
|
532
|
+
def attention_scores(self, x):
|
|
533
|
+
"""
|
|
534
|
+
Give back the attention scores used in this layer.
|
|
535
|
+
"""
|
|
536
|
+
if self.pre_norm:
|
|
537
|
+
x = self.layer_norm_1(x)
|
|
538
|
+
return self.attn(x, x, x)
|
|
539
|
+
else:
|
|
540
|
+
return self.attn(x, x, x)
|
|
541
|
+
|
|
542
|
+
def reset_parameters(self):
|
|
543
|
+
self.layer_norm_1.reset_parameters()
|
|
544
|
+
self.layer_norm_2.reset_parameters()
|
|
545
|
+
self.layer_norm_3.reset_parameters()
|
|
546
|
+
|
|
547
|
+
self.attn.reset_parameters()
|
|
548
|
+
self.ff.reset_parameters()
|
|
549
|
+
|
|
431
550
|
|
|
432
551
|
class TransformerEncoder(nn.Module):
|
|
433
552
|
"""
|
|
434
553
|
This assumes we already get a sequence of embeddings (e.g. word or image
|
|
435
|
-
patch embeddings).
|
|
554
|
+
patch embeddings).
|
|
436
555
|
"""
|
|
437
556
|
|
|
438
557
|
def __init__(
|
|
@@ -460,6 +579,7 @@ class TransformerEncoder(nn.Module):
|
|
|
460
579
|
post_norm=False,
|
|
461
580
|
normformer=False,
|
|
462
581
|
msa_scaling="d",
|
|
582
|
+
checkpoint_ff=True,
|
|
463
583
|
):
|
|
464
584
|
"""
|
|
465
585
|
Args:
|
|
@@ -536,16 +656,19 @@ class TransformerEncoder(nn.Module):
|
|
|
536
656
|
pre_norm=pre_norm,
|
|
537
657
|
post_norm=post_norm,
|
|
538
658
|
normformer=normformer,
|
|
659
|
+
checkpoint_ff=checkpoint_ff,
|
|
539
660
|
)
|
|
540
661
|
for i in range(n_layers)
|
|
541
662
|
]
|
|
542
663
|
)
|
|
543
664
|
|
|
665
|
+
self.reset_parameters()
|
|
666
|
+
|
|
544
667
|
@property
|
|
545
668
|
def _kv_distances(self) -> float:
|
|
546
669
|
return ",".join([str(block._kv_distance) for block in self.blocks])
|
|
547
670
|
|
|
548
|
-
def
|
|
671
|
+
def preprocess(self, x):
|
|
549
672
|
if self._bos_tokens:
|
|
550
673
|
x = torch.cat([self._bos_embedding.expand(x.size(0), -1, -1), x], dim=1)
|
|
551
674
|
else:
|
|
@@ -560,6 +683,10 @@ class TransformerEncoder(nn.Module):
|
|
|
560
683
|
) # to shape (1, seq_len) to broadcast over batch
|
|
561
684
|
)
|
|
562
685
|
|
|
686
|
+
def forward(self, x):
|
|
687
|
+
|
|
688
|
+
x = self.preprocess(x)
|
|
689
|
+
|
|
563
690
|
for block in self.blocks:
|
|
564
691
|
x = block(x)
|
|
565
692
|
|
|
@@ -567,3 +694,27 @@ class TransformerEncoder(nn.Module):
|
|
|
567
694
|
return x[:, self._bos_tokens :, :]
|
|
568
695
|
else:
|
|
569
696
|
return x
|
|
697
|
+
|
|
698
|
+
def attention_scores(self, x):
|
|
699
|
+
|
|
700
|
+
x = self.preprocess(x)
|
|
701
|
+
|
|
702
|
+
layer_scores = []
|
|
703
|
+
|
|
704
|
+
for block in self.blocks:
|
|
705
|
+
# Get attention scores with shape (batch, 1, head, seq_len, seq_len)
|
|
706
|
+
layer_attention_scores = block.attention_scores(x).unsqueeze(1)
|
|
707
|
+
layer_scores.append(layer_attention_scores)
|
|
708
|
+
x = block(x)
|
|
709
|
+
|
|
710
|
+
return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
|
|
711
|
+
|
|
712
|
+
def reset_parameters(self):
|
|
713
|
+
if self._bos_embedding is not None:
|
|
714
|
+
nn.init.normal_(self._bos_embedding, mean=0.0, std=1.0)
|
|
715
|
+
|
|
716
|
+
if self.absolute_position_embedding is not None:
|
|
717
|
+
self.absolute_position_embedding.reset_parameters()
|
|
718
|
+
|
|
719
|
+
for block in self.blocks:
|
|
720
|
+
block.reset_parameters()
|
|
@@ -9,7 +9,9 @@ from .utils import PadTensor
|
|
|
9
9
|
from einops import einsum
|
|
10
10
|
from einops.layers.torch import Rearrange
|
|
11
11
|
|
|
12
|
+
import torch
|
|
12
13
|
import torch.nn as nn
|
|
14
|
+
import torch.nn.functional as F
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
class GetCLSToken(nn.Module):
|
|
@@ -31,10 +33,18 @@ class SequencePool(nn.Module):
|
|
|
31
33
|
]
|
|
32
34
|
)
|
|
33
35
|
|
|
36
|
+
self.reset_parameters()
|
|
37
|
+
|
|
34
38
|
def forward(self, x):
|
|
35
39
|
weights = self.attention(x)
|
|
36
40
|
return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
|
|
37
41
|
|
|
42
|
+
def reset_parameters(self):
|
|
43
|
+
# Iterate over modules in the sequential block
|
|
44
|
+
for module in self.attention:
|
|
45
|
+
if hasattr(module, "reset_parameters"):
|
|
46
|
+
module.reset_parameters()
|
|
47
|
+
|
|
38
48
|
|
|
39
49
|
class ClassificationHead(nn.Module):
|
|
40
50
|
"""
|
|
@@ -71,9 +81,16 @@ class ClassificationHead(nn.Module):
|
|
|
71
81
|
]
|
|
72
82
|
)
|
|
73
83
|
|
|
84
|
+
self.reset_parameters()
|
|
85
|
+
|
|
74
86
|
def forward(self, x):
|
|
75
87
|
return self.classification_process(x)
|
|
76
88
|
|
|
89
|
+
def reset_parameters(self):
|
|
90
|
+
for module in self.classification_process:
|
|
91
|
+
if hasattr(module, "reset_parameters"):
|
|
92
|
+
module.reset_parameters()
|
|
93
|
+
|
|
77
94
|
|
|
78
95
|
class SequencePoolClassificationHead(ClassificationHead):
|
|
79
96
|
"""
|
|
@@ -106,6 +123,8 @@ class SequencePoolClassificationHead(ClassificationHead):
|
|
|
106
123
|
]
|
|
107
124
|
)
|
|
108
125
|
|
|
126
|
+
self.reset_parameters()
|
|
127
|
+
|
|
109
128
|
|
|
110
129
|
class ViTEncoder(nn.Module):
|
|
111
130
|
"""
|
|
@@ -160,6 +179,7 @@ class ViTEncoder(nn.Module):
|
|
|
160
179
|
transformer_mlp_dropout=0.0,
|
|
161
180
|
transformer_msa_dropout=0.1,
|
|
162
181
|
transformer_stochastic_depth=0.1,
|
|
182
|
+
transformer_checkpoint_ff=True,
|
|
163
183
|
linear_module=nn.Linear,
|
|
164
184
|
):
|
|
165
185
|
super().__init__()
|
|
@@ -321,6 +341,7 @@ class ViTEncoder(nn.Module):
|
|
|
321
341
|
pre_norm=transformer_pre_norm,
|
|
322
342
|
normformer=transformer_normformer,
|
|
323
343
|
post_norm=transformer_post_norm,
|
|
344
|
+
checkpoint_ff=transformer_checkpoint_ff,
|
|
324
345
|
)
|
|
325
346
|
else:
|
|
326
347
|
self.transformer = nn.Identity()
|
|
@@ -354,6 +375,7 @@ class ViTEncoder(nn.Module):
|
|
|
354
375
|
normformer=transformer_normformer,
|
|
355
376
|
post_norm=transformer_post_norm,
|
|
356
377
|
residual_path=transformer_initial_ff_residual_path,
|
|
378
|
+
checkpoint=transformer_checkpoint_ff,
|
|
357
379
|
)
|
|
358
380
|
else:
|
|
359
381
|
self.initial_ff = nn.Identity()
|
|
@@ -373,9 +395,20 @@ class ViTEncoder(nn.Module):
|
|
|
373
395
|
]
|
|
374
396
|
)
|
|
375
397
|
|
|
398
|
+
self.reset_parameters()
|
|
399
|
+
|
|
376
400
|
def forward(self, x):
|
|
377
401
|
return self.encoder(x)
|
|
378
402
|
|
|
403
|
+
def attention_scores(self, x):
|
|
404
|
+
x = self.encoder[:-1](x)
|
|
405
|
+
return self.encoder[-1].attention_scores(x)
|
|
406
|
+
|
|
407
|
+
def reset_parameters(self):
|
|
408
|
+
for module in self.encoder:
|
|
409
|
+
if hasattr(module, "reset_parameters"):
|
|
410
|
+
module.reset_parameters()
|
|
411
|
+
|
|
379
412
|
|
|
380
413
|
class ViT(nn.Module):
|
|
381
414
|
"""
|
|
@@ -426,6 +459,7 @@ class ViT(nn.Module):
|
|
|
426
459
|
transformer_mlp_dropout=0.0,
|
|
427
460
|
transformer_msa_dropout=0.1,
|
|
428
461
|
transformer_stochastic_depth=0.1,
|
|
462
|
+
transformer_checkpoint_ff=True,
|
|
429
463
|
head=SequencePoolClassificationHead,
|
|
430
464
|
batch_norm_logits=True,
|
|
431
465
|
logit_projection_layer=nn.Linear,
|
|
@@ -492,6 +526,7 @@ class ViT(nn.Module):
|
|
|
492
526
|
transformer_mlp_dropout=transformer_mlp_dropout,
|
|
493
527
|
transformer_msa_dropout=transformer_msa_dropout,
|
|
494
528
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
529
|
+
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
495
530
|
linear_module=linear_module,
|
|
496
531
|
)
|
|
497
532
|
|
|
@@ -502,9 +537,26 @@ class ViT(nn.Module):
|
|
|
502
537
|
batch_norm_logits=batch_norm_logits,
|
|
503
538
|
)
|
|
504
539
|
|
|
540
|
+
self.reset_parameters()
|
|
541
|
+
|
|
505
542
|
@property
|
|
506
543
|
def sequence_length(self):
|
|
507
544
|
return self.encoder.sequence_length
|
|
508
545
|
|
|
509
546
|
def forward(self, x):
|
|
510
547
|
return self.pool(self.encoder(x))
|
|
548
|
+
|
|
549
|
+
def attention_scores(self, x):
|
|
550
|
+
return self.encoder.attention_scores(x)
|
|
551
|
+
|
|
552
|
+
def head_to_bos_token_attention(self, x):
|
|
553
|
+
all_attention = self.attention_scores(x)
|
|
554
|
+
batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
|
|
555
|
+
sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
|
|
556
|
+
n_bos_tokens = self.encoder.encoder._bos_tokens
|
|
557
|
+
just_bos = sequence_averages[:, :, :n_bos_tokens]
|
|
558
|
+
return F.softmax(just_bos, dim=-1) # (layer, head, bos_token)
|
|
559
|
+
|
|
560
|
+
def reset_parameters(self):
|
|
561
|
+
self.encoder.reset_parameters()
|
|
562
|
+
self.pool.reset_parameters()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|