broccoli-ml 9.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/__init__.py +6 -0
- broccoli/activation.py +118 -0
- broccoli/cnn.py +157 -0
- broccoli/linear.py +352 -0
- broccoli/rope.py +407 -0
- broccoli/tensor.py +128 -0
- broccoli/transformer.py +779 -0
- broccoli/utils.py +15 -0
- broccoli/vit.py +600 -0
- broccoli_ml-9.5.1.dist-info/LICENSE +21 -0
- broccoli_ml-9.5.1.dist-info/METADATA +43 -0
- broccoli_ml-9.5.1.dist-info/RECORD +13 -0
- broccoli_ml-9.5.1.dist-info/WHEEL +4 -0
broccoli/transformer.py
ADDED
|
@@ -0,0 +1,779 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from torch.utils.checkpoint import checkpoint
|
|
8
|
+
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
|
|
11
|
+
from .rope import RotaryEmbedding, apply_rotary_emb
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from flash_attn import flash_attn_func
|
|
15
|
+
|
|
16
|
+
print("Using flash-attn.")
|
|
17
|
+
FLASH_ATTN = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
pass
|
|
20
|
+
FLASH_ATTN = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def drop_path(
|
|
24
|
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
|
28
|
+
Copyright 2019 Ross Wightman
|
|
29
|
+
See documentation and licence there.
|
|
30
|
+
"""
|
|
31
|
+
if drop_prob == 0.0 or not training:
|
|
32
|
+
return x
|
|
33
|
+
keep_prob = 1 - drop_prob
|
|
34
|
+
shape = (x.shape[0],) + (1,) * (
|
|
35
|
+
x.ndim - 1
|
|
36
|
+
) # work with diff dim tensors, not just 2D ConvNets
|
|
37
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
38
|
+
if keep_prob > 0.0 and scale_by_keep:
|
|
39
|
+
random_tensor.div_(keep_prob)
|
|
40
|
+
return x * random_tensor
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class DropPath(nn.Module):
|
|
44
|
+
"""
|
|
45
|
+
From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
|
46
|
+
Copyright 2019 Ross Wightman
|
|
47
|
+
See documentation and licence there.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
|
51
|
+
super(DropPath, self).__init__()
|
|
52
|
+
self.drop_prob = drop_prob
|
|
53
|
+
self.scale_by_keep = scale_by_keep
|
|
54
|
+
|
|
55
|
+
def forward(self, x):
|
|
56
|
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
|
57
|
+
|
|
58
|
+
def extra_repr(self):
|
|
59
|
+
return f"drop_prob={round(self.drop_prob, 3):0.3f}"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class MHAttention(nn.Module):
|
|
63
|
+
"""
|
|
64
|
+
Multi-head self-attention using einops and optionally a custom linear layer.
|
|
65
|
+
|
|
66
|
+
Forward method assumes q, k and v have the same embedding size and k and v
|
|
67
|
+
are the same shape.
|
|
68
|
+
|
|
69
|
+
Assumes bias=False and batch_first=True, as God intended.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
embed_dim,
|
|
75
|
+
n_heads,
|
|
76
|
+
dropout=0.0,
|
|
77
|
+
causal=False,
|
|
78
|
+
seq_len=None,
|
|
79
|
+
linear_module: nn.Module = nn.Linear,
|
|
80
|
+
utility_tokens=0,
|
|
81
|
+
rotary_embedding=None,
|
|
82
|
+
source_size=None,
|
|
83
|
+
scaling="d",
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
Args:
|
|
87
|
+
scaling: how should the attention logits be scaled? Can be "sqrtd"
|
|
88
|
+
to mimic the original Attention is All You Need approach of
|
|
89
|
+
dividing by the sqrt of the embedding Dimension or "d" per
|
|
90
|
+
"Tensor Programs V...". Default "d"
|
|
91
|
+
"""
|
|
92
|
+
super().__init__()
|
|
93
|
+
|
|
94
|
+
if rotary_embedding is not None:
|
|
95
|
+
assert source_size is not None
|
|
96
|
+
if causal:
|
|
97
|
+
assert seq_len is not None
|
|
98
|
+
|
|
99
|
+
self.embed_dim = embed_dim
|
|
100
|
+
self.n_heads = n_heads
|
|
101
|
+
assert embed_dim % n_heads == 0
|
|
102
|
+
self.scaling = scaling
|
|
103
|
+
|
|
104
|
+
self.head_dim = self.embed_dim // self.n_heads
|
|
105
|
+
|
|
106
|
+
if self.scaling == "sqrtd":
|
|
107
|
+
self.scaling_factor = 1 / math.sqrt(self.head_dim)
|
|
108
|
+
elif self.scaling == "d":
|
|
109
|
+
# 8/d_model for backwards compatibility,
|
|
110
|
+
# per https://github.com/microsoft/mup
|
|
111
|
+
self.scaling_factor = 8 / self.head_dim
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
|
|
114
|
+
|
|
115
|
+
self.q_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
116
|
+
self.k_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
117
|
+
self.v_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
118
|
+
|
|
119
|
+
self.out_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
120
|
+
|
|
121
|
+
self.causal = causal
|
|
122
|
+
self.seq_len = seq_len
|
|
123
|
+
self.dropout = nn.Dropout(dropout)
|
|
124
|
+
if self.causal:
|
|
125
|
+
self.register_buffer(
|
|
126
|
+
"mask",
|
|
127
|
+
(torch.triu(torch.ones(seq_len, seq_len), diagonal=1) == 1)
|
|
128
|
+
.unsqueeze(0)
|
|
129
|
+
.unsqueeze(0),
|
|
130
|
+
)
|
|
131
|
+
self.rotary_embedding = rotary_embedding
|
|
132
|
+
self.source_size = source_size
|
|
133
|
+
self.utility_tokens = utility_tokens
|
|
134
|
+
|
|
135
|
+
self.reset_parameters()
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def _kv_distance(self) -> float:
|
|
139
|
+
"""
|
|
140
|
+
Calculates the cosine distance between the weight tensors of `self.k_proj`
|
|
141
|
+
and `self.v_proj`.
|
|
142
|
+
|
|
143
|
+
The cosine distance is defined as 1 - cosine_similarity (i.e. a value
|
|
144
|
+
closer to 0 indicates higher similarity.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
similarity = F.cosine_similarity(
|
|
148
|
+
self.k_proj.weight.detach().flatten(),
|
|
149
|
+
self.v_proj.weight.detach().flatten(),
|
|
150
|
+
dim=0,
|
|
151
|
+
eps=1e-8,
|
|
152
|
+
).item()
|
|
153
|
+
|
|
154
|
+
return 1 - similarity
|
|
155
|
+
|
|
156
|
+
def add_axial_rope(
|
|
157
|
+
self, q: torch.Tensor, k: torch.Tensor
|
|
158
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
159
|
+
"""
|
|
160
|
+
Apply Axial RoPE to all tokens except utility tokens
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
if len(self.source_size) == 1:
|
|
164
|
+
spatial_dimension_names = "D1"
|
|
165
|
+
spatial_dimension_values = {"D1": self.source_size[0]}
|
|
166
|
+
elif len(self.source_size) == 2:
|
|
167
|
+
spatial_dimension_names = "D1 D2"
|
|
168
|
+
spatial_dimension_values = {
|
|
169
|
+
"D1": self.source_size[0],
|
|
170
|
+
"D2": self.source_size[1],
|
|
171
|
+
}
|
|
172
|
+
elif len(self.source_size) == 3:
|
|
173
|
+
spatial_dimension_names = "D1 D2 D3"
|
|
174
|
+
spatial_dimension_values = {
|
|
175
|
+
"D1": self.source_size[0],
|
|
176
|
+
"D2": self.source_size[1],
|
|
177
|
+
"D3": self.source_size[2],
|
|
178
|
+
}
|
|
179
|
+
else:
|
|
180
|
+
raise NotImplementedError(
|
|
181
|
+
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
|
|
185
|
+
k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
|
|
186
|
+
|
|
187
|
+
q_img = rearrange(
|
|
188
|
+
q_img,
|
|
189
|
+
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
190
|
+
**spatial_dimension_values,
|
|
191
|
+
)
|
|
192
|
+
k_img = rearrange(
|
|
193
|
+
k_img,
|
|
194
|
+
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
195
|
+
**spatial_dimension_values,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
|
|
199
|
+
|
|
200
|
+
q_img = apply_rotary_emb(freqs, q_img)
|
|
201
|
+
k_img = apply_rotary_emb(freqs, k_img)
|
|
202
|
+
|
|
203
|
+
q_img = rearrange(
|
|
204
|
+
q_img,
|
|
205
|
+
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
206
|
+
)
|
|
207
|
+
k_img = rearrange(
|
|
208
|
+
k_img,
|
|
209
|
+
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Re-combine the utility tokens and the RoPE-enhanced sequence tokens
|
|
213
|
+
q = torch.cat([q_util, q_img], dim=1)
|
|
214
|
+
k = torch.cat([k_util, k_img], dim=1)
|
|
215
|
+
|
|
216
|
+
return q, k
|
|
217
|
+
|
|
218
|
+
def project_qkv(
|
|
219
|
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
|
220
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
221
|
+
query_batch_size, query_tokens, query_features = q.size()
|
|
222
|
+
key_batch_size, key_tokens, key_features = k.size()
|
|
223
|
+
|
|
224
|
+
assert k.size() == v.size()
|
|
225
|
+
assert query_features == key_features
|
|
226
|
+
assert (
|
|
227
|
+
(query_batch_size == key_batch_size) # batch sizes are the same...
|
|
228
|
+
or query_batch_size == 1 # ... or query is broadcastable
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if self.causal:
|
|
232
|
+
assert query_tokens == key_tokens
|
|
233
|
+
assert query_tokens == self.seq_len
|
|
234
|
+
|
|
235
|
+
q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
|
|
236
|
+
|
|
237
|
+
if self.rotary_embedding is not None:
|
|
238
|
+
q, k = self.add_axial_rope(q, k)
|
|
239
|
+
|
|
240
|
+
return q, k, v
|
|
241
|
+
|
|
242
|
+
def forward(self, q, k, v):
|
|
243
|
+
|
|
244
|
+
q, k, v = self.project_qkv(q, k, v)
|
|
245
|
+
|
|
246
|
+
if FLASH_ATTN:
|
|
247
|
+
# Divide Q/K/V into heads
|
|
248
|
+
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
249
|
+
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
250
|
+
v = rearrange(v, "b t (h d) -> b t h d", h=self.n_heads)
|
|
251
|
+
|
|
252
|
+
output_with_heads = flash_attn_func(
|
|
253
|
+
q,
|
|
254
|
+
k,
|
|
255
|
+
v,
|
|
256
|
+
dropout_p=self.dropout.p if self.training else 0.0,
|
|
257
|
+
softmax_scale=self.scaling_factor,
|
|
258
|
+
causal=self.causal,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
output_without_heads = rearrange(output_with_heads, "b t h d -> b t (h d)")
|
|
262
|
+
|
|
263
|
+
return self.out_proj(output_without_heads)
|
|
264
|
+
else:
|
|
265
|
+
# Divide Q/K/V into heads
|
|
266
|
+
q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
|
|
267
|
+
k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
|
|
268
|
+
v = rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
|
|
269
|
+
|
|
270
|
+
qk_scores = q @ k.transpose(-1, -2)
|
|
271
|
+
|
|
272
|
+
qk_scores *= self.scaling_factor
|
|
273
|
+
|
|
274
|
+
# Apply mask if causal (must come before softmax)
|
|
275
|
+
if self.causal:
|
|
276
|
+
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
277
|
+
|
|
278
|
+
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
279
|
+
|
|
280
|
+
qk_scores = self.dropout(qk_scores)
|
|
281
|
+
|
|
282
|
+
output_with_heads = qk_scores @ v
|
|
283
|
+
|
|
284
|
+
output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
|
|
285
|
+
|
|
286
|
+
return self.out_proj(output_without_heads)
|
|
287
|
+
|
|
288
|
+
def attention_logits(self, q, k, v):
|
|
289
|
+
|
|
290
|
+
q, k, v = self.project_qkv(q, k, v)
|
|
291
|
+
|
|
292
|
+
# Divide Q/K/V into heads
|
|
293
|
+
q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
|
|
294
|
+
k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
|
|
295
|
+
v = rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
|
|
296
|
+
|
|
297
|
+
qk_scores = q @ k.transpose(-1, -2)
|
|
298
|
+
|
|
299
|
+
qk_scores *= self.scaling_factor
|
|
300
|
+
|
|
301
|
+
# Apply mask if causal (must come before softmax)
|
|
302
|
+
if self.causal:
|
|
303
|
+
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
304
|
+
|
|
305
|
+
return qk_scores # (batch, head, seq_len, seq_len)
|
|
306
|
+
|
|
307
|
+
def reset_parameters(self):
|
|
308
|
+
# Default nn.Linear init is kaiming_uniform, which is fine
|
|
309
|
+
self.q_proj.reset_parameters()
|
|
310
|
+
self.k_proj.reset_parameters()
|
|
311
|
+
self.v_proj.reset_parameters()
|
|
312
|
+
self.out_proj.reset_parameters()
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class FeedforwardBlock(nn.Module):
|
|
316
|
+
"""
|
|
317
|
+
...
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
def __init__(
|
|
321
|
+
self,
|
|
322
|
+
input_features,
|
|
323
|
+
ratio,
|
|
324
|
+
output_features,
|
|
325
|
+
activation=nn.ReLU,
|
|
326
|
+
activation_kwargs=None,
|
|
327
|
+
dropout=0.0,
|
|
328
|
+
inner_dropout=None,
|
|
329
|
+
outer_dropout=None,
|
|
330
|
+
linear_module_up=nn.Linear,
|
|
331
|
+
linear_module_down=nn.Linear,
|
|
332
|
+
pre_norm=True,
|
|
333
|
+
normformer=False,
|
|
334
|
+
post_norm=True,
|
|
335
|
+
residual_path=True,
|
|
336
|
+
checkpoint=True,
|
|
337
|
+
):
|
|
338
|
+
super().__init__()
|
|
339
|
+
|
|
340
|
+
self.checkpoint = checkpoint
|
|
341
|
+
self.residual_path = residual_path
|
|
342
|
+
self.post_norm = post_norm
|
|
343
|
+
self.xglu = activation.__name__.endswith("GLU")
|
|
344
|
+
|
|
345
|
+
if self.residual_path and (output_features < input_features):
|
|
346
|
+
raise ValueError(
|
|
347
|
+
"If the number of output features will be less than "
|
|
348
|
+
"the number of input features, then `residual_path` "
|
|
349
|
+
"should be set to False."
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
if self.post_norm:
|
|
353
|
+
self.layernorm = nn.LayerNorm(output_features)
|
|
354
|
+
|
|
355
|
+
if activation_kwargs is not None:
|
|
356
|
+
self.activation = activation(**activation_kwargs)
|
|
357
|
+
else:
|
|
358
|
+
self.activation = activation()
|
|
359
|
+
|
|
360
|
+
self.inner_dropout = nn.Dropout(
|
|
361
|
+
inner_dropout if inner_dropout is not None else dropout
|
|
362
|
+
)
|
|
363
|
+
self.outer_dropout = nn.Dropout(
|
|
364
|
+
outer_dropout if outer_dropout is not None else dropout
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
self.max_features = (
|
|
368
|
+
2 * ratio * output_features if self.xglu else ratio * output_features
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
self.linear_in = linear_module_up(input_features, self.max_features)
|
|
372
|
+
self.linear_out = linear_module_down(ratio * output_features, output_features)
|
|
373
|
+
|
|
374
|
+
self.process = nn.Sequential(
|
|
375
|
+
*[
|
|
376
|
+
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
|
377
|
+
self.linear_in,
|
|
378
|
+
self.activation,
|
|
379
|
+
self.inner_dropout,
|
|
380
|
+
nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
|
|
381
|
+
self.linear_out,
|
|
382
|
+
self.outer_dropout,
|
|
383
|
+
]
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
self.recycling_enabled = False
|
|
387
|
+
if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
|
|
388
|
+
self.linear_out, "column_recycling_rate"
|
|
389
|
+
):
|
|
390
|
+
self.recycling_enabled = True
|
|
391
|
+
self.master_recycling_rate = self.linear_in.row_recycling_rate
|
|
392
|
+
self.linear_in.row_recycling_rate = 0.0
|
|
393
|
+
self.linear_out.column_recycling_rate = 0.0
|
|
394
|
+
if (
|
|
395
|
+
hasattr(self.linear_in, "column_recycling_rate")
|
|
396
|
+
and self.linear_in.column_recycling_rate > 0
|
|
397
|
+
) or (
|
|
398
|
+
hasattr(self.linear_out, "row_recycling_rate")
|
|
399
|
+
and self.linear_out.row_recycling_rate > 0
|
|
400
|
+
):
|
|
401
|
+
raise NotImplementedError(
|
|
402
|
+
"At the moment this layer can only support recycling linear "
|
|
403
|
+
"layers if the in layer resets only rows and the out layer "
|
|
404
|
+
"resets only columns."
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
self.reset_parameters()
|
|
408
|
+
|
|
409
|
+
def forward(self, x):
|
|
410
|
+
|
|
411
|
+
# Recycle weights if using recycling linear layers
|
|
412
|
+
if self.training and self.recycling_enabled:
|
|
413
|
+
indices = self.linear_out.get_reset_indices(1)
|
|
414
|
+
self.linear_in.reset_rows(indices)
|
|
415
|
+
self.linear_out.reset_columns(indices)
|
|
416
|
+
|
|
417
|
+
if self.checkpoint:
|
|
418
|
+
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
419
|
+
else:
|
|
420
|
+
processed = self.process(x)
|
|
421
|
+
|
|
422
|
+
if self.residual_path and self.post_norm:
|
|
423
|
+
return self.layernorm(x + processed)
|
|
424
|
+
elif self.residual_path:
|
|
425
|
+
return x + processed
|
|
426
|
+
else:
|
|
427
|
+
return processed
|
|
428
|
+
|
|
429
|
+
def reset_parameters(self):
|
|
430
|
+
if self.post_norm:
|
|
431
|
+
self.layernorm.reset_parameters()
|
|
432
|
+
|
|
433
|
+
# Iterate over the sequential block to reset parameters
|
|
434
|
+
for module in self.process:
|
|
435
|
+
if hasattr(module, "reset_parameters"):
|
|
436
|
+
module.reset_parameters()
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
class TransformerBlock(nn.Module):
|
|
440
|
+
"""
|
|
441
|
+
Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
|
|
442
|
+
which is also what is seen in e.g.
|
|
443
|
+
https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
|
444
|
+
and is recommended by https://arxiv.org/abs/2002.04745
|
|
445
|
+
|
|
446
|
+
"""
|
|
447
|
+
|
|
448
|
+
def __init__(
|
|
449
|
+
self,
|
|
450
|
+
seq_len,
|
|
451
|
+
d_model,
|
|
452
|
+
n_heads,
|
|
453
|
+
relative_position_embedding=False,
|
|
454
|
+
source_size=None,
|
|
455
|
+
utility_tokens=0,
|
|
456
|
+
mlp_ratio=4,
|
|
457
|
+
activation: nn.Module = nn.ReLU,
|
|
458
|
+
activation_kwargs: Optional[dict] = None,
|
|
459
|
+
ff_linear_module_up=None,
|
|
460
|
+
ff_linear_module_down=None,
|
|
461
|
+
msa_scaling="d",
|
|
462
|
+
ff_dropout=0.0,
|
|
463
|
+
ff_inner_dropout=0.0,
|
|
464
|
+
ff_outer_dropout=0.0,
|
|
465
|
+
msa_dropout=0.0,
|
|
466
|
+
identity_probability=0.0,
|
|
467
|
+
causal=False,
|
|
468
|
+
linear_module=nn.Linear,
|
|
469
|
+
pre_norm=True,
|
|
470
|
+
post_norm=False,
|
|
471
|
+
normformer=False,
|
|
472
|
+
checkpoint_ff=True,
|
|
473
|
+
):
|
|
474
|
+
"""
|
|
475
|
+
Args:
|
|
476
|
+
msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
|
|
477
|
+
to mimic the original Attention is All You Need approach of
|
|
478
|
+
dividing by the sqrt of the embedding Dimension or "d" per
|
|
479
|
+
"Tensor Programs V...". Default "d"
|
|
480
|
+
"""
|
|
481
|
+
|
|
482
|
+
super().__init__()
|
|
483
|
+
|
|
484
|
+
self.pre_norm = pre_norm
|
|
485
|
+
self.post_norm = post_norm
|
|
486
|
+
self.normformer = normformer
|
|
487
|
+
|
|
488
|
+
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
|
489
|
+
|
|
490
|
+
self.layer_norm_1 = nn.LayerNorm(d_model)
|
|
491
|
+
self.layer_norm_2 = nn.LayerNorm(d_model)
|
|
492
|
+
self.layer_norm_3 = nn.LayerNorm(d_model)
|
|
493
|
+
|
|
494
|
+
if relative_position_embedding:
|
|
495
|
+
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
|
496
|
+
if d_model < 16:
|
|
497
|
+
dim = d_model
|
|
498
|
+
else:
|
|
499
|
+
dim = 16
|
|
500
|
+
self.rotary_embedding = RotaryEmbedding(
|
|
501
|
+
dim=dim, freqs_for="pixel", max_freq=max_freq
|
|
502
|
+
)
|
|
503
|
+
else:
|
|
504
|
+
self.rotary_embedding = None
|
|
505
|
+
|
|
506
|
+
self.attn = MHAttention( # Handles QKV projection
|
|
507
|
+
d_model,
|
|
508
|
+
n_heads,
|
|
509
|
+
dropout=msa_dropout,
|
|
510
|
+
causal=causal,
|
|
511
|
+
seq_len=seq_len,
|
|
512
|
+
linear_module=linear_module,
|
|
513
|
+
rotary_embedding=self.rotary_embedding,
|
|
514
|
+
source_size=source_size,
|
|
515
|
+
utility_tokens=utility_tokens,
|
|
516
|
+
scaling=msa_scaling,
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
# Submodule for the feedforward process
|
|
520
|
+
self.ff = FeedforwardBlock(
|
|
521
|
+
d_model,
|
|
522
|
+
mlp_ratio,
|
|
523
|
+
d_model,
|
|
524
|
+
activation=activation,
|
|
525
|
+
activation_kwargs=activation_kwargs,
|
|
526
|
+
dropout=ff_dropout,
|
|
527
|
+
inner_dropout=ff_inner_dropout,
|
|
528
|
+
outer_dropout=ff_outer_dropout,
|
|
529
|
+
linear_module_up=(
|
|
530
|
+
ff_linear_module_up
|
|
531
|
+
if ff_linear_module_up is not None
|
|
532
|
+
else linear_module
|
|
533
|
+
),
|
|
534
|
+
linear_module_down=(
|
|
535
|
+
ff_linear_module_down
|
|
536
|
+
if ff_linear_module_down is not None
|
|
537
|
+
else linear_module
|
|
538
|
+
),
|
|
539
|
+
pre_norm=False, # Handled outside the block
|
|
540
|
+
normformer=normformer,
|
|
541
|
+
post_norm=False, # Handled outside the block
|
|
542
|
+
residual_path=False, # Handled outside the block
|
|
543
|
+
checkpoint=checkpoint_ff,
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
self.reset_parameters()
|
|
547
|
+
|
|
548
|
+
@property
|
|
549
|
+
def _kv_distance(self) -> float:
|
|
550
|
+
return self.attn._kv_distance
|
|
551
|
+
|
|
552
|
+
def forward(self, x):
|
|
553
|
+
|
|
554
|
+
if self.pre_norm:
|
|
555
|
+
x = self.layer_norm_1(x)
|
|
556
|
+
x = x + self.drop_path(self.attn(x, x, x))
|
|
557
|
+
x = self.layer_norm_2(x)
|
|
558
|
+
x = x + self.drop_path(self.ff(x))
|
|
559
|
+
if self.post_norm: # i.e. in addition! Pre and post.
|
|
560
|
+
x = self.layer_norm_3(x)
|
|
561
|
+
elif self.post_norm: # i.e. only, not prenorm, just post
|
|
562
|
+
x = x + self.drop_path(self.attn(x, x, x))
|
|
563
|
+
x = self.layer_norm_1(x)
|
|
564
|
+
x = x + self.drop_path(self.ff(x))
|
|
565
|
+
x = self.layer_norm_2(x)
|
|
566
|
+
else: # Not pre or post norm. Stand well back.
|
|
567
|
+
x = x + self.drop_path(self.attn(x, x, x))
|
|
568
|
+
x = x + self.drop_path(self.ff(x))
|
|
569
|
+
|
|
570
|
+
return x
|
|
571
|
+
|
|
572
|
+
def attention_logits(self, x):
|
|
573
|
+
"""
|
|
574
|
+
Give back the attention scores used in this layer.
|
|
575
|
+
"""
|
|
576
|
+
if self.pre_norm:
|
|
577
|
+
x = self.layer_norm_1(x)
|
|
578
|
+
return self.attn.attention_logits(x, x, x)
|
|
579
|
+
else:
|
|
580
|
+
return self.attn.attention_logits(x, x, x)
|
|
581
|
+
|
|
582
|
+
def reset_parameters(self):
|
|
583
|
+
self.layer_norm_1.reset_parameters()
|
|
584
|
+
self.layer_norm_2.reset_parameters()
|
|
585
|
+
self.layer_norm_3.reset_parameters()
|
|
586
|
+
|
|
587
|
+
self.attn.reset_parameters()
|
|
588
|
+
self.ff.reset_parameters()
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
class TransformerEncoder(nn.Module):
|
|
592
|
+
"""
|
|
593
|
+
This assumes we already get a sequence of embeddings (e.g. word or image
|
|
594
|
+
patch embeddings).
|
|
595
|
+
"""
|
|
596
|
+
|
|
597
|
+
def __init__(
|
|
598
|
+
self,
|
|
599
|
+
seq_len,
|
|
600
|
+
d_model,
|
|
601
|
+
n_layers,
|
|
602
|
+
n_heads,
|
|
603
|
+
absolute_position_embedding=True,
|
|
604
|
+
relative_position_embedding=False,
|
|
605
|
+
source_size=None,
|
|
606
|
+
mlp_ratio=4,
|
|
607
|
+
activation: nn.Module = nn.ReLU,
|
|
608
|
+
activation_kwargs: Optional[dict] = None,
|
|
609
|
+
ff_linear_module_up=None,
|
|
610
|
+
ff_linear_module_down=None,
|
|
611
|
+
ff_dropout=0.0,
|
|
612
|
+
ff_inner_dropout=0.0,
|
|
613
|
+
ff_outer_dropout=0.0,
|
|
614
|
+
msa_dropout=0.0,
|
|
615
|
+
stochastic_depth=0.0,
|
|
616
|
+
causal=False,
|
|
617
|
+
linear_module=nn.Linear,
|
|
618
|
+
utility_tokens=0,
|
|
619
|
+
return_utility_tokens=False,
|
|
620
|
+
pre_norm=True,
|
|
621
|
+
post_norm=False,
|
|
622
|
+
normformer=False,
|
|
623
|
+
msa_scaling="d",
|
|
624
|
+
checkpoint_ff=True,
|
|
625
|
+
):
|
|
626
|
+
"""
|
|
627
|
+
Args:
|
|
628
|
+
msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
|
|
629
|
+
to mimic the original Attention is All You Need approach of
|
|
630
|
+
dividing by the sqrt of the embedding Dimension or "d" per
|
|
631
|
+
"Tensor Programs V...". Default "d"
|
|
632
|
+
"""
|
|
633
|
+
|
|
634
|
+
if relative_position_embedding and (source_size is None):
|
|
635
|
+
raise ValueError(
|
|
636
|
+
"`source_size` for TransformerEncoder cannot be None if"
|
|
637
|
+
" `relative_position_embedding` is True"
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
if absolute_position_embedding and (seq_len is None):
|
|
641
|
+
raise ValueError(
|
|
642
|
+
"`seq_len` for TransformerEncoder cannot be None if"
|
|
643
|
+
" `absolute_position_embedding` is True"
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
super().__init__()
|
|
647
|
+
self.seq_len = seq_len
|
|
648
|
+
self.n_heads = n_heads
|
|
649
|
+
self._utility_tokens = utility_tokens
|
|
650
|
+
self.return_utility_tokens = return_utility_tokens
|
|
651
|
+
|
|
652
|
+
# Initialise utility tokens with normal init, like usual Pytorch embeddings
|
|
653
|
+
if self._utility_tokens:
|
|
654
|
+
self._utility_token_embedding = nn.Parameter(
|
|
655
|
+
torch.empty(self._utility_tokens, d_model)
|
|
656
|
+
)
|
|
657
|
+
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
658
|
+
else:
|
|
659
|
+
self._utility_token_embedding = None
|
|
660
|
+
|
|
661
|
+
if self._utility_tokens and (self.seq_len is not None):
|
|
662
|
+
self.full_sequence_length = self.seq_len + self._utility_tokens
|
|
663
|
+
else:
|
|
664
|
+
self.full_sequence_length = self.seq_len
|
|
665
|
+
|
|
666
|
+
self.d_model = d_model
|
|
667
|
+
|
|
668
|
+
if absolute_position_embedding:
|
|
669
|
+
self.absolute_position_embedding = nn.Embedding(
|
|
670
|
+
self.full_sequence_length, d_model
|
|
671
|
+
)
|
|
672
|
+
else:
|
|
673
|
+
self.absolute_position_embedding = None
|
|
674
|
+
|
|
675
|
+
self.mlp_dropout = ff_dropout
|
|
676
|
+
self.msa_dropout = msa_dropout
|
|
677
|
+
self.stochastic_depth = stochastic_depth
|
|
678
|
+
|
|
679
|
+
assert isinstance(n_layers, int)
|
|
680
|
+
|
|
681
|
+
if n_layers == 1:
|
|
682
|
+
self.stochastic_depth_probabilities = [0.0]
|
|
683
|
+
else:
|
|
684
|
+
step_size = self.stochastic_depth / (n_layers - 1)
|
|
685
|
+
self.stochastic_depth_probabilities = [
|
|
686
|
+
i * step_size for i in range(n_layers)
|
|
687
|
+
]
|
|
688
|
+
|
|
689
|
+
self.blocks = nn.ModuleList(
|
|
690
|
+
[
|
|
691
|
+
TransformerBlock(
|
|
692
|
+
self.full_sequence_length,
|
|
693
|
+
d_model,
|
|
694
|
+
n_heads,
|
|
695
|
+
relative_position_embedding=relative_position_embedding,
|
|
696
|
+
source_size=source_size,
|
|
697
|
+
utility_tokens=utility_tokens,
|
|
698
|
+
mlp_ratio=mlp_ratio,
|
|
699
|
+
activation=activation,
|
|
700
|
+
activation_kwargs=activation_kwargs,
|
|
701
|
+
ff_linear_module_up=ff_linear_module_up,
|
|
702
|
+
ff_linear_module_down=ff_linear_module_down,
|
|
703
|
+
msa_scaling=msa_scaling,
|
|
704
|
+
ff_dropout=ff_dropout,
|
|
705
|
+
ff_inner_dropout=ff_inner_dropout,
|
|
706
|
+
ff_outer_dropout=ff_outer_dropout,
|
|
707
|
+
msa_dropout=msa_dropout,
|
|
708
|
+
identity_probability=self.stochastic_depth_probabilities[i],
|
|
709
|
+
causal=causal,
|
|
710
|
+
linear_module=linear_module,
|
|
711
|
+
pre_norm=pre_norm,
|
|
712
|
+
post_norm=post_norm,
|
|
713
|
+
normformer=normformer,
|
|
714
|
+
checkpoint_ff=checkpoint_ff,
|
|
715
|
+
)
|
|
716
|
+
for i in range(n_layers)
|
|
717
|
+
]
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
self.reset_parameters()
|
|
721
|
+
|
|
722
|
+
@property
|
|
723
|
+
def _kv_distances(self) -> float:
|
|
724
|
+
return ",".join([str(block._kv_distance) for block in self.blocks])
|
|
725
|
+
|
|
726
|
+
def preprocess(self, x):
|
|
727
|
+
if self._utility_tokens:
|
|
728
|
+
x = torch.cat(
|
|
729
|
+
[self._utility_token_embedding.expand(x.size(0), -1, -1), x], dim=1
|
|
730
|
+
)
|
|
731
|
+
else:
|
|
732
|
+
x = x
|
|
733
|
+
|
|
734
|
+
if self.absolute_position_embedding is not None:
|
|
735
|
+
x = x + self.absolute_position_embedding(
|
|
736
|
+
torch.arange(
|
|
737
|
+
0, self.full_sequence_length, dtype=torch.long, device=x.device
|
|
738
|
+
).unsqueeze(
|
|
739
|
+
0
|
|
740
|
+
) # to shape (1, seq_len) to broadcast over batch
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
return x
|
|
744
|
+
|
|
745
|
+
def forward(self, x):
|
|
746
|
+
|
|
747
|
+
x = self.preprocess(x)
|
|
748
|
+
|
|
749
|
+
for block in self.blocks:
|
|
750
|
+
x = block(x)
|
|
751
|
+
|
|
752
|
+
if self._utility_tokens and not self.return_utility_tokens:
|
|
753
|
+
return x[:, self._utility_tokens :, :]
|
|
754
|
+
else:
|
|
755
|
+
return x
|
|
756
|
+
|
|
757
|
+
def attention_logits(self, x):
|
|
758
|
+
|
|
759
|
+
x = self.preprocess(x)
|
|
760
|
+
|
|
761
|
+
layer_scores = []
|
|
762
|
+
|
|
763
|
+
for block in self.blocks:
|
|
764
|
+
# Get attention scores with shape (batch, 1, head, seq_len, seq_len)
|
|
765
|
+
layer_attention_logits = block.attention_logits(x).unsqueeze(1)
|
|
766
|
+
layer_scores.append(layer_attention_logits)
|
|
767
|
+
x = block(x)
|
|
768
|
+
|
|
769
|
+
return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
|
|
770
|
+
|
|
771
|
+
def reset_parameters(self):
|
|
772
|
+
if self._utility_token_embedding is not None:
|
|
773
|
+
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
774
|
+
|
|
775
|
+
if self.absolute_position_embedding is not None:
|
|
776
|
+
self.absolute_position_embedding.reset_parameters()
|
|
777
|
+
|
|
778
|
+
for block in self.blocks:
|
|
779
|
+
block.reset_parameters()
|