broccoli-ml 4.0.1__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +194 -74
- broccoli/vit.py +52 -0
- {broccoli_ml-4.0.1.dist-info → broccoli_ml-5.1.0.dist-info}/METADATA +1 -1
- {broccoli_ml-4.0.1.dist-info → broccoli_ml-5.1.0.dist-info}/RECORD +6 -6
- {broccoli_ml-4.0.1.dist-info → broccoli_ml-5.1.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-4.0.1.dist-info → broccoli_ml-5.1.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Optional, Tuple
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
import torch.nn as nn
|
|
@@ -102,6 +102,15 @@ class MHAttention(nn.Module):
|
|
|
102
102
|
|
|
103
103
|
self.head_dim = self.embed_dim // self.n_heads
|
|
104
104
|
|
|
105
|
+
if self.scaling == "sqrtd":
|
|
106
|
+
self.scaling_factor = 1 / math.sqrt(self.head_dim)
|
|
107
|
+
elif self.scaling == "d":
|
|
108
|
+
# 8/d_model for backwards compatibility,
|
|
109
|
+
# per https://github.com/microsoft/mup
|
|
110
|
+
self.scaling_factor = 8 / self.head_dim
|
|
111
|
+
else:
|
|
112
|
+
raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
|
|
113
|
+
|
|
105
114
|
self.q_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
106
115
|
self.k_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
107
116
|
self.v_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
@@ -122,6 +131,8 @@ class MHAttention(nn.Module):
|
|
|
122
131
|
self.source_size = source_size
|
|
123
132
|
self.bos_tokens = bos_tokens
|
|
124
133
|
|
|
134
|
+
self.reset_parameters()
|
|
135
|
+
|
|
125
136
|
@property
|
|
126
137
|
def _kv_distance(self) -> float:
|
|
127
138
|
"""
|
|
@@ -141,7 +152,71 @@ class MHAttention(nn.Module):
|
|
|
141
152
|
|
|
142
153
|
return 1 - similarity
|
|
143
154
|
|
|
144
|
-
def
|
|
155
|
+
def add_axial_rope(
|
|
156
|
+
self, q: torch.Tensor, k: torch.Tensor
|
|
157
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
158
|
+
"""
|
|
159
|
+
Apply Axial RoPE to all tokens except BOS tokens
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
if len(self.source_size) == 1:
|
|
163
|
+
spatial_dimension_names = "D1"
|
|
164
|
+
spatial_dimension_values = {"D1": self.source_size[0]}
|
|
165
|
+
elif len(self.source_size) == 2:
|
|
166
|
+
spatial_dimension_names = "D1 D2"
|
|
167
|
+
spatial_dimension_values = {
|
|
168
|
+
"D1": self.source_size[0],
|
|
169
|
+
"D2": self.source_size[1],
|
|
170
|
+
}
|
|
171
|
+
elif len(self.source_size) == 3:
|
|
172
|
+
spatial_dimension_names = "D1 D2 D3"
|
|
173
|
+
spatial_dimension_values = {
|
|
174
|
+
"D1": self.source_size[0],
|
|
175
|
+
"D2": self.source_size[1],
|
|
176
|
+
"D3": self.source_size[2],
|
|
177
|
+
}
|
|
178
|
+
else:
|
|
179
|
+
raise NotImplementedError(
|
|
180
|
+
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
q_bos, q_img = q[:, : self.bos_tokens, :], q[:, self.bos_tokens :, :]
|
|
184
|
+
k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
|
|
185
|
+
|
|
186
|
+
q_img = rearrange(
|
|
187
|
+
q_img,
|
|
188
|
+
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
189
|
+
**spatial_dimension_values,
|
|
190
|
+
)
|
|
191
|
+
k_img = rearrange(
|
|
192
|
+
k_img,
|
|
193
|
+
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
194
|
+
**spatial_dimension_values,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
|
|
198
|
+
|
|
199
|
+
q_img = apply_rotary_emb(freqs, q_img)
|
|
200
|
+
k_img = apply_rotary_emb(freqs, k_img)
|
|
201
|
+
|
|
202
|
+
q_img = rearrange(
|
|
203
|
+
q_img,
|
|
204
|
+
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
205
|
+
)
|
|
206
|
+
k_img = rearrange(
|
|
207
|
+
k_img,
|
|
208
|
+
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Re-combine the BOS tokens and the RoPE-enhanced image tokens
|
|
212
|
+
q = torch.cat([q_bos, q_img], dim=1)
|
|
213
|
+
k = torch.cat([k_bos, k_img], dim=1)
|
|
214
|
+
|
|
215
|
+
return q, k
|
|
216
|
+
|
|
217
|
+
def project_qkv(
|
|
218
|
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
|
219
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
145
220
|
query_batch_size, query_tokens, query_features = q.size()
|
|
146
221
|
key_batch_size, key_tokens, key_features = k.size()
|
|
147
222
|
|
|
@@ -154,74 +229,18 @@ class MHAttention(nn.Module):
|
|
|
154
229
|
|
|
155
230
|
if self.causal:
|
|
156
231
|
assert query_tokens == key_tokens
|
|
157
|
-
assert query_tokens == self.
|
|
232
|
+
assert query_tokens == self.seq_len
|
|
158
233
|
|
|
159
|
-
|
|
160
|
-
q = self.q_proj(q)
|
|
161
|
-
k = self.k_proj(k)
|
|
162
|
-
v = self.v_proj(v)
|
|
234
|
+
q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
|
|
163
235
|
|
|
164
|
-
# Rearrange dimensions and add RoPE if needed
|
|
165
236
|
if self.rotary_embedding is not None:
|
|
237
|
+
q, k = self.add_axial_rope(q, k)
|
|
166
238
|
|
|
167
|
-
|
|
168
|
-
spatial_dimension_names = "D1"
|
|
169
|
-
spatial_dimension_values = {"D1": self.source_size[0]}
|
|
170
|
-
elif len(self.source_size) == 2:
|
|
171
|
-
spatial_dimension_names = "D1 D2"
|
|
172
|
-
spatial_dimension_values = {
|
|
173
|
-
"D1": self.source_size[0],
|
|
174
|
-
"D2": self.source_size[1],
|
|
175
|
-
}
|
|
176
|
-
elif len(self.source_size) == 3:
|
|
177
|
-
spatial_dimension_names = "D1 D2 D3"
|
|
178
|
-
spatial_dimension_values = {
|
|
179
|
-
"D1": self.source_size[0],
|
|
180
|
-
"D2": self.source_size[1],
|
|
181
|
-
"D3": self.source_size[2],
|
|
182
|
-
}
|
|
183
|
-
else:
|
|
184
|
-
raise NotImplementedError(
|
|
185
|
-
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
q_bos, q_img = q[:, : self.bos_tokens, :], q[:, self.bos_tokens :, :]
|
|
189
|
-
k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
|
|
190
|
-
|
|
191
|
-
q_img = rearrange(
|
|
192
|
-
q_img,
|
|
193
|
-
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
194
|
-
**spatial_dimension_values,
|
|
195
|
-
)
|
|
196
|
-
k_img = rearrange(
|
|
197
|
-
k_img,
|
|
198
|
-
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
199
|
-
**spatial_dimension_values,
|
|
200
|
-
)
|
|
201
|
-
freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
|
|
202
|
-
q_img = apply_rotary_emb(freqs, q_img)
|
|
203
|
-
k_img = apply_rotary_emb(freqs, k_img)
|
|
204
|
-
|
|
205
|
-
q_img = rearrange(
|
|
206
|
-
q_img,
|
|
207
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
208
|
-
)
|
|
209
|
-
k_img = rearrange(
|
|
210
|
-
k_img,
|
|
211
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
212
|
-
)
|
|
239
|
+
return q, k, v
|
|
213
240
|
|
|
214
|
-
|
|
215
|
-
q = torch.cat([q_bos, q_img], dim=1)
|
|
216
|
-
k = torch.cat([k_bos, k_img], dim=1)
|
|
241
|
+
def forward(self, q, k, v):
|
|
217
242
|
|
|
218
|
-
|
|
219
|
-
scaling_factor = 1 / math.sqrt(self.head_dim)
|
|
220
|
-
elif self.scaling == "d":
|
|
221
|
-
# for backwards compatibility, per https://github.com/microsoft/mup
|
|
222
|
-
scaling_factor = 8 / self.head_dim
|
|
223
|
-
else:
|
|
224
|
-
raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
|
|
243
|
+
q, k, v = self.project_qkv(q, k, v)
|
|
225
244
|
|
|
226
245
|
if FLASH_ATTN:
|
|
227
246
|
# Divide Q/K/V into heads
|
|
@@ -234,7 +253,7 @@ class MHAttention(nn.Module):
|
|
|
234
253
|
k,
|
|
235
254
|
v,
|
|
236
255
|
dropout_p=self.dropout.p if self.training else 0.0,
|
|
237
|
-
softmax_scale=scaling_factor,
|
|
256
|
+
softmax_scale=self.scaling_factor,
|
|
238
257
|
causal=self.causal,
|
|
239
258
|
)
|
|
240
259
|
|
|
@@ -249,7 +268,7 @@ class MHAttention(nn.Module):
|
|
|
249
268
|
|
|
250
269
|
qk_scores = q @ k.transpose(-1, -2)
|
|
251
270
|
|
|
252
|
-
qk_scores *= scaling_factor
|
|
271
|
+
qk_scores *= self.scaling_factor
|
|
253
272
|
|
|
254
273
|
# Apply mask if causal (must come before softmax)
|
|
255
274
|
if self.causal:
|
|
@@ -265,6 +284,34 @@ class MHAttention(nn.Module):
|
|
|
265
284
|
|
|
266
285
|
return self.out_proj(output_without_heads)
|
|
267
286
|
|
|
287
|
+
def attention_scores(self, q, k, v):
|
|
288
|
+
|
|
289
|
+
q, k, v = self.project_qkv(q, k, v)
|
|
290
|
+
|
|
291
|
+
# Divide Q/K/V into heads
|
|
292
|
+
q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
|
|
293
|
+
k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
|
|
294
|
+
v = rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
|
|
295
|
+
|
|
296
|
+
qk_scores = q @ k.transpose(-1, -2)
|
|
297
|
+
|
|
298
|
+
qk_scores *= self.scaling_factor
|
|
299
|
+
|
|
300
|
+
# Apply mask if causal (must come before softmax)
|
|
301
|
+
if self.causal:
|
|
302
|
+
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
303
|
+
|
|
304
|
+
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
305
|
+
|
|
306
|
+
return qk_scores # (batch, head, seq_len, seq_len)
|
|
307
|
+
|
|
308
|
+
def reset_parameters(self):
|
|
309
|
+
# Default nn.Linear init is kaiming_uniform, which is fine
|
|
310
|
+
self.q_proj.reset_parameters()
|
|
311
|
+
self.k_proj.reset_parameters()
|
|
312
|
+
self.v_proj.reset_parameters()
|
|
313
|
+
self.out_proj.reset_parameters()
|
|
314
|
+
|
|
268
315
|
|
|
269
316
|
class FeedforwardBlock(nn.Module):
|
|
270
317
|
"""
|
|
@@ -285,9 +332,11 @@ class FeedforwardBlock(nn.Module):
|
|
|
285
332
|
normformer=False,
|
|
286
333
|
post_norm=True,
|
|
287
334
|
residual_path=True,
|
|
335
|
+
checkpoint=True,
|
|
288
336
|
):
|
|
289
337
|
super().__init__()
|
|
290
338
|
|
|
339
|
+
self.checkpoint = checkpoint
|
|
291
340
|
self.residual_path = residual_path
|
|
292
341
|
self.post_norm = post_norm
|
|
293
342
|
|
|
@@ -325,13 +374,30 @@ class FeedforwardBlock(nn.Module):
|
|
|
325
374
|
]
|
|
326
375
|
)
|
|
327
376
|
|
|
377
|
+
self.reset_parameters()
|
|
378
|
+
|
|
328
379
|
def forward(self, x):
|
|
380
|
+
|
|
381
|
+
if self.checkpoint:
|
|
382
|
+
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
383
|
+
else:
|
|
384
|
+
processed = self.process(x)
|
|
385
|
+
|
|
329
386
|
if self.residual_path and self.post_norm:
|
|
330
|
-
return self.layernorm(x +
|
|
387
|
+
return self.layernorm(x + processed)
|
|
331
388
|
elif self.residual_path:
|
|
332
|
-
return x +
|
|
389
|
+
return x + processed
|
|
333
390
|
else:
|
|
334
|
-
return
|
|
391
|
+
return processed
|
|
392
|
+
|
|
393
|
+
def reset_parameters(self):
|
|
394
|
+
if self.post_norm:
|
|
395
|
+
self.layernorm.reset_parameters()
|
|
396
|
+
|
|
397
|
+
# Iterate over the sequential block to reset parameters
|
|
398
|
+
for module in self.process:
|
|
399
|
+
if hasattr(module, "reset_parameters"):
|
|
400
|
+
module.reset_parameters()
|
|
335
401
|
|
|
336
402
|
|
|
337
403
|
class TransformerBlock(nn.Module):
|
|
@@ -365,6 +431,7 @@ class TransformerBlock(nn.Module):
|
|
|
365
431
|
pre_norm=True,
|
|
366
432
|
post_norm=False,
|
|
367
433
|
normformer=False,
|
|
434
|
+
checkpoint_ff=True,
|
|
368
435
|
):
|
|
369
436
|
"""
|
|
370
437
|
Args:
|
|
@@ -433,8 +500,11 @@ class TransformerBlock(nn.Module):
|
|
|
433
500
|
normformer=normformer,
|
|
434
501
|
post_norm=False, # Handled outside the block
|
|
435
502
|
residual_path=False, # Handled outside the block
|
|
503
|
+
checkpoint=checkpoint_ff,
|
|
436
504
|
)
|
|
437
505
|
|
|
506
|
+
self.reset_parameters()
|
|
507
|
+
|
|
438
508
|
@property
|
|
439
509
|
def _kv_distance(self) -> float:
|
|
440
510
|
return self.attn._kv_distance
|
|
@@ -445,25 +515,43 @@ class TransformerBlock(nn.Module):
|
|
|
445
515
|
x = self.layer_norm_1(x)
|
|
446
516
|
x = x + self.drop_path(self.attn(x, x, x))
|
|
447
517
|
x = self.layer_norm_2(x)
|
|
448
|
-
x = x + self.drop_path(
|
|
518
|
+
x = x + self.drop_path(self.ff(x))
|
|
449
519
|
if self.post_norm: # i.e. in addition! Pre and post.
|
|
450
520
|
x = self.layer_norm_3(x)
|
|
451
521
|
elif self.post_norm: # i.e. only, not prenorm, just post
|
|
452
522
|
x = x + self.drop_path(self.attn(x, x, x))
|
|
453
523
|
x = self.layer_norm_1(x)
|
|
454
|
-
x = x + self.drop_path(
|
|
524
|
+
x = x + self.drop_path(self.ff(x))
|
|
455
525
|
x = self.layer_norm_2(x)
|
|
456
526
|
else: # Not pre or post norm. Stand well back.
|
|
457
527
|
x = x + self.drop_path(self.attn(x, x, x))
|
|
458
|
-
x = x + self.drop_path(
|
|
528
|
+
x = x + self.drop_path(self.ff(x))
|
|
459
529
|
|
|
460
530
|
return x
|
|
461
531
|
|
|
532
|
+
def attention_scores(self, x):
|
|
533
|
+
"""
|
|
534
|
+
Give back the attention scores used in this layer.
|
|
535
|
+
"""
|
|
536
|
+
if self.pre_norm:
|
|
537
|
+
x = self.layer_norm_1(x)
|
|
538
|
+
return self.attn(x, x, x)
|
|
539
|
+
else:
|
|
540
|
+
return self.attn(x, x, x)
|
|
541
|
+
|
|
542
|
+
def reset_parameters(self):
|
|
543
|
+
self.layer_norm_1.reset_parameters()
|
|
544
|
+
self.layer_norm_2.reset_parameters()
|
|
545
|
+
self.layer_norm_3.reset_parameters()
|
|
546
|
+
|
|
547
|
+
self.attn.reset_parameters()
|
|
548
|
+
self.ff.reset_parameters()
|
|
549
|
+
|
|
462
550
|
|
|
463
551
|
class TransformerEncoder(nn.Module):
|
|
464
552
|
"""
|
|
465
553
|
This assumes we already get a sequence of embeddings (e.g. word or image
|
|
466
|
-
patch embeddings).
|
|
554
|
+
patch embeddings).
|
|
467
555
|
"""
|
|
468
556
|
|
|
469
557
|
def __init__(
|
|
@@ -491,6 +579,7 @@ class TransformerEncoder(nn.Module):
|
|
|
491
579
|
post_norm=False,
|
|
492
580
|
normformer=False,
|
|
493
581
|
msa_scaling="d",
|
|
582
|
+
checkpoint_ff=True,
|
|
494
583
|
):
|
|
495
584
|
"""
|
|
496
585
|
Args:
|
|
@@ -567,16 +656,19 @@ class TransformerEncoder(nn.Module):
|
|
|
567
656
|
pre_norm=pre_norm,
|
|
568
657
|
post_norm=post_norm,
|
|
569
658
|
normformer=normformer,
|
|
659
|
+
checkpoint_ff=checkpoint_ff,
|
|
570
660
|
)
|
|
571
661
|
for i in range(n_layers)
|
|
572
662
|
]
|
|
573
663
|
)
|
|
574
664
|
|
|
665
|
+
self.reset_parameters()
|
|
666
|
+
|
|
575
667
|
@property
|
|
576
668
|
def _kv_distances(self) -> float:
|
|
577
669
|
return ",".join([str(block._kv_distance) for block in self.blocks])
|
|
578
670
|
|
|
579
|
-
def
|
|
671
|
+
def preprocess(self, x):
|
|
580
672
|
if self._bos_tokens:
|
|
581
673
|
x = torch.cat([self._bos_embedding.expand(x.size(0), -1, -1), x], dim=1)
|
|
582
674
|
else:
|
|
@@ -591,6 +683,10 @@ class TransformerEncoder(nn.Module):
|
|
|
591
683
|
) # to shape (1, seq_len) to broadcast over batch
|
|
592
684
|
)
|
|
593
685
|
|
|
686
|
+
def forward(self, x):
|
|
687
|
+
|
|
688
|
+
x = self.preprocess(x)
|
|
689
|
+
|
|
594
690
|
for block in self.blocks:
|
|
595
691
|
x = block(x)
|
|
596
692
|
|
|
@@ -598,3 +694,27 @@ class TransformerEncoder(nn.Module):
|
|
|
598
694
|
return x[:, self._bos_tokens :, :]
|
|
599
695
|
else:
|
|
600
696
|
return x
|
|
697
|
+
|
|
698
|
+
def attention_scores(self, x):
|
|
699
|
+
|
|
700
|
+
x = self.preprocess(x)
|
|
701
|
+
|
|
702
|
+
layer_scores = []
|
|
703
|
+
|
|
704
|
+
for block in self.blocks:
|
|
705
|
+
# Get attention scores with shape (batch, 1, head, seq_len, seq_len)
|
|
706
|
+
layer_attention_scores = block.attention_scores(x).unsqueeze(1)
|
|
707
|
+
layer_scores.append(layer_attention_scores)
|
|
708
|
+
x = block(x)
|
|
709
|
+
|
|
710
|
+
return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
|
|
711
|
+
|
|
712
|
+
def reset_parameters(self):
|
|
713
|
+
if self._bos_embedding is not None:
|
|
714
|
+
nn.init.normal_(self._bos_embedding, mean=0.0, std=1.0)
|
|
715
|
+
|
|
716
|
+
if self.absolute_position_embedding is not None:
|
|
717
|
+
self.absolute_position_embedding.reset_parameters()
|
|
718
|
+
|
|
719
|
+
for block in self.blocks:
|
|
720
|
+
block.reset_parameters()
|
broccoli/vit.py
CHANGED
|
@@ -9,7 +9,9 @@ from .utils import PadTensor
|
|
|
9
9
|
from einops import einsum
|
|
10
10
|
from einops.layers.torch import Rearrange
|
|
11
11
|
|
|
12
|
+
import torch
|
|
12
13
|
import torch.nn as nn
|
|
14
|
+
import torch.nn.functional as F
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
class GetCLSToken(nn.Module):
|
|
@@ -31,10 +33,18 @@ class SequencePool(nn.Module):
|
|
|
31
33
|
]
|
|
32
34
|
)
|
|
33
35
|
|
|
36
|
+
self.reset_parameters()
|
|
37
|
+
|
|
34
38
|
def forward(self, x):
|
|
35
39
|
weights = self.attention(x)
|
|
36
40
|
return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
|
|
37
41
|
|
|
42
|
+
def reset_parameters(self):
|
|
43
|
+
# Iterate over modules in the sequential block
|
|
44
|
+
for module in self.attention:
|
|
45
|
+
if hasattr(module, "reset_parameters"):
|
|
46
|
+
module.reset_parameters()
|
|
47
|
+
|
|
38
48
|
|
|
39
49
|
class ClassificationHead(nn.Module):
|
|
40
50
|
"""
|
|
@@ -71,9 +81,16 @@ class ClassificationHead(nn.Module):
|
|
|
71
81
|
]
|
|
72
82
|
)
|
|
73
83
|
|
|
84
|
+
self.reset_parameters()
|
|
85
|
+
|
|
74
86
|
def forward(self, x):
|
|
75
87
|
return self.classification_process(x)
|
|
76
88
|
|
|
89
|
+
def reset_parameters(self):
|
|
90
|
+
for module in self.classification_process:
|
|
91
|
+
if hasattr(module, "reset_parameters"):
|
|
92
|
+
module.reset_parameters()
|
|
93
|
+
|
|
77
94
|
|
|
78
95
|
class SequencePoolClassificationHead(ClassificationHead):
|
|
79
96
|
"""
|
|
@@ -106,6 +123,8 @@ class SequencePoolClassificationHead(ClassificationHead):
|
|
|
106
123
|
]
|
|
107
124
|
)
|
|
108
125
|
|
|
126
|
+
self.reset_parameters()
|
|
127
|
+
|
|
109
128
|
|
|
110
129
|
class ViTEncoder(nn.Module):
|
|
111
130
|
"""
|
|
@@ -160,6 +179,7 @@ class ViTEncoder(nn.Module):
|
|
|
160
179
|
transformer_mlp_dropout=0.0,
|
|
161
180
|
transformer_msa_dropout=0.1,
|
|
162
181
|
transformer_stochastic_depth=0.1,
|
|
182
|
+
transformer_checkpoint_ff=True,
|
|
163
183
|
linear_module=nn.Linear,
|
|
164
184
|
):
|
|
165
185
|
super().__init__()
|
|
@@ -321,6 +341,7 @@ class ViTEncoder(nn.Module):
|
|
|
321
341
|
pre_norm=transformer_pre_norm,
|
|
322
342
|
normformer=transformer_normformer,
|
|
323
343
|
post_norm=transformer_post_norm,
|
|
344
|
+
checkpoint_ff=transformer_checkpoint_ff,
|
|
324
345
|
)
|
|
325
346
|
else:
|
|
326
347
|
self.transformer = nn.Identity()
|
|
@@ -354,6 +375,7 @@ class ViTEncoder(nn.Module):
|
|
|
354
375
|
normformer=transformer_normformer,
|
|
355
376
|
post_norm=transformer_post_norm,
|
|
356
377
|
residual_path=transformer_initial_ff_residual_path,
|
|
378
|
+
checkpoint=transformer_checkpoint_ff,
|
|
357
379
|
)
|
|
358
380
|
else:
|
|
359
381
|
self.initial_ff = nn.Identity()
|
|
@@ -373,9 +395,20 @@ class ViTEncoder(nn.Module):
|
|
|
373
395
|
]
|
|
374
396
|
)
|
|
375
397
|
|
|
398
|
+
self.reset_parameters()
|
|
399
|
+
|
|
376
400
|
def forward(self, x):
|
|
377
401
|
return self.encoder(x)
|
|
378
402
|
|
|
403
|
+
def attention_scores(self, x):
|
|
404
|
+
x = self.encoder[:-1](x)
|
|
405
|
+
return self.encoder[-1].attention_scores(x)
|
|
406
|
+
|
|
407
|
+
def reset_parameters(self):
|
|
408
|
+
for module in self.encoder:
|
|
409
|
+
if hasattr(module, "reset_parameters"):
|
|
410
|
+
module.reset_parameters()
|
|
411
|
+
|
|
379
412
|
|
|
380
413
|
class ViT(nn.Module):
|
|
381
414
|
"""
|
|
@@ -426,6 +459,7 @@ class ViT(nn.Module):
|
|
|
426
459
|
transformer_mlp_dropout=0.0,
|
|
427
460
|
transformer_msa_dropout=0.1,
|
|
428
461
|
transformer_stochastic_depth=0.1,
|
|
462
|
+
transformer_checkpoint_ff=True,
|
|
429
463
|
head=SequencePoolClassificationHead,
|
|
430
464
|
batch_norm_logits=True,
|
|
431
465
|
logit_projection_layer=nn.Linear,
|
|
@@ -492,6 +526,7 @@ class ViT(nn.Module):
|
|
|
492
526
|
transformer_mlp_dropout=transformer_mlp_dropout,
|
|
493
527
|
transformer_msa_dropout=transformer_msa_dropout,
|
|
494
528
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
529
|
+
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
495
530
|
linear_module=linear_module,
|
|
496
531
|
)
|
|
497
532
|
|
|
@@ -502,9 +537,26 @@ class ViT(nn.Module):
|
|
|
502
537
|
batch_norm_logits=batch_norm_logits,
|
|
503
538
|
)
|
|
504
539
|
|
|
540
|
+
self.reset_parameters()
|
|
541
|
+
|
|
505
542
|
@property
|
|
506
543
|
def sequence_length(self):
|
|
507
544
|
return self.encoder.sequence_length
|
|
508
545
|
|
|
509
546
|
def forward(self, x):
|
|
510
547
|
return self.pool(self.encoder(x))
|
|
548
|
+
|
|
549
|
+
def attention_scores(self, x):
|
|
550
|
+
return self.encoder.attention_scores(x)
|
|
551
|
+
|
|
552
|
+
def head_to_bos_token_attention(self, x):
|
|
553
|
+
all_attention = self.attention_scores(x)
|
|
554
|
+
batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
|
|
555
|
+
sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
|
|
556
|
+
n_bos_tokens = self.encoder.encoder._bos_tokens
|
|
557
|
+
just_bos = sequence_averages[:, :, :n_bos_tokens]
|
|
558
|
+
return F.softmax(just_bos, dim=-1) # (layer, head, bos_token)
|
|
559
|
+
|
|
560
|
+
def reset_parameters(self):
|
|
561
|
+
self.encoder.reset_parameters()
|
|
562
|
+
self.pool.reset_parameters()
|
|
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
|
4
4
|
broccoli/linear.py,sha256=Y7s-DzcwsOipRboNHc4HTScw4mJRalNoVFsNcxOB6a4,4872
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=
|
|
7
|
+
broccoli/transformer.py,sha256=x3Mo6_1x6fGG6lPDPx9srxn6UdwKEpvjFAO8zoMwAMI,23052
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
|
-
broccoli/vit.py,sha256=
|
|
10
|
-
broccoli_ml-
|
|
11
|
-
broccoli_ml-
|
|
12
|
-
broccoli_ml-
|
|
13
|
-
broccoli_ml-
|
|
9
|
+
broccoli/vit.py,sha256=tUYQyoDsBc5ZR_M5_J0huj0T3OAy-vn1f19hCGVDCrM,20425
|
|
10
|
+
broccoli_ml-5.1.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-5.1.0.dist-info/METADATA,sha256=3986lqn1iuWJ53O8ckM9LVU3tTjr32i19SeIXauWDXw,1368
|
|
12
|
+
broccoli_ml-5.1.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-5.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|