broccoli-ml 0.36.0__py3-none-any.whl → 10.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/activation.py +1 -4
- broccoli/cnn.py +1 -289
- broccoli/linear.py +237 -7
- broccoli/rope.py +19 -4
- broccoli/transformer.py +502 -142
- broccoli/utils.py +13 -7
- broccoli/vit.py +192 -51
- {broccoli_ml-0.36.0.dist-info → broccoli_ml-10.0.1.dist-info}/METADATA +5 -3
- broccoli_ml-10.0.1.dist-info/RECORD +13 -0
- broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl +0 -0
- broccoli/assets/cifar100_eigenvectors_size_2.pt +0 -0
- broccoli/assets/cifar100_eigenvectors_size_3.pt +0 -0
- broccoli/eigenpatches.py +0 -49
- broccoli_ml-0.36.0.dist-info/RECORD +0 -17
- {broccoli_ml-0.36.0.dist-info → broccoli_ml-10.0.1.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.36.0.dist-info → broccoli_ml-10.0.1.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
|
@@ -1,16 +1,72 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import math
|
|
2
|
-
from
|
|
3
|
-
from typing import Optional
|
|
4
|
-
from numpy import random
|
|
3
|
+
from typing import Optional, Tuple
|
|
5
4
|
|
|
6
5
|
import torch
|
|
7
6
|
import torch.nn as nn
|
|
8
7
|
import torch.nn.functional as F
|
|
8
|
+
from torch.utils.checkpoint import checkpoint
|
|
9
9
|
|
|
10
10
|
from einops import rearrange
|
|
11
11
|
|
|
12
12
|
from .rope import RotaryEmbedding, apply_rotary_emb
|
|
13
|
-
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from flash_attn import flash_attn_func
|
|
16
|
+
|
|
17
|
+
print("Using flash-attn.")
|
|
18
|
+
FLASH_ATTN = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
pass
|
|
21
|
+
FLASH_ATTN = False
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LayerScale(nn.Module):
|
|
25
|
+
def __init__(self, dim, init_values=1e-4):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.nondecay_scale = nn.Parameter(init_values * torch.ones(dim))
|
|
28
|
+
|
|
29
|
+
def forward(self, x):
|
|
30
|
+
return x * self.nondecay_scale
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def drop_path(
|
|
34
|
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
|
38
|
+
Copyright 2019 Ross Wightman
|
|
39
|
+
See documentation and licence there.
|
|
40
|
+
"""
|
|
41
|
+
if drop_prob == 0.0 or not training:
|
|
42
|
+
return x
|
|
43
|
+
keep_prob = 1 - drop_prob
|
|
44
|
+
shape = (x.shape[0],) + (1,) * (
|
|
45
|
+
x.ndim - 1
|
|
46
|
+
) # work with diff dim tensors, not just 2D ConvNets
|
|
47
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
48
|
+
if keep_prob > 0.0 and scale_by_keep:
|
|
49
|
+
random_tensor.div_(keep_prob)
|
|
50
|
+
return x * random_tensor
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class DropPath(nn.Module):
|
|
54
|
+
"""
|
|
55
|
+
From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
|
56
|
+
Copyright 2019 Ross Wightman
|
|
57
|
+
See documentation and licence there.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
|
61
|
+
super(DropPath, self).__init__()
|
|
62
|
+
self.drop_prob = drop_prob
|
|
63
|
+
self.scale_by_keep = scale_by_keep
|
|
64
|
+
|
|
65
|
+
def forward(self, x):
|
|
66
|
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
|
67
|
+
|
|
68
|
+
def extra_repr(self):
|
|
69
|
+
return f"drop_prob={round(self.drop_prob, 3):0.3f}"
|
|
14
70
|
|
|
15
71
|
|
|
16
72
|
class MHAttention(nn.Module):
|
|
@@ -31,10 +87,19 @@ class MHAttention(nn.Module):
|
|
|
31
87
|
causal=False,
|
|
32
88
|
seq_len=None,
|
|
33
89
|
linear_module: nn.Module = nn.Linear,
|
|
34
|
-
|
|
90
|
+
utility_tokens=0,
|
|
91
|
+
talking_heads=False,
|
|
35
92
|
rotary_embedding=None,
|
|
36
93
|
source_size=None,
|
|
94
|
+
scaling="d",
|
|
37
95
|
):
|
|
96
|
+
"""
|
|
97
|
+
Args:
|
|
98
|
+
scaling: how should the attention logits be scaled? Can be "sqrtd"
|
|
99
|
+
to mimic the original Attention is All You Need approach of
|
|
100
|
+
dividing by the sqrt of the embedding Dimension or "d" per
|
|
101
|
+
"Tensor Programs V...". Default "d"
|
|
102
|
+
"""
|
|
38
103
|
super().__init__()
|
|
39
104
|
|
|
40
105
|
if rotary_embedding is not None:
|
|
@@ -42,12 +107,31 @@ class MHAttention(nn.Module):
|
|
|
42
107
|
if causal:
|
|
43
108
|
assert seq_len is not None
|
|
44
109
|
|
|
110
|
+
self.talking_heads = talking_heads
|
|
111
|
+
|
|
112
|
+
if self.talking_heads:
|
|
113
|
+
self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
114
|
+
self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
115
|
+
else:
|
|
116
|
+
self.head_projection = None
|
|
117
|
+
self.sample_projection = None
|
|
118
|
+
|
|
45
119
|
self.embed_dim = embed_dim
|
|
46
120
|
self.n_heads = n_heads
|
|
47
121
|
assert embed_dim % n_heads == 0
|
|
122
|
+
self.scaling = scaling
|
|
48
123
|
|
|
49
124
|
self.head_dim = self.embed_dim // self.n_heads
|
|
50
125
|
|
|
126
|
+
if self.scaling == "sqrtd":
|
|
127
|
+
self.scaling_factor = 1 / math.sqrt(self.head_dim)
|
|
128
|
+
elif self.scaling == "d":
|
|
129
|
+
# 8/d_model for backwards compatibility,
|
|
130
|
+
# per https://github.com/microsoft/mup
|
|
131
|
+
self.scaling_factor = 8 / self.head_dim
|
|
132
|
+
else:
|
|
133
|
+
raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
|
|
134
|
+
|
|
51
135
|
self.q_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
52
136
|
self.k_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
53
137
|
self.v_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
@@ -66,7 +150,9 @@ class MHAttention(nn.Module):
|
|
|
66
150
|
)
|
|
67
151
|
self.rotary_embedding = rotary_embedding
|
|
68
152
|
self.source_size = source_size
|
|
69
|
-
self.
|
|
153
|
+
self.utility_tokens = utility_tokens
|
|
154
|
+
|
|
155
|
+
self.reset_parameters()
|
|
70
156
|
|
|
71
157
|
@property
|
|
72
158
|
def _kv_distance(self) -> float:
|
|
@@ -87,7 +173,71 @@ class MHAttention(nn.Module):
|
|
|
87
173
|
|
|
88
174
|
return 1 - similarity
|
|
89
175
|
|
|
90
|
-
def
|
|
176
|
+
def add_axial_rope(
|
|
177
|
+
self, q: torch.Tensor, k: torch.Tensor
|
|
178
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
179
|
+
"""
|
|
180
|
+
Apply Axial RoPE to all tokens except utility tokens
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
if len(self.source_size) == 1:
|
|
184
|
+
spatial_dimension_names = "D1"
|
|
185
|
+
spatial_dimension_values = {"D1": self.source_size[0]}
|
|
186
|
+
elif len(self.source_size) == 2:
|
|
187
|
+
spatial_dimension_names = "D1 D2"
|
|
188
|
+
spatial_dimension_values = {
|
|
189
|
+
"D1": self.source_size[0],
|
|
190
|
+
"D2": self.source_size[1],
|
|
191
|
+
}
|
|
192
|
+
elif len(self.source_size) == 3:
|
|
193
|
+
spatial_dimension_names = "D1 D2 D3"
|
|
194
|
+
spatial_dimension_values = {
|
|
195
|
+
"D1": self.source_size[0],
|
|
196
|
+
"D2": self.source_size[1],
|
|
197
|
+
"D3": self.source_size[2],
|
|
198
|
+
}
|
|
199
|
+
else:
|
|
200
|
+
raise NotImplementedError(
|
|
201
|
+
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
|
|
205
|
+
k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
|
|
206
|
+
|
|
207
|
+
q_img = rearrange(
|
|
208
|
+
q_img,
|
|
209
|
+
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
210
|
+
**spatial_dimension_values,
|
|
211
|
+
)
|
|
212
|
+
k_img = rearrange(
|
|
213
|
+
k_img,
|
|
214
|
+
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
215
|
+
**spatial_dimension_values,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
|
|
219
|
+
|
|
220
|
+
q_img = apply_rotary_emb(freqs, q_img)
|
|
221
|
+
k_img = apply_rotary_emb(freqs, k_img)
|
|
222
|
+
|
|
223
|
+
q_img = rearrange(
|
|
224
|
+
q_img,
|
|
225
|
+
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
226
|
+
)
|
|
227
|
+
k_img = rearrange(
|
|
228
|
+
k_img,
|
|
229
|
+
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Re-combine the utility tokens and the RoPE-enhanced sequence tokens
|
|
233
|
+
q = torch.cat([q_util, q_img], dim=1)
|
|
234
|
+
k = torch.cat([k_util, k_img], dim=1)
|
|
235
|
+
|
|
236
|
+
return q, k
|
|
237
|
+
|
|
238
|
+
def project_qkv(
|
|
239
|
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
|
240
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
91
241
|
query_batch_size, query_tokens, query_features = q.size()
|
|
92
242
|
key_batch_size, key_tokens, key_features = k.size()
|
|
93
243
|
|
|
@@ -100,66 +250,74 @@ class MHAttention(nn.Module):
|
|
|
100
250
|
|
|
101
251
|
if self.causal:
|
|
102
252
|
assert query_tokens == key_tokens
|
|
103
|
-
assert query_tokens == self.
|
|
253
|
+
assert query_tokens == self.seq_len
|
|
104
254
|
|
|
105
|
-
|
|
106
|
-
q = self.q_proj(q)
|
|
107
|
-
k = self.k_proj(k)
|
|
108
|
-
v = self.v_proj(v)
|
|
255
|
+
q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
|
|
109
256
|
|
|
110
|
-
# Rearrange dimensions and add RoPE if needed
|
|
111
257
|
if self.rotary_embedding is not None:
|
|
258
|
+
q, k = self.add_axial_rope(q, k)
|
|
112
259
|
|
|
113
|
-
|
|
114
|
-
spatial_dimension_names = "D1"
|
|
115
|
-
spatial_dimension_values = {"D1": self.source_size[0]}
|
|
116
|
-
elif len(self.source_size) == 2:
|
|
117
|
-
spatial_dimension_names = "D1 D2"
|
|
118
|
-
spatial_dimension_values = {
|
|
119
|
-
"D1": self.source_size[0],
|
|
120
|
-
"D2": self.source_size[1],
|
|
121
|
-
}
|
|
122
|
-
elif len(self.source_size) == 3:
|
|
123
|
-
spatial_dimension_names = "D1 D2 D3"
|
|
124
|
-
spatial_dimension_values = {
|
|
125
|
-
"D1": self.source_size[0],
|
|
126
|
-
"D2": self.source_size[1],
|
|
127
|
-
"D3": self.source_size[2],
|
|
128
|
-
}
|
|
129
|
-
else:
|
|
130
|
-
raise NotImplementedError(
|
|
131
|
-
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
132
|
-
)
|
|
260
|
+
return q, k, v
|
|
133
261
|
|
|
134
|
-
|
|
135
|
-
k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
|
|
262
|
+
def forward(self, q, k, v):
|
|
136
263
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
264
|
+
q, k, v = self.project_qkv(q, k, v)
|
|
265
|
+
|
|
266
|
+
if FLASH_ATTN and not self.talking_heads:
|
|
267
|
+
# Divide Q/K/V into heads
|
|
268
|
+
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
269
|
+
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
270
|
+
v = rearrange(v, "b t (h d) -> b t h d", h=self.n_heads)
|
|
271
|
+
|
|
272
|
+
output_with_heads = flash_attn_func(
|
|
273
|
+
q,
|
|
274
|
+
k,
|
|
275
|
+
v,
|
|
276
|
+
dropout_p=self.dropout.p if self.training else 0.0,
|
|
277
|
+
softmax_scale=self.scaling_factor,
|
|
278
|
+
causal=self.causal,
|
|
141
279
|
)
|
|
142
|
-
k_img = rearrange(
|
|
143
|
-
k_img,
|
|
144
|
-
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
145
|
-
**spatial_dimension_values,
|
|
146
|
-
)
|
|
147
|
-
freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
|
|
148
|
-
q_img = apply_rotary_emb(freqs, q_img)
|
|
149
|
-
k_img = apply_rotary_emb(freqs, k_img)
|
|
150
280
|
|
|
151
|
-
|
|
152
|
-
q_img,
|
|
153
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
154
|
-
)
|
|
155
|
-
k_img = rearrange(
|
|
156
|
-
k_img,
|
|
157
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
158
|
-
)
|
|
281
|
+
output_without_heads = rearrange(output_with_heads, "b t h d -> b t (h d)")
|
|
159
282
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
283
|
+
return self.out_proj(output_without_heads)
|
|
284
|
+
else:
|
|
285
|
+
# Divide Q/K/V into heads
|
|
286
|
+
q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
|
|
287
|
+
k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
|
|
288
|
+
v = rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
|
|
289
|
+
|
|
290
|
+
qk_scores = q @ k.transpose(-1, -2)
|
|
291
|
+
|
|
292
|
+
qk_scores *= self.scaling_factor
|
|
293
|
+
|
|
294
|
+
if self.talking_heads:
|
|
295
|
+
qk_scores = torch.einsum(
|
|
296
|
+
"b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# Apply mask if causal (must come before softmax)
|
|
300
|
+
if self.causal:
|
|
301
|
+
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
302
|
+
|
|
303
|
+
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
304
|
+
|
|
305
|
+
if self.talking_heads:
|
|
306
|
+
qk_scores = torch.einsum(
|
|
307
|
+
"b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
qk_scores = self.dropout(qk_scores)
|
|
311
|
+
|
|
312
|
+
output_with_heads = qk_scores @ v
|
|
313
|
+
|
|
314
|
+
output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
|
|
315
|
+
|
|
316
|
+
return self.out_proj(output_without_heads)
|
|
317
|
+
|
|
318
|
+
def attention_logits(self, q, k, v):
|
|
319
|
+
|
|
320
|
+
q, k, v = self.project_qkv(q, k, v)
|
|
163
321
|
|
|
164
322
|
# Divide Q/K/V into heads
|
|
165
323
|
q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
|
|
@@ -168,19 +326,24 @@ class MHAttention(nn.Module):
|
|
|
168
326
|
|
|
169
327
|
qk_scores = q @ k.transpose(-1, -2)
|
|
170
328
|
|
|
171
|
-
qk_scores
|
|
329
|
+
qk_scores *= self.scaling_factor
|
|
172
330
|
|
|
173
331
|
# Apply mask if causal (must come before softmax)
|
|
174
332
|
if self.causal:
|
|
175
333
|
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
176
334
|
|
|
177
|
-
qk_scores
|
|
178
|
-
|
|
179
|
-
output_with_heads = qk_scores @ v
|
|
335
|
+
return qk_scores # (batch, head, seq_len, seq_len)
|
|
180
336
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
337
|
+
def reset_parameters(self):
|
|
338
|
+
# Default nn.Linear init is kaiming_uniform, which is fine
|
|
339
|
+
self.q_proj.reset_parameters()
|
|
340
|
+
self.k_proj.reset_parameters()
|
|
341
|
+
self.v_proj.reset_parameters()
|
|
342
|
+
self.out_proj.reset_parameters()
|
|
343
|
+
if self.talking_heads:
|
|
344
|
+
# Initialize close to identity
|
|
345
|
+
nn.init.eye_(self.head_projection.weight)
|
|
346
|
+
nn.init.eye_(self.sample_projection.weight)
|
|
184
347
|
|
|
185
348
|
|
|
186
349
|
class FeedforwardBlock(nn.Module):
|
|
@@ -196,17 +359,29 @@ class FeedforwardBlock(nn.Module):
|
|
|
196
359
|
activation=nn.ReLU,
|
|
197
360
|
activation_kwargs=None,
|
|
198
361
|
dropout=0.0,
|
|
362
|
+
inner_dropout=None,
|
|
363
|
+
outer_dropout=None,
|
|
199
364
|
linear_module_up=nn.Linear,
|
|
200
365
|
linear_module_down=nn.Linear,
|
|
201
366
|
pre_norm=True,
|
|
202
367
|
normformer=False,
|
|
203
368
|
post_norm=True,
|
|
204
369
|
residual_path=True,
|
|
370
|
+
checkpoint=True,
|
|
205
371
|
):
|
|
206
372
|
super().__init__()
|
|
207
373
|
|
|
374
|
+
self.checkpoint = checkpoint
|
|
208
375
|
self.residual_path = residual_path
|
|
209
376
|
self.post_norm = post_norm
|
|
377
|
+
self.xglu = activation.__name__.endswith("GLU")
|
|
378
|
+
|
|
379
|
+
if self.residual_path and (output_features < input_features):
|
|
380
|
+
raise ValueError(
|
|
381
|
+
"If the number of output features will be less than "
|
|
382
|
+
"the number of input features, then `residual_path` "
|
|
383
|
+
"should be set to False."
|
|
384
|
+
)
|
|
210
385
|
|
|
211
386
|
if self.post_norm:
|
|
212
387
|
self.layernorm = nn.LayerNorm(output_features)
|
|
@@ -216,32 +391,91 @@ class FeedforwardBlock(nn.Module):
|
|
|
216
391
|
else:
|
|
217
392
|
self.activation = activation()
|
|
218
393
|
|
|
219
|
-
self.
|
|
394
|
+
self.inner_dropout = nn.Dropout(
|
|
395
|
+
inner_dropout if inner_dropout is not None else dropout
|
|
396
|
+
)
|
|
397
|
+
self.outer_dropout = nn.Dropout(
|
|
398
|
+
outer_dropout if outer_dropout is not None else dropout
|
|
399
|
+
)
|
|
220
400
|
|
|
221
401
|
self.max_features = (
|
|
222
|
-
2 * ratio * output_features
|
|
223
|
-
if
|
|
224
|
-
else ratio * output_features
|
|
402
|
+
2 * int(ratio * output_features)
|
|
403
|
+
if self.xglu
|
|
404
|
+
else int(ratio * output_features)
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
self.linear_in = linear_module_up(input_features, self.max_features)
|
|
408
|
+
self.linear_out = linear_module_down(
|
|
409
|
+
int(ratio * output_features), output_features
|
|
225
410
|
)
|
|
226
411
|
|
|
227
412
|
self.process = nn.Sequential(
|
|
228
413
|
*[
|
|
229
414
|
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
|
230
|
-
|
|
415
|
+
self.linear_in,
|
|
231
416
|
self.activation,
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
417
|
+
self.inner_dropout,
|
|
418
|
+
(
|
|
419
|
+
nn.LayerNorm(int(ratio * output_features))
|
|
420
|
+
if normformer
|
|
421
|
+
else nn.Identity()
|
|
422
|
+
),
|
|
423
|
+
self.linear_out,
|
|
424
|
+
self.outer_dropout,
|
|
235
425
|
]
|
|
236
426
|
)
|
|
237
427
|
|
|
428
|
+
self.recycling_enabled = False
|
|
429
|
+
if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
|
|
430
|
+
self.linear_out, "column_recycling_rate"
|
|
431
|
+
):
|
|
432
|
+
self.recycling_enabled = True
|
|
433
|
+
self.master_recycling_rate = self.linear_in.row_recycling_rate
|
|
434
|
+
self.linear_in.row_recycling_rate = 0.0
|
|
435
|
+
self.linear_out.column_recycling_rate = 0.0
|
|
436
|
+
if (
|
|
437
|
+
hasattr(self.linear_in, "column_recycling_rate")
|
|
438
|
+
and self.linear_in.column_recycling_rate > 0
|
|
439
|
+
) or (
|
|
440
|
+
hasattr(self.linear_out, "row_recycling_rate")
|
|
441
|
+
and self.linear_out.row_recycling_rate > 0
|
|
442
|
+
):
|
|
443
|
+
raise NotImplementedError(
|
|
444
|
+
"At the moment this layer can only support recycling linear "
|
|
445
|
+
"layers if the in layer resets only rows and the out layer "
|
|
446
|
+
"resets only columns."
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
self.reset_parameters()
|
|
450
|
+
|
|
238
451
|
def forward(self, x):
|
|
452
|
+
|
|
453
|
+
# Recycle weights if using recycling linear layers
|
|
454
|
+
if self.training and self.recycling_enabled:
|
|
455
|
+
indices = self.linear_out.get_reset_indices(1)
|
|
456
|
+
self.linear_in.reset_rows(indices, incoming_data=x)
|
|
457
|
+
self.linear_out.reset_columns(indices)
|
|
458
|
+
|
|
459
|
+
if self.checkpoint:
|
|
460
|
+
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
461
|
+
else:
|
|
462
|
+
processed = self.process(x)
|
|
463
|
+
|
|
239
464
|
if self.residual_path and self.post_norm:
|
|
240
|
-
return self.layernorm(x +
|
|
465
|
+
return self.layernorm(x + processed)
|
|
241
466
|
elif self.residual_path:
|
|
242
|
-
return x +
|
|
467
|
+
return x + processed
|
|
243
468
|
else:
|
|
244
|
-
return
|
|
469
|
+
return processed
|
|
470
|
+
|
|
471
|
+
def reset_parameters(self):
|
|
472
|
+
if self.post_norm:
|
|
473
|
+
self.layernorm.reset_parameters()
|
|
474
|
+
|
|
475
|
+
# Iterate over the sequential block to reset parameters
|
|
476
|
+
for module in self.process:
|
|
477
|
+
if hasattr(module, "reset_parameters"):
|
|
478
|
+
module.reset_parameters()
|
|
245
479
|
|
|
246
480
|
|
|
247
481
|
class TransformerBlock(nn.Module):
|
|
@@ -258,13 +492,19 @@ class TransformerBlock(nn.Module):
|
|
|
258
492
|
seq_len,
|
|
259
493
|
d_model,
|
|
260
494
|
n_heads,
|
|
261
|
-
|
|
495
|
+
relative_position_embedding=False,
|
|
262
496
|
source_size=None,
|
|
263
|
-
|
|
497
|
+
utility_tokens=0,
|
|
498
|
+
talking_heads=False,
|
|
264
499
|
mlp_ratio=4,
|
|
265
500
|
activation: nn.Module = nn.ReLU,
|
|
266
501
|
activation_kwargs: Optional[dict] = None,
|
|
267
|
-
|
|
502
|
+
ff_linear_module_up=None,
|
|
503
|
+
ff_linear_module_down=None,
|
|
504
|
+
msa_scaling="d",
|
|
505
|
+
ff_dropout=0.0,
|
|
506
|
+
ff_inner_dropout=0.0,
|
|
507
|
+
ff_outer_dropout=0.0,
|
|
268
508
|
msa_dropout=0.0,
|
|
269
509
|
identity_probability=0.0,
|
|
270
510
|
causal=False,
|
|
@@ -272,19 +512,37 @@ class TransformerBlock(nn.Module):
|
|
|
272
512
|
pre_norm=True,
|
|
273
513
|
post_norm=False,
|
|
274
514
|
normformer=False,
|
|
515
|
+
checkpoint_ff=True,
|
|
516
|
+
layerscale=True,
|
|
275
517
|
):
|
|
518
|
+
"""
|
|
519
|
+
Args:
|
|
520
|
+
msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
|
|
521
|
+
to mimic the original Attention is All You Need approach of
|
|
522
|
+
dividing by the sqrt of the embedding Dimension or "d" per
|
|
523
|
+
"Tensor Programs V...". Default "d"
|
|
524
|
+
"""
|
|
525
|
+
|
|
276
526
|
super().__init__()
|
|
277
527
|
|
|
278
528
|
self.pre_norm = pre_norm
|
|
279
529
|
self.post_norm = post_norm
|
|
280
530
|
self.normformer = normformer
|
|
281
531
|
|
|
282
|
-
self.
|
|
532
|
+
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
|
283
533
|
|
|
284
534
|
self.layer_norm_1 = nn.LayerNorm(d_model)
|
|
285
535
|
self.layer_norm_2 = nn.LayerNorm(d_model)
|
|
536
|
+
self.layer_norm_3 = nn.LayerNorm(d_model)
|
|
286
537
|
|
|
287
|
-
if
|
|
538
|
+
if layerscale:
|
|
539
|
+
self.layerscale1 = LayerScale(d_model)
|
|
540
|
+
self.layerscale2 = LayerScale(d_model)
|
|
541
|
+
else:
|
|
542
|
+
self.layerscale1 = nn.Identity()
|
|
543
|
+
self.layerscale2 = nn.Identity()
|
|
544
|
+
|
|
545
|
+
if relative_position_embedding:
|
|
288
546
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
|
289
547
|
if d_model < 16:
|
|
290
548
|
dim = d_model
|
|
@@ -305,7 +563,9 @@ class TransformerBlock(nn.Module):
|
|
|
305
563
|
linear_module=linear_module,
|
|
306
564
|
rotary_embedding=self.rotary_embedding,
|
|
307
565
|
source_size=source_size,
|
|
308
|
-
|
|
566
|
+
utility_tokens=utility_tokens,
|
|
567
|
+
talking_heads=talking_heads,
|
|
568
|
+
scaling=msa_scaling,
|
|
309
569
|
)
|
|
310
570
|
|
|
311
571
|
# Submodule for the feedforward process
|
|
@@ -315,56 +575,75 @@ class TransformerBlock(nn.Module):
|
|
|
315
575
|
d_model,
|
|
316
576
|
activation=activation,
|
|
317
577
|
activation_kwargs=activation_kwargs,
|
|
318
|
-
dropout=
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
578
|
+
dropout=ff_dropout,
|
|
579
|
+
inner_dropout=ff_inner_dropout,
|
|
580
|
+
outer_dropout=ff_outer_dropout,
|
|
581
|
+
linear_module_up=(
|
|
582
|
+
ff_linear_module_up
|
|
583
|
+
if ff_linear_module_up is not None
|
|
584
|
+
else linear_module
|
|
585
|
+
),
|
|
586
|
+
linear_module_down=(
|
|
587
|
+
ff_linear_module_down
|
|
588
|
+
if ff_linear_module_down is not None
|
|
589
|
+
else linear_module
|
|
590
|
+
),
|
|
591
|
+
pre_norm=False, # Handled outside the block
|
|
322
592
|
normformer=normformer,
|
|
323
|
-
post_norm=
|
|
324
|
-
residual_path=
|
|
593
|
+
post_norm=False, # Handled outside the block
|
|
594
|
+
residual_path=False, # Handled outside the block
|
|
595
|
+
checkpoint=checkpoint_ff,
|
|
325
596
|
)
|
|
326
597
|
|
|
598
|
+
self.reset_parameters()
|
|
599
|
+
|
|
327
600
|
@property
|
|
328
601
|
def _kv_distance(self) -> float:
|
|
329
602
|
return self.attn._kv_distance
|
|
330
603
|
|
|
331
604
|
def forward(self, x):
|
|
332
|
-
if not self.training:
|
|
333
|
-
identity_probability = 0.0
|
|
334
|
-
else:
|
|
335
|
-
identity_probability = self.identity_probability
|
|
336
|
-
|
|
337
|
-
# perform the identity operation for some rows in the batch
|
|
338
|
-
dist = torch.distributions.Binomial(x.size(0), identity_probability)
|
|
339
|
-
identity_count = int(dist.sample().item())
|
|
340
|
-
|
|
341
|
-
shuffle_indices = torch.randperm(x.size(0), device=x.device)
|
|
342
|
-
unshuffle_indices = torch.argsort(shuffle_indices)
|
|
343
|
-
shuffled = x[shuffle_indices, :, :]
|
|
344
|
-
identity_x = shuffled[:identity_count, :, :]
|
|
345
|
-
process_x = shuffled[identity_count:, :, :]
|
|
346
|
-
|
|
347
|
-
residual_x = process_x
|
|
348
605
|
|
|
349
606
|
if self.pre_norm:
|
|
350
|
-
|
|
607
|
+
x = self.layer_norm_1(x)
|
|
608
|
+
x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
|
|
609
|
+
x = self.layer_norm_2(x)
|
|
610
|
+
x = x + self.drop_path(self.layerscale2(self.ff(x)))
|
|
611
|
+
if self.post_norm: # i.e. in addition! Pre and post.
|
|
612
|
+
x = self.layer_norm_3(x)
|
|
613
|
+
elif self.post_norm: # i.e. only, not prenorm, just post
|
|
614
|
+
x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
|
|
615
|
+
x = self.layer_norm_1(x)
|
|
616
|
+
x = x + self.drop_path(self.layerscale2(self.ff(x)))
|
|
617
|
+
x = self.layer_norm_2(x)
|
|
618
|
+
else: # Not pre or post norm. Stand well back.
|
|
619
|
+
x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
|
|
620
|
+
x = x + self.drop_path(self.layerscale2(self.ff(x)))
|
|
351
621
|
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
if self.post_norm:
|
|
355
|
-
process_x = self.layer_norm_2(process_x)
|
|
622
|
+
return x
|
|
356
623
|
|
|
357
|
-
|
|
624
|
+
def attention_logits(self, x):
|
|
625
|
+
"""
|
|
626
|
+
Give back the attention scores used in this layer.
|
|
627
|
+
"""
|
|
628
|
+
if self.pre_norm:
|
|
629
|
+
x = self.layer_norm_1(x)
|
|
630
|
+
return self.attn.attention_logits(x, x, x)
|
|
631
|
+
else:
|
|
632
|
+
return self.attn.attention_logits(x, x, x)
|
|
358
633
|
|
|
359
|
-
|
|
634
|
+
def reset_parameters(self):
|
|
635
|
+
self.layer_norm_1.reset_parameters()
|
|
636
|
+
self.layer_norm_2.reset_parameters()
|
|
637
|
+
self.layer_norm_3.reset_parameters()
|
|
360
638
|
|
|
361
|
-
|
|
639
|
+
self.attn.reset_parameters()
|
|
640
|
+
self.ff.reset_parameters()
|
|
362
641
|
|
|
363
642
|
|
|
364
643
|
class TransformerEncoder(nn.Module):
|
|
365
644
|
"""
|
|
366
645
|
This assumes we already get a sequence of embeddings (e.g. word or image
|
|
367
|
-
patch embeddings).
|
|
646
|
+
patch embeddings).
|
|
368
647
|
"""
|
|
369
648
|
|
|
370
649
|
def __init__(
|
|
@@ -373,54 +652,93 @@ class TransformerEncoder(nn.Module):
|
|
|
373
652
|
d_model,
|
|
374
653
|
n_layers,
|
|
375
654
|
n_heads,
|
|
376
|
-
|
|
655
|
+
absolute_position_embedding=True,
|
|
656
|
+
relative_position_embedding=False,
|
|
377
657
|
source_size=None,
|
|
378
658
|
mlp_ratio=4,
|
|
379
659
|
activation: nn.Module = nn.ReLU,
|
|
380
660
|
activation_kwargs: Optional[dict] = None,
|
|
381
|
-
|
|
661
|
+
ff_linear_module_up=None,
|
|
662
|
+
ff_linear_module_down=None,
|
|
663
|
+
ff_dropout=0.0,
|
|
664
|
+
ff_inner_dropout=0.0,
|
|
665
|
+
ff_outer_dropout=0.0,
|
|
382
666
|
msa_dropout=0.0,
|
|
383
667
|
stochastic_depth=0.0,
|
|
384
668
|
causal=False,
|
|
385
669
|
linear_module=nn.Linear,
|
|
386
|
-
|
|
387
|
-
|
|
670
|
+
utility_tokens=0,
|
|
671
|
+
talking_heads=False,
|
|
672
|
+
return_utility_tokens=False,
|
|
388
673
|
pre_norm=True,
|
|
389
674
|
post_norm=False,
|
|
390
675
|
normformer=False,
|
|
676
|
+
msa_scaling="d",
|
|
677
|
+
checkpoint_ff=True,
|
|
678
|
+
layerscale=True,
|
|
391
679
|
):
|
|
392
|
-
|
|
393
|
-
|
|
680
|
+
"""
|
|
681
|
+
Args:
|
|
682
|
+
msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
|
|
683
|
+
to mimic the original Attention is All You Need approach of
|
|
684
|
+
dividing by the sqrt of the embedding Dimension or "d" per
|
|
685
|
+
"Tensor Programs V...". Default "d"
|
|
686
|
+
"""
|
|
687
|
+
|
|
688
|
+
if relative_position_embedding and (source_size is None):
|
|
689
|
+
raise ValueError(
|
|
690
|
+
"`source_size` for TransformerEncoder cannot be None if"
|
|
691
|
+
" `relative_position_embedding` is True"
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
if absolute_position_embedding and (seq_len is None):
|
|
695
|
+
raise ValueError(
|
|
696
|
+
"`seq_len` for TransformerEncoder cannot be None if"
|
|
697
|
+
" `absolute_position_embedding` is True"
|
|
698
|
+
)
|
|
394
699
|
|
|
395
700
|
super().__init__()
|
|
701
|
+
|
|
702
|
+
if FLASH_ATTN and talking_heads:
|
|
703
|
+
warnings.warn(
|
|
704
|
+
"Using talking heads currently prevents using flash attention.",
|
|
705
|
+
stacklevel=2,
|
|
706
|
+
)
|
|
707
|
+
|
|
396
708
|
self.seq_len = seq_len
|
|
397
709
|
self.n_heads = n_heads
|
|
398
|
-
self.
|
|
399
|
-
self.
|
|
400
|
-
|
|
401
|
-
# Initialise
|
|
402
|
-
if self.
|
|
403
|
-
self.
|
|
404
|
-
|
|
405
|
-
|
|
710
|
+
self._utility_tokens = utility_tokens
|
|
711
|
+
self.return_utility_tokens = return_utility_tokens
|
|
712
|
+
|
|
713
|
+
# Initialise utility tokens with normal init, like usual Pytorch embeddings
|
|
714
|
+
if self._utility_tokens:
|
|
715
|
+
self._utility_token_embedding = nn.Parameter(
|
|
716
|
+
torch.empty(self._utility_tokens, d_model)
|
|
717
|
+
)
|
|
718
|
+
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
719
|
+
else:
|
|
720
|
+
self._utility_token_embedding = None
|
|
721
|
+
|
|
722
|
+
if self._utility_tokens and (self.seq_len is not None):
|
|
723
|
+
self.full_sequence_length = self.seq_len + self._utility_tokens
|
|
406
724
|
else:
|
|
407
|
-
self._bos_embedding = None
|
|
408
725
|
self.full_sequence_length = self.seq_len
|
|
409
726
|
|
|
410
727
|
self.d_model = d_model
|
|
411
728
|
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
if self.position_embedding_type == "absolute":
|
|
729
|
+
if absolute_position_embedding:
|
|
415
730
|
self.absolute_position_embedding = nn.Embedding(
|
|
416
731
|
self.full_sequence_length, d_model
|
|
417
732
|
)
|
|
733
|
+
else:
|
|
734
|
+
self.absolute_position_embedding = None
|
|
418
735
|
|
|
419
|
-
self.mlp_dropout =
|
|
736
|
+
self.mlp_dropout = ff_dropout
|
|
420
737
|
self.msa_dropout = msa_dropout
|
|
421
738
|
self.stochastic_depth = stochastic_depth
|
|
422
739
|
|
|
423
|
-
assert isinstance(n_layers, int)
|
|
740
|
+
assert isinstance(n_layers, int)
|
|
741
|
+
|
|
424
742
|
if n_layers == 1:
|
|
425
743
|
self.stochastic_depth_probabilities = [0.0]
|
|
426
744
|
else:
|
|
@@ -435,13 +753,19 @@ class TransformerEncoder(nn.Module):
|
|
|
435
753
|
self.full_sequence_length,
|
|
436
754
|
d_model,
|
|
437
755
|
n_heads,
|
|
438
|
-
|
|
756
|
+
relative_position_embedding=relative_position_embedding,
|
|
439
757
|
source_size=source_size,
|
|
440
|
-
|
|
758
|
+
utility_tokens=utility_tokens,
|
|
759
|
+
talking_heads=talking_heads,
|
|
441
760
|
mlp_ratio=mlp_ratio,
|
|
442
761
|
activation=activation,
|
|
443
762
|
activation_kwargs=activation_kwargs,
|
|
444
|
-
|
|
763
|
+
ff_linear_module_up=ff_linear_module_up,
|
|
764
|
+
ff_linear_module_down=ff_linear_module_down,
|
|
765
|
+
msa_scaling=msa_scaling,
|
|
766
|
+
ff_dropout=ff_dropout,
|
|
767
|
+
ff_inner_dropout=ff_inner_dropout,
|
|
768
|
+
ff_outer_dropout=ff_outer_dropout,
|
|
445
769
|
msa_dropout=msa_dropout,
|
|
446
770
|
identity_probability=self.stochastic_depth_probabilities[i],
|
|
447
771
|
causal=causal,
|
|
@@ -449,22 +773,28 @@ class TransformerEncoder(nn.Module):
|
|
|
449
773
|
pre_norm=pre_norm,
|
|
450
774
|
post_norm=post_norm,
|
|
451
775
|
normformer=normformer,
|
|
776
|
+
checkpoint_ff=checkpoint_ff,
|
|
777
|
+
layerscale=layerscale,
|
|
452
778
|
)
|
|
453
779
|
for i in range(n_layers)
|
|
454
780
|
]
|
|
455
781
|
)
|
|
456
782
|
|
|
783
|
+
self.reset_parameters()
|
|
784
|
+
|
|
457
785
|
@property
|
|
458
786
|
def _kv_distances(self) -> float:
|
|
459
787
|
return ",".join([str(block._kv_distance) for block in self.blocks])
|
|
460
788
|
|
|
461
|
-
def
|
|
462
|
-
if self.
|
|
463
|
-
x = torch.cat(
|
|
789
|
+
def preprocess(self, x):
|
|
790
|
+
if self._utility_tokens:
|
|
791
|
+
x = torch.cat(
|
|
792
|
+
[self._utility_token_embedding.expand(x.size(0), -1, -1), x], dim=1
|
|
793
|
+
)
|
|
464
794
|
else:
|
|
465
795
|
x = x
|
|
466
796
|
|
|
467
|
-
if self.
|
|
797
|
+
if self.absolute_position_embedding is not None:
|
|
468
798
|
x = x + self.absolute_position_embedding(
|
|
469
799
|
torch.arange(
|
|
470
800
|
0, self.full_sequence_length, dtype=torch.long, device=x.device
|
|
@@ -473,10 +803,40 @@ class TransformerEncoder(nn.Module):
|
|
|
473
803
|
) # to shape (1, seq_len) to broadcast over batch
|
|
474
804
|
)
|
|
475
805
|
|
|
806
|
+
return x
|
|
807
|
+
|
|
808
|
+
def forward(self, x):
|
|
809
|
+
|
|
810
|
+
x = self.preprocess(x)
|
|
811
|
+
|
|
476
812
|
for block in self.blocks:
|
|
477
813
|
x = block(x)
|
|
478
814
|
|
|
479
|
-
if self.
|
|
480
|
-
return x[:, self.
|
|
815
|
+
if self._utility_tokens and not self.return_utility_tokens:
|
|
816
|
+
return x[:, self._utility_tokens :, :]
|
|
481
817
|
else:
|
|
482
818
|
return x
|
|
819
|
+
|
|
820
|
+
def attention_logits(self, x):
|
|
821
|
+
|
|
822
|
+
x = self.preprocess(x)
|
|
823
|
+
|
|
824
|
+
layer_scores = []
|
|
825
|
+
|
|
826
|
+
for block in self.blocks:
|
|
827
|
+
# Get attention scores with shape (batch, 1, head, seq_len, seq_len)
|
|
828
|
+
layer_attention_logits = block.attention_logits(x).unsqueeze(1)
|
|
829
|
+
layer_scores.append(layer_attention_logits)
|
|
830
|
+
x = block(x)
|
|
831
|
+
|
|
832
|
+
return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
|
|
833
|
+
|
|
834
|
+
def reset_parameters(self):
|
|
835
|
+
if self._utility_token_embedding is not None:
|
|
836
|
+
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
837
|
+
|
|
838
|
+
if self.absolute_position_embedding is not None:
|
|
839
|
+
self.absolute_position_embedding.reset_parameters()
|
|
840
|
+
|
|
841
|
+
for block in self.blocks:
|
|
842
|
+
block.reset_parameters()
|