broccoli-ml 5.1.1__tar.gz → 9.5.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-5.1.1 → broccoli_ml-9.5.1}/PKG-INFO +1 -1
- {broccoli_ml-5.1.1 → broccoli_ml-9.5.1}/broccoli/activation.py +1 -4
- broccoli_ml-9.5.1/broccoli/linear.py +352 -0
- {broccoli_ml-5.1.1 → broccoli_ml-9.5.1}/broccoli/transformer.py +107 -50
- {broccoli_ml-5.1.1 → broccoli_ml-9.5.1}/broccoli/vit.py +66 -28
- {broccoli_ml-5.1.1 → broccoli_ml-9.5.1}/pyproject.toml +1 -1
- broccoli_ml-5.1.1/broccoli/linear.py +0 -138
- {broccoli_ml-5.1.1 → broccoli_ml-9.5.1}/LICENSE +0 -0
- {broccoli_ml-5.1.1 → broccoli_ml-9.5.1}/README.md +0 -0
- {broccoli_ml-5.1.1 → broccoli_ml-9.5.1}/broccoli/__init__.py +0 -0
- {broccoli_ml-5.1.1 → broccoli_ml-9.5.1}/broccoli/cnn.py +0 -0
- {broccoli_ml-5.1.1 → broccoli_ml-9.5.1}/broccoli/rope.py +0 -0
- {broccoli_ml-5.1.1 → broccoli_ml-9.5.1}/broccoli/tensor.py +0 -0
- {broccoli_ml-5.1.1 → broccoli_ml-9.5.1}/broccoli/utils.py +0 -0
|
@@ -46,10 +46,7 @@ class GELU(nn.Module):
|
|
|
46
46
|
|
|
47
47
|
class Swish(nn.Module):
|
|
48
48
|
"""
|
|
49
|
-
Implementation of (beta)
|
|
50
|
-
(https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
|
|
51
|
-
|
|
52
|
-
Halves the incoming parameter count, which should be scaled up before input.
|
|
49
|
+
Implementation of (beta) Swish
|
|
53
50
|
"""
|
|
54
51
|
|
|
55
52
|
def __init__(self) -> None:
|
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import random
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Union, List, Iterable
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torch.nn import functional as F
|
|
9
|
+
|
|
10
|
+
from .tensor import SigmaReparamTensor, AnchoredReparamTensor, NormReparamTensor
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SpectralNormLinear(nn.Module):
|
|
14
|
+
"""
|
|
15
|
+
Inspired by Apple's Spectral Normed Linear Layers
|
|
16
|
+
(https://github.com/apple/ml-sigma-reparam)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.in_features = in_features
|
|
22
|
+
self.out_features = out_features
|
|
23
|
+
self.use_bias = bias
|
|
24
|
+
|
|
25
|
+
self.weights = None
|
|
26
|
+
|
|
27
|
+
# Define the bias vector as a learnable parameter if required.
|
|
28
|
+
if self.use_bias:
|
|
29
|
+
self.bias = nn.Parameter(torch.empty(out_features))
|
|
30
|
+
else:
|
|
31
|
+
# If no bias, register it as None.
|
|
32
|
+
# This is important so that PyTorch doesn't complain when saving/loading the model.
|
|
33
|
+
self.register_parameter("bias", None)
|
|
34
|
+
|
|
35
|
+
self.reset_parameters()
|
|
36
|
+
|
|
37
|
+
def reset_parameters(self) -> None:
|
|
38
|
+
weights = torch.empty(self.out_features, self.in_features)
|
|
39
|
+
stdv = 1.0 / math.sqrt(self.in_features)
|
|
40
|
+
nn.init.uniform_(weights, a=-stdv, b=stdv)
|
|
41
|
+
if self.use_bias:
|
|
42
|
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
|
43
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
44
|
+
nn.init.uniform_(self.bias, -bound, bound)
|
|
45
|
+
self.weights = SigmaReparamTensor(weights)
|
|
46
|
+
|
|
47
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
48
|
+
return F.linear(x, self.weights(), self.bias)
|
|
49
|
+
|
|
50
|
+
def __repr__(self) -> str:
|
|
51
|
+
# Optional: A nice representation for printing the module.
|
|
52
|
+
return (
|
|
53
|
+
f"SpectralNormFeedForward(in_features={self.in_features},"
|
|
54
|
+
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class AnchoredLinear(nn.Module):
|
|
59
|
+
"""
|
|
60
|
+
...
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.in_features = in_features
|
|
66
|
+
self.out_features = out_features
|
|
67
|
+
self.use_bias = bias
|
|
68
|
+
|
|
69
|
+
self.weights = None
|
|
70
|
+
|
|
71
|
+
# Define the bias vector as a learnable parameter if required.
|
|
72
|
+
if self.use_bias:
|
|
73
|
+
self.bias = nn.Parameter(torch.empty(out_features))
|
|
74
|
+
else:
|
|
75
|
+
# If no bias, register it as None.
|
|
76
|
+
# This is important so that PyTorch doesn't complain when saving/loading the model.
|
|
77
|
+
self.register_parameter("bias", None)
|
|
78
|
+
|
|
79
|
+
self.reset_parameters()
|
|
80
|
+
|
|
81
|
+
def reset_parameters(self) -> None:
|
|
82
|
+
weights = torch.empty(self.out_features, self.in_features)
|
|
83
|
+
stdv = 1.0 / math.sqrt(self.in_features)
|
|
84
|
+
nn.init.uniform_(weights, a=-stdv, b=stdv)
|
|
85
|
+
if self.use_bias:
|
|
86
|
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
|
87
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
88
|
+
nn.init.uniform_(self.bias, -bound, bound)
|
|
89
|
+
self.weights = AnchoredReparamTensor(weights)
|
|
90
|
+
|
|
91
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
92
|
+
return F.linear(x, self.weights(), self.bias)
|
|
93
|
+
|
|
94
|
+
def __repr__(self) -> str:
|
|
95
|
+
# Optional: A nice representation for printing the module.
|
|
96
|
+
return (
|
|
97
|
+
f"AnchoredLinear(in_features={self.in_features},"
|
|
98
|
+
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class WeightNormedLinear(nn.Module):
|
|
103
|
+
"""
|
|
104
|
+
...
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
|
108
|
+
super().__init__()
|
|
109
|
+
self.in_features = in_features
|
|
110
|
+
self.out_features = out_features
|
|
111
|
+
self.use_bias = bias
|
|
112
|
+
|
|
113
|
+
self.weights = None
|
|
114
|
+
|
|
115
|
+
# Define the bias vector as a learnable parameter if required.
|
|
116
|
+
if self.use_bias:
|
|
117
|
+
self.bias = nn.Parameter(torch.empty(out_features))
|
|
118
|
+
else:
|
|
119
|
+
# If no bias, register it as None.
|
|
120
|
+
# This is important so that PyTorch doesn't complain when saving/loading the model.
|
|
121
|
+
self.register_parameter("bias", None)
|
|
122
|
+
|
|
123
|
+
self.reset_parameters()
|
|
124
|
+
|
|
125
|
+
def reset_parameters(self) -> None:
|
|
126
|
+
weights = torch.empty(self.out_features, self.in_features)
|
|
127
|
+
stdv = 1.0 / math.sqrt(self.in_features)
|
|
128
|
+
nn.init.uniform_(weights, a=-stdv, b=stdv)
|
|
129
|
+
if self.use_bias:
|
|
130
|
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
|
131
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
132
|
+
nn.init.uniform_(self.bias, -bound, bound)
|
|
133
|
+
self.weights = NormReparamTensor(weights)
|
|
134
|
+
|
|
135
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
136
|
+
return F.linear(x, self.weights(), self.bias)
|
|
137
|
+
|
|
138
|
+
def __repr__(self) -> str:
|
|
139
|
+
return (
|
|
140
|
+
f"WeightNormedLinear(in_features={self.in_features},"
|
|
141
|
+
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class RecyclingLinear(nn.Module):
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
in_features: int,
|
|
149
|
+
out_features: int,
|
|
150
|
+
bias: bool = True,
|
|
151
|
+
row_recycling_rate: float = 0.0,
|
|
152
|
+
column_recycling_rate: float = 0.0,
|
|
153
|
+
adaptive=False,
|
|
154
|
+
xglu=False,
|
|
155
|
+
):
|
|
156
|
+
super().__init__()
|
|
157
|
+
self.in_features = in_features
|
|
158
|
+
self.out_features = out_features
|
|
159
|
+
self.bias = bias
|
|
160
|
+
self.xglu = xglu
|
|
161
|
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
162
|
+
self.row_recycling_rate = row_recycling_rate
|
|
163
|
+
self.column_recycling_rate = column_recycling_rate
|
|
164
|
+
self.adaptive = adaptive
|
|
165
|
+
self.optimisers = []
|
|
166
|
+
self.initial_learning_rates = []
|
|
167
|
+
self._warned_about_registration = False
|
|
168
|
+
|
|
169
|
+
def register_optimiser(self, optimiser: torch.optim.Optimizer):
|
|
170
|
+
self.optimisers.append(optimiser)
|
|
171
|
+
self.initial_learning_rates.append(self._get_learning_rate(optimiser))
|
|
172
|
+
if self.initial_learning_rates[-1] == 0.0:
|
|
173
|
+
warnings.warn(
|
|
174
|
+
"Learning rate of registered optimiser was 0.0 - make sure "
|
|
175
|
+
"you haven't initialised a scheduler before registering the "
|
|
176
|
+
"optimiser",
|
|
177
|
+
stacklevel=2,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def _get_learning_rate(self, optimiser: torch.optim.Optimizer):
|
|
181
|
+
for group in optimiser.param_groups:
|
|
182
|
+
for param in group["params"]:
|
|
183
|
+
if param is self.linear.weight:
|
|
184
|
+
return group["lr"]
|
|
185
|
+
|
|
186
|
+
def _get_multiplier(self):
|
|
187
|
+
if not self.adaptive or not self.optimisers:
|
|
188
|
+
return 1.0
|
|
189
|
+
else:
|
|
190
|
+
init = self.initial_learning_rates
|
|
191
|
+
current = [self._get_learning_rate(o) for o in self.optimisers]
|
|
192
|
+
pairs = zip(current, init, strict=True)
|
|
193
|
+
multipliers = [a / b for a, b in pairs if b != 0.0]
|
|
194
|
+
return min(multipliers) if multipliers else 0.0
|
|
195
|
+
|
|
196
|
+
def reset_rows(self, indices):
|
|
197
|
+
if not torch.is_tensor(indices):
|
|
198
|
+
idx_tensor = torch.as_tensor(
|
|
199
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
200
|
+
)
|
|
201
|
+
else:
|
|
202
|
+
idx_tensor = indices
|
|
203
|
+
|
|
204
|
+
if idx_tensor.size(0):
|
|
205
|
+
value_indices = indices
|
|
206
|
+
centred_value_weights = self._mean_value_weights()
|
|
207
|
+
centred_value_weights = centred_value_weights.expand(indices.size(0), -1)
|
|
208
|
+
if self.xglu:
|
|
209
|
+
gate_indices = indices
|
|
210
|
+
value_indices = indices + (self.linear.out_features // 2)
|
|
211
|
+
centred_gate_weights = self._mean_gate_weights()
|
|
212
|
+
centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
|
|
213
|
+
self._update_weights(
|
|
214
|
+
gate_indices, 0, centred_gate_weights, self.optimisers # dim
|
|
215
|
+
)
|
|
216
|
+
self._update_weights(
|
|
217
|
+
value_indices, 0, centred_value_weights, self.optimisers
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
return
|
|
221
|
+
|
|
222
|
+
def reset_columns(self, indices):
|
|
223
|
+
if not torch.is_tensor(indices):
|
|
224
|
+
idx_tensor = torch.as_tensor(
|
|
225
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
226
|
+
)
|
|
227
|
+
else:
|
|
228
|
+
idx_tensor = indices
|
|
229
|
+
|
|
230
|
+
if idx_tensor.size(0):
|
|
231
|
+
random_weights = self._random_weights(
|
|
232
|
+
self.linear.weight.size(0), indices.size(0)
|
|
233
|
+
)
|
|
234
|
+
# Make random col weights quiet so they don't introduce loud noise...
|
|
235
|
+
# ...but not so quiet that FP16 zeros them and ruins symmetry breaking!
|
|
236
|
+
random_weights *= 0.1
|
|
237
|
+
self._update_weights(indices, 1, random_weights, self.optimisers) # dim
|
|
238
|
+
else:
|
|
239
|
+
return
|
|
240
|
+
|
|
241
|
+
def forward(self, x):
|
|
242
|
+
if self.training and self.optimisers:
|
|
243
|
+
self.reset_rows(self.get_reset_indices(0))
|
|
244
|
+
self.reset_columns(self.get_reset_indices(1))
|
|
245
|
+
elif self.training and not self._warned_about_registration:
|
|
246
|
+
warnings.warn(
|
|
247
|
+
"RecyclingLinear: No optimiser registered. Recycling disabled.",
|
|
248
|
+
stacklevel=2,
|
|
249
|
+
)
|
|
250
|
+
self._warned_about_registration = True
|
|
251
|
+
|
|
252
|
+
return self.linear(x)
|
|
253
|
+
|
|
254
|
+
def get_reset_indices(self, dim):
|
|
255
|
+
base_rate = self.row_recycling_rate if dim == 0 else self.column_recycling_rate
|
|
256
|
+
p = base_rate * self._get_multiplier()
|
|
257
|
+
if dim == 0:
|
|
258
|
+
if self.xglu:
|
|
259
|
+
sample_space = self.linear.out_features // 2
|
|
260
|
+
else:
|
|
261
|
+
sample_space = self.linear.out_features
|
|
262
|
+
elif dim == 1:
|
|
263
|
+
sample_space = self.linear.in_features
|
|
264
|
+
else:
|
|
265
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
266
|
+
|
|
267
|
+
# Sample the indices
|
|
268
|
+
probs = torch.rand(sample_space, device=self.linear.weight.device)
|
|
269
|
+
mask = probs < p
|
|
270
|
+
if mask.any():
|
|
271
|
+
return torch.nonzero(mask).squeeze(-1)
|
|
272
|
+
else:
|
|
273
|
+
return torch.tensor([], dtype=torch.long, device=self.linear.weight.device)
|
|
274
|
+
|
|
275
|
+
def _random_weights(self, rows, columns):
|
|
276
|
+
device = self.linear.weight.device
|
|
277
|
+
weights = self.linear.weight.data
|
|
278
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
279
|
+
random_weights = torch.rand(rows, columns, device=device)
|
|
280
|
+
random_weights -= 0.5 # Range [-0.5, +0.5]
|
|
281
|
+
random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
|
|
282
|
+
return random_weights
|
|
283
|
+
|
|
284
|
+
def _mean_value_weights(self):
|
|
285
|
+
"""
|
|
286
|
+
Only used when self.xglu
|
|
287
|
+
"""
|
|
288
|
+
weights = self.linear.weight.data
|
|
289
|
+
rows = weights.size(0)
|
|
290
|
+
if self.xglu:
|
|
291
|
+
return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
|
|
292
|
+
else:
|
|
293
|
+
return self.linear.weight.data.mean(dim=0, keepdim=True)
|
|
294
|
+
|
|
295
|
+
def _mean_gate_weights(self):
|
|
296
|
+
"""
|
|
297
|
+
Only used when self.xglu
|
|
298
|
+
"""
|
|
299
|
+
weights = self.linear.weight.data
|
|
300
|
+
rows = weights.size(0)
|
|
301
|
+
return self.linear.weight[: int(rows / 2)].data.mean(dim=0, keepdim=True)
|
|
302
|
+
|
|
303
|
+
def _update_weights(
|
|
304
|
+
self,
|
|
305
|
+
indices: Iterable[int],
|
|
306
|
+
dim: int,
|
|
307
|
+
data: torch.Tensor,
|
|
308
|
+
optimisers: Union[
|
|
309
|
+
List[torch.optim.Optimizer], torch.optim.Optimizer, None
|
|
310
|
+
] = None,
|
|
311
|
+
):
|
|
312
|
+
if optimisers is None:
|
|
313
|
+
optimisers = []
|
|
314
|
+
if not isinstance(optimisers, list):
|
|
315
|
+
optimisers = [optimisers]
|
|
316
|
+
|
|
317
|
+
if not torch.is_tensor(indices):
|
|
318
|
+
idx_tensor = torch.as_tensor(
|
|
319
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
320
|
+
)
|
|
321
|
+
else:
|
|
322
|
+
idx_tensor = indices
|
|
323
|
+
|
|
324
|
+
if idx_tensor.numel() == 0:
|
|
325
|
+
return
|
|
326
|
+
|
|
327
|
+
with torch.no_grad():
|
|
328
|
+
if dim == 0:
|
|
329
|
+
self.linear.weight.data[idx_tensor] = data
|
|
330
|
+
elif dim == 1:
|
|
331
|
+
self.linear.weight.data[:, idx_tensor] = data
|
|
332
|
+
else:
|
|
333
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
334
|
+
|
|
335
|
+
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=dim)
|
|
336
|
+
|
|
337
|
+
def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
|
|
338
|
+
"""
|
|
339
|
+
Zeroes out the optimizer state for the given indices in a single operation.
|
|
340
|
+
"""
|
|
341
|
+
for optimiser in optimisers:
|
|
342
|
+
if param not in optimiser.state:
|
|
343
|
+
continue
|
|
344
|
+
state = optimiser.state[param]
|
|
345
|
+
|
|
346
|
+
for _, buffer in state.items():
|
|
347
|
+
if torch.is_tensor(buffer) and buffer.shape == param.shape:
|
|
348
|
+
# Vectorized zeroing
|
|
349
|
+
if dim == 0:
|
|
350
|
+
buffer[idx_tensor] = 0.0
|
|
351
|
+
else:
|
|
352
|
+
buffer[:, idx_tensor] = 0.0
|
|
@@ -13,6 +13,7 @@ from .rope import RotaryEmbedding, apply_rotary_emb
|
|
|
13
13
|
try:
|
|
14
14
|
from flash_attn import flash_attn_func
|
|
15
15
|
|
|
16
|
+
print("Using flash-attn.")
|
|
16
17
|
FLASH_ATTN = True
|
|
17
18
|
except ImportError:
|
|
18
19
|
pass
|
|
@@ -76,7 +77,7 @@ class MHAttention(nn.Module):
|
|
|
76
77
|
causal=False,
|
|
77
78
|
seq_len=None,
|
|
78
79
|
linear_module: nn.Module = nn.Linear,
|
|
79
|
-
|
|
80
|
+
utility_tokens=0,
|
|
80
81
|
rotary_embedding=None,
|
|
81
82
|
source_size=None,
|
|
82
83
|
scaling="d",
|
|
@@ -129,7 +130,7 @@ class MHAttention(nn.Module):
|
|
|
129
130
|
)
|
|
130
131
|
self.rotary_embedding = rotary_embedding
|
|
131
132
|
self.source_size = source_size
|
|
132
|
-
self.
|
|
133
|
+
self.utility_tokens = utility_tokens
|
|
133
134
|
|
|
134
135
|
self.reset_parameters()
|
|
135
136
|
|
|
@@ -156,7 +157,7 @@ class MHAttention(nn.Module):
|
|
|
156
157
|
self, q: torch.Tensor, k: torch.Tensor
|
|
157
158
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
158
159
|
"""
|
|
159
|
-
Apply Axial RoPE to all tokens except
|
|
160
|
+
Apply Axial RoPE to all tokens except utility tokens
|
|
160
161
|
"""
|
|
161
162
|
|
|
162
163
|
if len(self.source_size) == 1:
|
|
@@ -180,8 +181,8 @@ class MHAttention(nn.Module):
|
|
|
180
181
|
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
181
182
|
)
|
|
182
183
|
|
|
183
|
-
|
|
184
|
-
|
|
184
|
+
q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
|
|
185
|
+
k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
|
|
185
186
|
|
|
186
187
|
q_img = rearrange(
|
|
187
188
|
q_img,
|
|
@@ -208,9 +209,9 @@ class MHAttention(nn.Module):
|
|
|
208
209
|
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
209
210
|
)
|
|
210
211
|
|
|
211
|
-
# Re-combine the
|
|
212
|
-
q = torch.cat([
|
|
213
|
-
k = torch.cat([
|
|
212
|
+
# Re-combine the utility tokens and the RoPE-enhanced sequence tokens
|
|
213
|
+
q = torch.cat([q_util, q_img], dim=1)
|
|
214
|
+
k = torch.cat([k_util, k_img], dim=1)
|
|
214
215
|
|
|
215
216
|
return q, k
|
|
216
217
|
|
|
@@ -284,7 +285,7 @@ class MHAttention(nn.Module):
|
|
|
284
285
|
|
|
285
286
|
return self.out_proj(output_without_heads)
|
|
286
287
|
|
|
287
|
-
def
|
|
288
|
+
def attention_logits(self, q, k, v):
|
|
288
289
|
|
|
289
290
|
q, k, v = self.project_qkv(q, k, v)
|
|
290
291
|
|
|
@@ -301,8 +302,6 @@ class MHAttention(nn.Module):
|
|
|
301
302
|
if self.causal:
|
|
302
303
|
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
303
304
|
|
|
304
|
-
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
305
|
-
|
|
306
305
|
return qk_scores # (batch, head, seq_len, seq_len)
|
|
307
306
|
|
|
308
307
|
def reset_parameters(self):
|
|
@@ -326,6 +325,8 @@ class FeedforwardBlock(nn.Module):
|
|
|
326
325
|
activation=nn.ReLU,
|
|
327
326
|
activation_kwargs=None,
|
|
328
327
|
dropout=0.0,
|
|
328
|
+
inner_dropout=None,
|
|
329
|
+
outer_dropout=None,
|
|
329
330
|
linear_module_up=nn.Linear,
|
|
330
331
|
linear_module_down=nn.Linear,
|
|
331
332
|
pre_norm=True,
|
|
@@ -339,6 +340,7 @@ class FeedforwardBlock(nn.Module):
|
|
|
339
340
|
self.checkpoint = checkpoint
|
|
340
341
|
self.residual_path = residual_path
|
|
341
342
|
self.post_norm = post_norm
|
|
343
|
+
self.xglu = activation.__name__.endswith("GLU")
|
|
342
344
|
|
|
343
345
|
if self.residual_path and (output_features < input_features):
|
|
344
346
|
raise ValueError(
|
|
@@ -355,29 +357,63 @@ class FeedforwardBlock(nn.Module):
|
|
|
355
357
|
else:
|
|
356
358
|
self.activation = activation()
|
|
357
359
|
|
|
358
|
-
self.
|
|
360
|
+
self.inner_dropout = nn.Dropout(
|
|
361
|
+
inner_dropout if inner_dropout is not None else dropout
|
|
362
|
+
)
|
|
363
|
+
self.outer_dropout = nn.Dropout(
|
|
364
|
+
outer_dropout if outer_dropout is not None else dropout
|
|
365
|
+
)
|
|
359
366
|
|
|
360
367
|
self.max_features = (
|
|
361
|
-
2 * ratio * output_features
|
|
362
|
-
if activation.__name__.endswith("GLU")
|
|
363
|
-
else ratio * output_features
|
|
368
|
+
2 * ratio * output_features if self.xglu else ratio * output_features
|
|
364
369
|
)
|
|
365
370
|
|
|
371
|
+
self.linear_in = linear_module_up(input_features, self.max_features)
|
|
372
|
+
self.linear_out = linear_module_down(ratio * output_features, output_features)
|
|
373
|
+
|
|
366
374
|
self.process = nn.Sequential(
|
|
367
375
|
*[
|
|
368
376
|
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
|
369
|
-
|
|
377
|
+
self.linear_in,
|
|
370
378
|
self.activation,
|
|
379
|
+
self.inner_dropout,
|
|
371
380
|
nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
|
|
372
|
-
|
|
373
|
-
self.
|
|
381
|
+
self.linear_out,
|
|
382
|
+
self.outer_dropout,
|
|
374
383
|
]
|
|
375
384
|
)
|
|
376
385
|
|
|
386
|
+
self.recycling_enabled = False
|
|
387
|
+
if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
|
|
388
|
+
self.linear_out, "column_recycling_rate"
|
|
389
|
+
):
|
|
390
|
+
self.recycling_enabled = True
|
|
391
|
+
self.master_recycling_rate = self.linear_in.row_recycling_rate
|
|
392
|
+
self.linear_in.row_recycling_rate = 0.0
|
|
393
|
+
self.linear_out.column_recycling_rate = 0.0
|
|
394
|
+
if (
|
|
395
|
+
hasattr(self.linear_in, "column_recycling_rate")
|
|
396
|
+
and self.linear_in.column_recycling_rate > 0
|
|
397
|
+
) or (
|
|
398
|
+
hasattr(self.linear_out, "row_recycling_rate")
|
|
399
|
+
and self.linear_out.row_recycling_rate > 0
|
|
400
|
+
):
|
|
401
|
+
raise NotImplementedError(
|
|
402
|
+
"At the moment this layer can only support recycling linear "
|
|
403
|
+
"layers if the in layer resets only rows and the out layer "
|
|
404
|
+
"resets only columns."
|
|
405
|
+
)
|
|
406
|
+
|
|
377
407
|
self.reset_parameters()
|
|
378
408
|
|
|
379
409
|
def forward(self, x):
|
|
380
410
|
|
|
411
|
+
# Recycle weights if using recycling linear layers
|
|
412
|
+
if self.training and self.recycling_enabled:
|
|
413
|
+
indices = self.linear_out.get_reset_indices(1)
|
|
414
|
+
self.linear_in.reset_rows(indices)
|
|
415
|
+
self.linear_out.reset_columns(indices)
|
|
416
|
+
|
|
381
417
|
if self.checkpoint:
|
|
382
418
|
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
383
419
|
else:
|
|
@@ -416,14 +452,16 @@ class TransformerBlock(nn.Module):
|
|
|
416
452
|
n_heads,
|
|
417
453
|
relative_position_embedding=False,
|
|
418
454
|
source_size=None,
|
|
419
|
-
|
|
455
|
+
utility_tokens=0,
|
|
420
456
|
mlp_ratio=4,
|
|
421
457
|
activation: nn.Module = nn.ReLU,
|
|
422
458
|
activation_kwargs: Optional[dict] = None,
|
|
423
459
|
ff_linear_module_up=None,
|
|
424
460
|
ff_linear_module_down=None,
|
|
425
461
|
msa_scaling="d",
|
|
426
|
-
|
|
462
|
+
ff_dropout=0.0,
|
|
463
|
+
ff_inner_dropout=0.0,
|
|
464
|
+
ff_outer_dropout=0.0,
|
|
427
465
|
msa_dropout=0.0,
|
|
428
466
|
identity_probability=0.0,
|
|
429
467
|
causal=False,
|
|
@@ -474,7 +512,7 @@ class TransformerBlock(nn.Module):
|
|
|
474
512
|
linear_module=linear_module,
|
|
475
513
|
rotary_embedding=self.rotary_embedding,
|
|
476
514
|
source_size=source_size,
|
|
477
|
-
|
|
515
|
+
utility_tokens=utility_tokens,
|
|
478
516
|
scaling=msa_scaling,
|
|
479
517
|
)
|
|
480
518
|
|
|
@@ -485,7 +523,9 @@ class TransformerBlock(nn.Module):
|
|
|
485
523
|
d_model,
|
|
486
524
|
activation=activation,
|
|
487
525
|
activation_kwargs=activation_kwargs,
|
|
488
|
-
dropout=
|
|
526
|
+
dropout=ff_dropout,
|
|
527
|
+
inner_dropout=ff_inner_dropout,
|
|
528
|
+
outer_dropout=ff_outer_dropout,
|
|
489
529
|
linear_module_up=(
|
|
490
530
|
ff_linear_module_up
|
|
491
531
|
if ff_linear_module_up is not None
|
|
@@ -529,15 +569,15 @@ class TransformerBlock(nn.Module):
|
|
|
529
569
|
|
|
530
570
|
return x
|
|
531
571
|
|
|
532
|
-
def
|
|
572
|
+
def attention_logits(self, x):
|
|
533
573
|
"""
|
|
534
574
|
Give back the attention scores used in this layer.
|
|
535
575
|
"""
|
|
536
576
|
if self.pre_norm:
|
|
537
577
|
x = self.layer_norm_1(x)
|
|
538
|
-
return self.attn.
|
|
578
|
+
return self.attn.attention_logits(x, x, x)
|
|
539
579
|
else:
|
|
540
|
-
return self.attn.
|
|
580
|
+
return self.attn.attention_logits(x, x, x)
|
|
541
581
|
|
|
542
582
|
def reset_parameters(self):
|
|
543
583
|
self.layer_norm_1.reset_parameters()
|
|
@@ -568,13 +608,15 @@ class TransformerEncoder(nn.Module):
|
|
|
568
608
|
activation_kwargs: Optional[dict] = None,
|
|
569
609
|
ff_linear_module_up=None,
|
|
570
610
|
ff_linear_module_down=None,
|
|
571
|
-
|
|
611
|
+
ff_dropout=0.0,
|
|
612
|
+
ff_inner_dropout=0.0,
|
|
613
|
+
ff_outer_dropout=0.0,
|
|
572
614
|
msa_dropout=0.0,
|
|
573
615
|
stochastic_depth=0.0,
|
|
574
616
|
causal=False,
|
|
575
617
|
linear_module=nn.Linear,
|
|
576
|
-
|
|
577
|
-
|
|
618
|
+
utility_tokens=0,
|
|
619
|
+
return_utility_tokens=False,
|
|
578
620
|
pre_norm=True,
|
|
579
621
|
post_norm=False,
|
|
580
622
|
normformer=False,
|
|
@@ -592,22 +634,33 @@ class TransformerEncoder(nn.Module):
|
|
|
592
634
|
if relative_position_embedding and (source_size is None):
|
|
593
635
|
raise ValueError(
|
|
594
636
|
"`source_size` for TransformerEncoder cannot be None if"
|
|
595
|
-
" `
|
|
637
|
+
" `relative_position_embedding` is True"
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
if absolute_position_embedding and (seq_len is None):
|
|
641
|
+
raise ValueError(
|
|
642
|
+
"`seq_len` for TransformerEncoder cannot be None if"
|
|
643
|
+
" `absolute_position_embedding` is True"
|
|
596
644
|
)
|
|
597
645
|
|
|
598
646
|
super().__init__()
|
|
599
647
|
self.seq_len = seq_len
|
|
600
648
|
self.n_heads = n_heads
|
|
601
|
-
self.
|
|
602
|
-
self.
|
|
603
|
-
|
|
604
|
-
# Initialise
|
|
605
|
-
if self.
|
|
606
|
-
self.
|
|
607
|
-
|
|
608
|
-
|
|
649
|
+
self._utility_tokens = utility_tokens
|
|
650
|
+
self.return_utility_tokens = return_utility_tokens
|
|
651
|
+
|
|
652
|
+
# Initialise utility tokens with normal init, like usual Pytorch embeddings
|
|
653
|
+
if self._utility_tokens:
|
|
654
|
+
self._utility_token_embedding = nn.Parameter(
|
|
655
|
+
torch.empty(self._utility_tokens, d_model)
|
|
656
|
+
)
|
|
657
|
+
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
658
|
+
else:
|
|
659
|
+
self._utility_token_embedding = None
|
|
660
|
+
|
|
661
|
+
if self._utility_tokens and (self.seq_len is not None):
|
|
662
|
+
self.full_sequence_length = self.seq_len + self._utility_tokens
|
|
609
663
|
else:
|
|
610
|
-
self._bos_embedding = None
|
|
611
664
|
self.full_sequence_length = self.seq_len
|
|
612
665
|
|
|
613
666
|
self.d_model = d_model
|
|
@@ -619,7 +672,7 @@ class TransformerEncoder(nn.Module):
|
|
|
619
672
|
else:
|
|
620
673
|
self.absolute_position_embedding = None
|
|
621
674
|
|
|
622
|
-
self.mlp_dropout =
|
|
675
|
+
self.mlp_dropout = ff_dropout
|
|
623
676
|
self.msa_dropout = msa_dropout
|
|
624
677
|
self.stochastic_depth = stochastic_depth
|
|
625
678
|
|
|
@@ -641,14 +694,16 @@ class TransformerEncoder(nn.Module):
|
|
|
641
694
|
n_heads,
|
|
642
695
|
relative_position_embedding=relative_position_embedding,
|
|
643
696
|
source_size=source_size,
|
|
644
|
-
|
|
697
|
+
utility_tokens=utility_tokens,
|
|
645
698
|
mlp_ratio=mlp_ratio,
|
|
646
699
|
activation=activation,
|
|
647
700
|
activation_kwargs=activation_kwargs,
|
|
648
701
|
ff_linear_module_up=ff_linear_module_up,
|
|
649
702
|
ff_linear_module_down=ff_linear_module_down,
|
|
650
703
|
msa_scaling=msa_scaling,
|
|
651
|
-
|
|
704
|
+
ff_dropout=ff_dropout,
|
|
705
|
+
ff_inner_dropout=ff_inner_dropout,
|
|
706
|
+
ff_outer_dropout=ff_outer_dropout,
|
|
652
707
|
msa_dropout=msa_dropout,
|
|
653
708
|
identity_probability=self.stochastic_depth_probabilities[i],
|
|
654
709
|
causal=causal,
|
|
@@ -669,8 +724,10 @@ class TransformerEncoder(nn.Module):
|
|
|
669
724
|
return ",".join([str(block._kv_distance) for block in self.blocks])
|
|
670
725
|
|
|
671
726
|
def preprocess(self, x):
|
|
672
|
-
if self.
|
|
673
|
-
x = torch.cat(
|
|
727
|
+
if self._utility_tokens:
|
|
728
|
+
x = torch.cat(
|
|
729
|
+
[self._utility_token_embedding.expand(x.size(0), -1, -1), x], dim=1
|
|
730
|
+
)
|
|
674
731
|
else:
|
|
675
732
|
x = x
|
|
676
733
|
|
|
@@ -692,12 +749,12 @@ class TransformerEncoder(nn.Module):
|
|
|
692
749
|
for block in self.blocks:
|
|
693
750
|
x = block(x)
|
|
694
751
|
|
|
695
|
-
if self.
|
|
696
|
-
return x[:, self.
|
|
752
|
+
if self._utility_tokens and not self.return_utility_tokens:
|
|
753
|
+
return x[:, self._utility_tokens :, :]
|
|
697
754
|
else:
|
|
698
755
|
return x
|
|
699
756
|
|
|
700
|
-
def
|
|
757
|
+
def attention_logits(self, x):
|
|
701
758
|
|
|
702
759
|
x = self.preprocess(x)
|
|
703
760
|
|
|
@@ -705,15 +762,15 @@ class TransformerEncoder(nn.Module):
|
|
|
705
762
|
|
|
706
763
|
for block in self.blocks:
|
|
707
764
|
# Get attention scores with shape (batch, 1, head, seq_len, seq_len)
|
|
708
|
-
|
|
709
|
-
layer_scores.append(
|
|
765
|
+
layer_attention_logits = block.attention_logits(x).unsqueeze(1)
|
|
766
|
+
layer_scores.append(layer_attention_logits)
|
|
710
767
|
x = block(x)
|
|
711
768
|
|
|
712
769
|
return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
|
|
713
770
|
|
|
714
771
|
def reset_parameters(self):
|
|
715
|
-
if self.
|
|
716
|
-
nn.init.normal_(self.
|
|
772
|
+
if self._utility_token_embedding is not None:
|
|
773
|
+
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
717
774
|
|
|
718
775
|
if self.absolute_position_embedding is not None:
|
|
719
776
|
self.absolute_position_embedding.reset_parameters()
|
|
@@ -11,7 +11,6 @@ from einops.layers.torch import Rearrange
|
|
|
11
11
|
|
|
12
12
|
import torch
|
|
13
13
|
import torch.nn as nn
|
|
14
|
-
import torch.nn.functional as F
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
class GetCLSToken(nn.Module):
|
|
@@ -39,6 +38,9 @@ class SequencePool(nn.Module):
|
|
|
39
38
|
weights = self.attention(x)
|
|
40
39
|
return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
|
|
41
40
|
|
|
41
|
+
def attention_scores(self, x):
|
|
42
|
+
return self.attention(x)
|
|
43
|
+
|
|
42
44
|
def reset_parameters(self):
|
|
43
45
|
# Iterate over modules in the sequential block
|
|
44
46
|
for module in self.attention:
|
|
@@ -159,7 +161,9 @@ class ViTEncoder(nn.Module):
|
|
|
159
161
|
transformer_initial_ff_residual_path=True,
|
|
160
162
|
transformer_initial_ff_linear_module_up=None,
|
|
161
163
|
transformer_initial_ff_linear_module_down=None,
|
|
162
|
-
|
|
164
|
+
transformer_initial_ff_dropout=None,
|
|
165
|
+
transformer_initial_ff_inner_dropout=None,
|
|
166
|
+
transformer_initial_ff_outer_dropout=None,
|
|
163
167
|
transformer_pre_norm=True,
|
|
164
168
|
transformer_normformer=False,
|
|
165
169
|
transformer_post_norm=False,
|
|
@@ -169,14 +173,16 @@ class ViTEncoder(nn.Module):
|
|
|
169
173
|
transformer_layers=7,
|
|
170
174
|
transformer_heads=4,
|
|
171
175
|
transformer_mlp_ratio=2,
|
|
172
|
-
|
|
173
|
-
|
|
176
|
+
transformer_utility_tokens=0,
|
|
177
|
+
transformer_return_utility_tokens=False,
|
|
174
178
|
transformer_activation: nn.Module = SquaredReLU,
|
|
175
179
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
176
180
|
transformer_ff_linear_module_up=None,
|
|
177
181
|
transformer_ff_linear_module_down=None,
|
|
178
182
|
transformer_msa_scaling="d",
|
|
179
|
-
|
|
183
|
+
transformer_ff_dropout=0.0,
|
|
184
|
+
transformer_ff_inner_dropout=0.0,
|
|
185
|
+
transformer_ff_outer_dropout=0.0,
|
|
180
186
|
transformer_msa_dropout=0.1,
|
|
181
187
|
transformer_stochastic_depth=0.1,
|
|
182
188
|
transformer_checkpoint_ff=True,
|
|
@@ -331,13 +337,15 @@ class ViTEncoder(nn.Module):
|
|
|
331
337
|
ff_linear_module_up=transformer_ff_linear_module_up,
|
|
332
338
|
ff_linear_module_down=transformer_ff_linear_module_down,
|
|
333
339
|
msa_scaling=transformer_msa_scaling,
|
|
334
|
-
|
|
340
|
+
ff_dropout=transformer_ff_dropout,
|
|
341
|
+
ff_inner_dropout=transformer_ff_inner_dropout,
|
|
342
|
+
ff_outer_dropout=transformer_ff_outer_dropout,
|
|
335
343
|
msa_dropout=transformer_msa_dropout,
|
|
336
344
|
stochastic_depth=transformer_stochastic_depth,
|
|
337
345
|
causal=False,
|
|
338
346
|
linear_module=linear_module,
|
|
339
|
-
|
|
340
|
-
|
|
347
|
+
utility_tokens=transformer_utility_tokens,
|
|
348
|
+
return_utility_tokens=transformer_return_utility_tokens,
|
|
341
349
|
pre_norm=transformer_pre_norm,
|
|
342
350
|
normformer=transformer_normformer,
|
|
343
351
|
post_norm=transformer_post_norm,
|
|
@@ -355,9 +363,21 @@ class ViTEncoder(nn.Module):
|
|
|
355
363
|
activation_kwargs=transformer_activation_kwargs,
|
|
356
364
|
dropout=(
|
|
357
365
|
# First truthy assigned value
|
|
358
|
-
|
|
359
|
-
if
|
|
360
|
-
else
|
|
366
|
+
transformer_initial_ff_dropout
|
|
367
|
+
if transformer_initial_ff_dropout is not None
|
|
368
|
+
else transformer_ff_dropout
|
|
369
|
+
),
|
|
370
|
+
inner_dropout=(
|
|
371
|
+
# First truthy assigned value
|
|
372
|
+
transformer_initial_ff_inner_dropout
|
|
373
|
+
if transformer_initial_ff_inner_dropout is not None
|
|
374
|
+
else transformer_ff_inner_dropout
|
|
375
|
+
),
|
|
376
|
+
outer_dropout=(
|
|
377
|
+
# First truthy assigned value
|
|
378
|
+
transformer_initial_ff_outer_dropout
|
|
379
|
+
if transformer_initial_ff_outer_dropout is not None
|
|
380
|
+
else transformer_ff_outer_dropout
|
|
361
381
|
),
|
|
362
382
|
linear_module_up=(
|
|
363
383
|
# First truthy assigned value
|
|
@@ -400,9 +420,9 @@ class ViTEncoder(nn.Module):
|
|
|
400
420
|
def forward(self, x):
|
|
401
421
|
return self.encoder(x)
|
|
402
422
|
|
|
403
|
-
def
|
|
423
|
+
def attention_logits(self, x):
|
|
404
424
|
x = self.encoder[:-1](x)
|
|
405
|
-
return self.encoder[-1].
|
|
425
|
+
return self.encoder[-1].attention_logits(x)
|
|
406
426
|
|
|
407
427
|
def reset_parameters(self):
|
|
408
428
|
for module in self.encoder:
|
|
@@ -439,7 +459,9 @@ class ViT(nn.Module):
|
|
|
439
459
|
transformer_initial_ff_residual_path=True,
|
|
440
460
|
transformer_initial_ff_linear_module_up=None,
|
|
441
461
|
transformer_initial_ff_linear_module_down=None,
|
|
442
|
-
|
|
462
|
+
transformer_initial_ff_dropout=None,
|
|
463
|
+
transformer_initial_ff_inner_dropout=None,
|
|
464
|
+
transformer_initial_ff_outer_dropout=None,
|
|
443
465
|
transformer_pre_norm=True,
|
|
444
466
|
transformer_normformer=False,
|
|
445
467
|
transformer_post_norm=False,
|
|
@@ -449,14 +471,16 @@ class ViT(nn.Module):
|
|
|
449
471
|
transformer_layers=7,
|
|
450
472
|
transformer_heads=4,
|
|
451
473
|
transformer_mlp_ratio=2,
|
|
452
|
-
|
|
453
|
-
|
|
474
|
+
transformer_utility_tokens=0,
|
|
475
|
+
transformer_return_utility_tokens=False,
|
|
454
476
|
transformer_activation: nn.Module = SquaredReLU,
|
|
455
477
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
456
478
|
transformer_ff_linear_module_up=None,
|
|
457
479
|
transformer_ff_linear_module_down=None,
|
|
458
480
|
transformer_msa_scaling="d",
|
|
459
|
-
|
|
481
|
+
transformer_ff_dropout=0.0,
|
|
482
|
+
transformer_ff_inner_dropout=0.0,
|
|
483
|
+
transformer_ff_outer_dropout=0.0,
|
|
460
484
|
transformer_msa_dropout=0.1,
|
|
461
485
|
transformer_stochastic_depth=0.1,
|
|
462
486
|
transformer_checkpoint_ff=True,
|
|
@@ -506,7 +530,9 @@ class ViT(nn.Module):
|
|
|
506
530
|
transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
|
|
507
531
|
transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
|
|
508
532
|
transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
|
|
509
|
-
|
|
533
|
+
transformer_initial_ff_dropout=transformer_initial_ff_dropout,
|
|
534
|
+
transformer_initial_ff_inner_dropout=transformer_initial_ff_inner_dropout,
|
|
535
|
+
transformer_initial_ff_outer_dropout=transformer_initial_ff_outer_dropout,
|
|
510
536
|
transformer_pre_norm=transformer_pre_norm,
|
|
511
537
|
transformer_normformer=transformer_normformer,
|
|
512
538
|
transformer_post_norm=transformer_post_norm,
|
|
@@ -516,14 +542,16 @@ class ViT(nn.Module):
|
|
|
516
542
|
transformer_layers=transformer_layers,
|
|
517
543
|
transformer_heads=transformer_heads,
|
|
518
544
|
transformer_mlp_ratio=transformer_mlp_ratio,
|
|
519
|
-
|
|
520
|
-
|
|
545
|
+
transformer_utility_tokens=transformer_utility_tokens,
|
|
546
|
+
transformer_return_utility_tokens=transformer_return_utility_tokens,
|
|
521
547
|
transformer_activation=transformer_activation,
|
|
522
548
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
|
523
549
|
transformer_ff_linear_module_up=transformer_ff_linear_module_up,
|
|
524
550
|
transformer_ff_linear_module_down=transformer_ff_linear_module_down,
|
|
525
551
|
transformer_msa_scaling=transformer_msa_scaling,
|
|
526
|
-
|
|
552
|
+
transformer_ff_dropout=transformer_ff_dropout,
|
|
553
|
+
transformer_ff_inner_dropout=transformer_ff_inner_dropout,
|
|
554
|
+
transformer_ff_outer_dropout=transformer_ff_outer_dropout,
|
|
527
555
|
transformer_msa_dropout=transformer_msa_dropout,
|
|
528
556
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
529
557
|
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
@@ -546,16 +574,26 @@ class ViT(nn.Module):
|
|
|
546
574
|
def forward(self, x):
|
|
547
575
|
return self.pool(self.encoder(x))
|
|
548
576
|
|
|
549
|
-
def
|
|
550
|
-
return self.encoder.
|
|
577
|
+
def attention_logits(self, x):
|
|
578
|
+
return self.encoder.attention_logits(x)
|
|
579
|
+
|
|
580
|
+
def pool_attention(self, x):
|
|
581
|
+
if hasattr(self.pool.summarize, "attention"):
|
|
582
|
+
return self.pool.summarize.attention(self.encoder(x))
|
|
583
|
+
else:
|
|
584
|
+
raise NotImplementedError(
|
|
585
|
+
"`pool_attention` is currently only implemented where"
|
|
586
|
+
" head class is SequencePoolClassificationHead"
|
|
587
|
+
)
|
|
551
588
|
|
|
552
|
-
def
|
|
553
|
-
all_attention = self.
|
|
589
|
+
def head_to_utility_token_attention_logits(self, x):
|
|
590
|
+
all_attention = self.attention_logits(x)
|
|
554
591
|
batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
|
|
555
592
|
sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
593
|
+
n_utility_tokens = self.encoder.encoder[-1]._utility_tokens
|
|
594
|
+
return sequence_averages[
|
|
595
|
+
:, :, :n_utility_tokens
|
|
596
|
+
] # (layer, head, utility_tokens)
|
|
559
597
|
|
|
560
598
|
def reset_parameters(self):
|
|
561
599
|
self.encoder.reset_parameters()
|
|
@@ -1,138 +0,0 @@
|
|
|
1
|
-
import math
|
|
2
|
-
import torch
|
|
3
|
-
from torch import nn
|
|
4
|
-
from torch.nn import functional as F
|
|
5
|
-
|
|
6
|
-
from .tensor import SigmaReparamTensor, AnchoredReparamTensor, NormReparamTensor
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class SpectralNormLinear(nn.Module):
|
|
10
|
-
"""
|
|
11
|
-
Inspired by Apple's Spectral Normed Linear Layers
|
|
12
|
-
(https://github.com/apple/ml-sigma-reparam)
|
|
13
|
-
"""
|
|
14
|
-
|
|
15
|
-
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
|
16
|
-
super().__init__()
|
|
17
|
-
self.in_features = in_features
|
|
18
|
-
self.out_features = out_features
|
|
19
|
-
self.use_bias = bias
|
|
20
|
-
|
|
21
|
-
self.weights = None
|
|
22
|
-
|
|
23
|
-
# Define the bias vector as a learnable parameter if required.
|
|
24
|
-
if self.use_bias:
|
|
25
|
-
self.bias = nn.Parameter(torch.empty(out_features))
|
|
26
|
-
else:
|
|
27
|
-
# If no bias, register it as None.
|
|
28
|
-
# This is important so that PyTorch doesn't complain when saving/loading the model.
|
|
29
|
-
self.register_parameter("bias", None)
|
|
30
|
-
|
|
31
|
-
self.reset_parameters()
|
|
32
|
-
|
|
33
|
-
def reset_parameters(self) -> None:
|
|
34
|
-
weights = torch.empty(self.out_features, self.in_features)
|
|
35
|
-
stdv = 1.0 / math.sqrt(self.in_features)
|
|
36
|
-
nn.init.uniform_(weights, a=-stdv, b=stdv)
|
|
37
|
-
if self.use_bias:
|
|
38
|
-
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
|
39
|
-
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
40
|
-
nn.init.uniform_(self.bias, -bound, bound)
|
|
41
|
-
self.weights = SigmaReparamTensor(weights)
|
|
42
|
-
|
|
43
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
44
|
-
return F.linear(x, self.weights(), self.bias)
|
|
45
|
-
|
|
46
|
-
def __repr__(self) -> str:
|
|
47
|
-
# Optional: A nice representation for printing the module.
|
|
48
|
-
return (
|
|
49
|
-
f"SpectralNormFeedForward(in_features={self.in_features},"
|
|
50
|
-
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
class AnchoredLinear(nn.Module):
|
|
55
|
-
"""
|
|
56
|
-
...
|
|
57
|
-
"""
|
|
58
|
-
|
|
59
|
-
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
|
60
|
-
super().__init__()
|
|
61
|
-
self.in_features = in_features
|
|
62
|
-
self.out_features = out_features
|
|
63
|
-
self.use_bias = bias
|
|
64
|
-
|
|
65
|
-
self.weights = None
|
|
66
|
-
|
|
67
|
-
# Define the bias vector as a learnable parameter if required.
|
|
68
|
-
if self.use_bias:
|
|
69
|
-
self.bias = nn.Parameter(torch.empty(out_features))
|
|
70
|
-
else:
|
|
71
|
-
# If no bias, register it as None.
|
|
72
|
-
# This is important so that PyTorch doesn't complain when saving/loading the model.
|
|
73
|
-
self.register_parameter("bias", None)
|
|
74
|
-
|
|
75
|
-
self.reset_parameters()
|
|
76
|
-
|
|
77
|
-
def reset_parameters(self) -> None:
|
|
78
|
-
weights = torch.empty(self.out_features, self.in_features)
|
|
79
|
-
stdv = 1.0 / math.sqrt(self.in_features)
|
|
80
|
-
nn.init.uniform_(weights, a=-stdv, b=stdv)
|
|
81
|
-
if self.use_bias:
|
|
82
|
-
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
|
83
|
-
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
84
|
-
nn.init.uniform_(self.bias, -bound, bound)
|
|
85
|
-
self.weights = AnchoredReparamTensor(weights)
|
|
86
|
-
|
|
87
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
88
|
-
return F.linear(x, self.weights(), self.bias)
|
|
89
|
-
|
|
90
|
-
def __repr__(self) -> str:
|
|
91
|
-
# Optional: A nice representation for printing the module.
|
|
92
|
-
return (
|
|
93
|
-
f"AnchoredLinear(in_features={self.in_features},"
|
|
94
|
-
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
class WeightNormedLinear(nn.Module):
|
|
99
|
-
"""
|
|
100
|
-
...
|
|
101
|
-
"""
|
|
102
|
-
|
|
103
|
-
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
|
104
|
-
super().__init__()
|
|
105
|
-
self.in_features = in_features
|
|
106
|
-
self.out_features = out_features
|
|
107
|
-
self.use_bias = bias
|
|
108
|
-
|
|
109
|
-
self.weights = None
|
|
110
|
-
|
|
111
|
-
# Define the bias vector as a learnable parameter if required.
|
|
112
|
-
if self.use_bias:
|
|
113
|
-
self.bias = nn.Parameter(torch.empty(out_features))
|
|
114
|
-
else:
|
|
115
|
-
# If no bias, register it as None.
|
|
116
|
-
# This is important so that PyTorch doesn't complain when saving/loading the model.
|
|
117
|
-
self.register_parameter("bias", None)
|
|
118
|
-
|
|
119
|
-
self.reset_parameters()
|
|
120
|
-
|
|
121
|
-
def reset_parameters(self) -> None:
|
|
122
|
-
weights = torch.empty(self.out_features, self.in_features)
|
|
123
|
-
stdv = 1.0 / math.sqrt(self.in_features)
|
|
124
|
-
nn.init.uniform_(weights, a=-stdv, b=stdv)
|
|
125
|
-
if self.use_bias:
|
|
126
|
-
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
|
127
|
-
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
128
|
-
nn.init.uniform_(self.bias, -bound, bound)
|
|
129
|
-
self.weights = NormReparamTensor(weights)
|
|
130
|
-
|
|
131
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
132
|
-
return F.linear(x, self.weights(), self.bias)
|
|
133
|
-
|
|
134
|
-
def __repr__(self) -> str:
|
|
135
|
-
return (
|
|
136
|
-
f"WeightNormedLinear(in_features={self.in_features},"
|
|
137
|
-
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
138
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|