broccoli-ml 7.0.0__tar.gz → 9.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-7.0.0 → broccoli_ml-9.6.0}/PKG-INFO +1 -1
- {broccoli_ml-7.0.0 → broccoli_ml-9.6.0}/broccoli/activation.py +1 -4
- broccoli_ml-9.6.0/broccoli/linear.py +368 -0
- {broccoli_ml-7.0.0 → broccoli_ml-9.6.0}/broccoli/transformer.py +57 -12
- {broccoli_ml-7.0.0 → broccoli_ml-9.6.0}/broccoli/vit.py +36 -10
- {broccoli_ml-7.0.0 → broccoli_ml-9.6.0}/pyproject.toml +1 -1
- broccoli_ml-7.0.0/broccoli/linear.py +0 -138
- {broccoli_ml-7.0.0 → broccoli_ml-9.6.0}/LICENSE +0 -0
- {broccoli_ml-7.0.0 → broccoli_ml-9.6.0}/README.md +0 -0
- {broccoli_ml-7.0.0 → broccoli_ml-9.6.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-7.0.0 → broccoli_ml-9.6.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-7.0.0 → broccoli_ml-9.6.0}/broccoli/rope.py +0 -0
- {broccoli_ml-7.0.0 → broccoli_ml-9.6.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-7.0.0 → broccoli_ml-9.6.0}/broccoli/utils.py +0 -0
|
@@ -46,10 +46,7 @@ class GELU(nn.Module):
|
|
|
46
46
|
|
|
47
47
|
class Swish(nn.Module):
|
|
48
48
|
"""
|
|
49
|
-
Implementation of (beta)
|
|
50
|
-
(https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
|
|
51
|
-
|
|
52
|
-
Halves the incoming parameter count, which should be scaled up before input.
|
|
49
|
+
Implementation of (beta) Swish
|
|
53
50
|
"""
|
|
54
51
|
|
|
55
52
|
def __init__(self) -> None:
|
|
@@ -0,0 +1,368 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import random
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Union, List, Iterable
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torch.nn import functional as F
|
|
9
|
+
|
|
10
|
+
from .tensor import SigmaReparamTensor, AnchoredReparamTensor, NormReparamTensor
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SpectralNormLinear(nn.Module):
|
|
14
|
+
"""
|
|
15
|
+
Inspired by Apple's Spectral Normed Linear Layers
|
|
16
|
+
(https://github.com/apple/ml-sigma-reparam)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.in_features = in_features
|
|
22
|
+
self.out_features = out_features
|
|
23
|
+
self.use_bias = bias
|
|
24
|
+
|
|
25
|
+
self.weights = None
|
|
26
|
+
|
|
27
|
+
# Define the bias vector as a learnable parameter if required.
|
|
28
|
+
if self.use_bias:
|
|
29
|
+
self.bias = nn.Parameter(torch.empty(out_features))
|
|
30
|
+
else:
|
|
31
|
+
# If no bias, register it as None.
|
|
32
|
+
# This is important so that PyTorch doesn't complain when saving/loading the model.
|
|
33
|
+
self.register_parameter("bias", None)
|
|
34
|
+
|
|
35
|
+
self.reset_parameters()
|
|
36
|
+
|
|
37
|
+
def reset_parameters(self) -> None:
|
|
38
|
+
weights = torch.empty(self.out_features, self.in_features)
|
|
39
|
+
stdv = 1.0 / math.sqrt(self.in_features)
|
|
40
|
+
nn.init.uniform_(weights, a=-stdv, b=stdv)
|
|
41
|
+
if self.use_bias:
|
|
42
|
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
|
43
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
44
|
+
nn.init.uniform_(self.bias, -bound, bound)
|
|
45
|
+
self.weights = SigmaReparamTensor(weights)
|
|
46
|
+
|
|
47
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
48
|
+
return F.linear(x, self.weights(), self.bias)
|
|
49
|
+
|
|
50
|
+
def __repr__(self) -> str:
|
|
51
|
+
# Optional: A nice representation for printing the module.
|
|
52
|
+
return (
|
|
53
|
+
f"SpectralNormFeedForward(in_features={self.in_features},"
|
|
54
|
+
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class AnchoredLinear(nn.Module):
|
|
59
|
+
"""
|
|
60
|
+
...
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.in_features = in_features
|
|
66
|
+
self.out_features = out_features
|
|
67
|
+
self.use_bias = bias
|
|
68
|
+
|
|
69
|
+
self.weights = None
|
|
70
|
+
|
|
71
|
+
# Define the bias vector as a learnable parameter if required.
|
|
72
|
+
if self.use_bias:
|
|
73
|
+
self.bias = nn.Parameter(torch.empty(out_features))
|
|
74
|
+
else:
|
|
75
|
+
# If no bias, register it as None.
|
|
76
|
+
# This is important so that PyTorch doesn't complain when saving/loading the model.
|
|
77
|
+
self.register_parameter("bias", None)
|
|
78
|
+
|
|
79
|
+
self.reset_parameters()
|
|
80
|
+
|
|
81
|
+
def reset_parameters(self) -> None:
|
|
82
|
+
weights = torch.empty(self.out_features, self.in_features)
|
|
83
|
+
stdv = 1.0 / math.sqrt(self.in_features)
|
|
84
|
+
nn.init.uniform_(weights, a=-stdv, b=stdv)
|
|
85
|
+
if self.use_bias:
|
|
86
|
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
|
87
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
88
|
+
nn.init.uniform_(self.bias, -bound, bound)
|
|
89
|
+
self.weights = AnchoredReparamTensor(weights)
|
|
90
|
+
|
|
91
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
92
|
+
return F.linear(x, self.weights(), self.bias)
|
|
93
|
+
|
|
94
|
+
def __repr__(self) -> str:
|
|
95
|
+
# Optional: A nice representation for printing the module.
|
|
96
|
+
return (
|
|
97
|
+
f"AnchoredLinear(in_features={self.in_features},"
|
|
98
|
+
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class WeightNormedLinear(nn.Module):
|
|
103
|
+
"""
|
|
104
|
+
...
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
|
108
|
+
super().__init__()
|
|
109
|
+
self.in_features = in_features
|
|
110
|
+
self.out_features = out_features
|
|
111
|
+
self.use_bias = bias
|
|
112
|
+
|
|
113
|
+
self.weights = None
|
|
114
|
+
|
|
115
|
+
# Define the bias vector as a learnable parameter if required.
|
|
116
|
+
if self.use_bias:
|
|
117
|
+
self.bias = nn.Parameter(torch.empty(out_features))
|
|
118
|
+
else:
|
|
119
|
+
# If no bias, register it as None.
|
|
120
|
+
# This is important so that PyTorch doesn't complain when saving/loading the model.
|
|
121
|
+
self.register_parameter("bias", None)
|
|
122
|
+
|
|
123
|
+
self.reset_parameters()
|
|
124
|
+
|
|
125
|
+
def reset_parameters(self) -> None:
|
|
126
|
+
weights = torch.empty(self.out_features, self.in_features)
|
|
127
|
+
stdv = 1.0 / math.sqrt(self.in_features)
|
|
128
|
+
nn.init.uniform_(weights, a=-stdv, b=stdv)
|
|
129
|
+
if self.use_bias:
|
|
130
|
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
|
131
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
132
|
+
nn.init.uniform_(self.bias, -bound, bound)
|
|
133
|
+
self.weights = NormReparamTensor(weights)
|
|
134
|
+
|
|
135
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
136
|
+
return F.linear(x, self.weights(), self.bias)
|
|
137
|
+
|
|
138
|
+
def __repr__(self) -> str:
|
|
139
|
+
return (
|
|
140
|
+
f"WeightNormedLinear(in_features={self.in_features},"
|
|
141
|
+
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class RecyclingLinear(nn.Module):
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
in_features: int,
|
|
149
|
+
out_features: int,
|
|
150
|
+
bias: bool = True,
|
|
151
|
+
row_recycling_rate: float = 0.0,
|
|
152
|
+
column_recycling_rate: float = 0.0,
|
|
153
|
+
adaptive=False,
|
|
154
|
+
xglu=False,
|
|
155
|
+
):
|
|
156
|
+
super().__init__()
|
|
157
|
+
self.in_features = in_features
|
|
158
|
+
self.out_features = out_features
|
|
159
|
+
self.bias = bias
|
|
160
|
+
self.xglu = xglu
|
|
161
|
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
162
|
+
self.row_recycling_rate = row_recycling_rate
|
|
163
|
+
self.column_recycling_rate = column_recycling_rate
|
|
164
|
+
self.adaptive = adaptive
|
|
165
|
+
self.optimisers = []
|
|
166
|
+
self.initial_learning_rates = []
|
|
167
|
+
self._warned_about_registration = False
|
|
168
|
+
|
|
169
|
+
def register_optimiser(self, optimiser: torch.optim.Optimizer):
|
|
170
|
+
self.optimisers.append(optimiser)
|
|
171
|
+
self.initial_learning_rates.append(self._get_learning_rate(optimiser))
|
|
172
|
+
if self.initial_learning_rates[-1] == 0.0:
|
|
173
|
+
warnings.warn(
|
|
174
|
+
"Learning rate of registered optimiser was 0.0 - make sure "
|
|
175
|
+
"you haven't initialised a scheduler before registering the "
|
|
176
|
+
"optimiser",
|
|
177
|
+
stacklevel=2,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def _get_learning_rate(self, optimiser: torch.optim.Optimizer):
|
|
181
|
+
for group in optimiser.param_groups:
|
|
182
|
+
for param in group["params"]:
|
|
183
|
+
if param is self.linear.weight:
|
|
184
|
+
return group["lr"]
|
|
185
|
+
|
|
186
|
+
def _get_multiplier(self):
|
|
187
|
+
if not self.adaptive or not self.optimisers:
|
|
188
|
+
return 1.0
|
|
189
|
+
else:
|
|
190
|
+
init = self.initial_learning_rates
|
|
191
|
+
current = [self._get_learning_rate(o) for o in self.optimisers]
|
|
192
|
+
pairs = zip(current, init, strict=True)
|
|
193
|
+
multipliers = [a / b for a, b in pairs if b != 0.0]
|
|
194
|
+
return min(multipliers) if multipliers else 0.0
|
|
195
|
+
|
|
196
|
+
def reset_rows(self, indices, incoming_data=None):
|
|
197
|
+
"""
|
|
198
|
+
Resets rows.
|
|
199
|
+
If incoming_data is provided, resets to the centroid (mean) of that data.
|
|
200
|
+
If not, resets to the mean of existing weights.
|
|
201
|
+
"""
|
|
202
|
+
if not torch.is_tensor(indices):
|
|
203
|
+
idx_tensor = torch.as_tensor(
|
|
204
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
205
|
+
)
|
|
206
|
+
else:
|
|
207
|
+
idx_tensor = indices
|
|
208
|
+
|
|
209
|
+
if idx_tensor.numel() == 0:
|
|
210
|
+
return
|
|
211
|
+
|
|
212
|
+
if incoming_data is not None:
|
|
213
|
+
target_center = self._mean_input_weights(incoming_data)
|
|
214
|
+
else:
|
|
215
|
+
target_center = self._mean_value_weights()
|
|
216
|
+
|
|
217
|
+
target_center = target_center.expand(idx_tensor.size(0), -1)
|
|
218
|
+
|
|
219
|
+
if self.xglu:
|
|
220
|
+
gate_indices = idx_tensor
|
|
221
|
+
value_indices = idx_tensor + (self.linear.out_features // 2)
|
|
222
|
+
self._update_weights(gate_indices, 0, target_center, self.optimisers)
|
|
223
|
+
self._update_weights(value_indices, 0, target_center, self.optimisers)
|
|
224
|
+
else:
|
|
225
|
+
self._update_weights(idx_tensor, 0, target_center, self.optimisers)
|
|
226
|
+
|
|
227
|
+
def reset_columns(self, indices):
|
|
228
|
+
if not torch.is_tensor(indices):
|
|
229
|
+
idx_tensor = torch.as_tensor(
|
|
230
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
idx_tensor = indices
|
|
234
|
+
|
|
235
|
+
if idx_tensor.size(0):
|
|
236
|
+
random_weights = self._random_weights(
|
|
237
|
+
self.linear.weight.size(0), indices.size(0)
|
|
238
|
+
)
|
|
239
|
+
# Make random col weights quiet so they don't introduce loud noise...
|
|
240
|
+
# ...but not so quiet that FP16 zeros them and ruins symmetry breaking!
|
|
241
|
+
random_weights *= 0.1
|
|
242
|
+
self._update_weights(indices, 1, random_weights, self.optimisers) # dim
|
|
243
|
+
else:
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
def forward(self, x):
|
|
247
|
+
if self.training and self.optimisers:
|
|
248
|
+
self.reset_rows(self.get_reset_indices(0))
|
|
249
|
+
self.reset_columns(self.get_reset_indices(1))
|
|
250
|
+
elif self.training and not self._warned_about_registration:
|
|
251
|
+
warnings.warn(
|
|
252
|
+
"RecyclingLinear: No optimiser registered. Recycling disabled.",
|
|
253
|
+
stacklevel=2,
|
|
254
|
+
)
|
|
255
|
+
self._warned_about_registration = True
|
|
256
|
+
|
|
257
|
+
return self.linear(x)
|
|
258
|
+
|
|
259
|
+
def get_reset_indices(self, dim):
|
|
260
|
+
base_rate = self.row_recycling_rate if dim == 0 else self.column_recycling_rate
|
|
261
|
+
p = base_rate * self._get_multiplier()
|
|
262
|
+
if dim == 0:
|
|
263
|
+
if self.xglu:
|
|
264
|
+
sample_space = self.linear.out_features // 2
|
|
265
|
+
else:
|
|
266
|
+
sample_space = self.linear.out_features
|
|
267
|
+
elif dim == 1:
|
|
268
|
+
sample_space = self.linear.in_features
|
|
269
|
+
else:
|
|
270
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
271
|
+
|
|
272
|
+
# Sample the indices
|
|
273
|
+
probs = torch.rand(sample_space, device=self.linear.weight.device)
|
|
274
|
+
mask = probs < p
|
|
275
|
+
if mask.any():
|
|
276
|
+
return torch.nonzero(mask).squeeze(-1)
|
|
277
|
+
else:
|
|
278
|
+
return torch.tensor([], dtype=torch.long, device=self.linear.weight.device)
|
|
279
|
+
|
|
280
|
+
def _random_weights(self, rows, columns):
|
|
281
|
+
device = self.linear.weight.device
|
|
282
|
+
weights = self.linear.weight.data
|
|
283
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
284
|
+
random_weights = torch.rand(rows, columns, device=device)
|
|
285
|
+
random_weights -= 0.5 # Range [-0.5, +0.5]
|
|
286
|
+
random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
|
|
287
|
+
return random_weights
|
|
288
|
+
|
|
289
|
+
def _mean_input_weights(self, input):
|
|
290
|
+
reduce_dims = list(range(input.ndim - 1))
|
|
291
|
+
data_mean = input.detach().mean(dim=reduce_dims, keepdim=True)
|
|
292
|
+
|
|
293
|
+
weights = self.linear.weight.data
|
|
294
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
295
|
+
data_norm = data_mean.std() + 1e-6
|
|
296
|
+
scale_factor = stdv / data_norm
|
|
297
|
+
|
|
298
|
+
return data_mean * scale_factor
|
|
299
|
+
|
|
300
|
+
def _mean_value_weights(self):
|
|
301
|
+
"""
|
|
302
|
+
Only used when self.xglu
|
|
303
|
+
"""
|
|
304
|
+
weights = self.linear.weight.data
|
|
305
|
+
rows = weights.size(0)
|
|
306
|
+
if self.xglu:
|
|
307
|
+
return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
|
|
308
|
+
else:
|
|
309
|
+
return self.linear.weight.data.mean(dim=0, keepdim=True)
|
|
310
|
+
|
|
311
|
+
def _mean_gate_weights(self):
|
|
312
|
+
"""
|
|
313
|
+
Only used when self.xglu
|
|
314
|
+
"""
|
|
315
|
+
weights = self.linear.weight.data
|
|
316
|
+
rows = weights.size(0)
|
|
317
|
+
return self.linear.weight[: int(rows / 2)].data.mean(dim=0, keepdim=True)
|
|
318
|
+
|
|
319
|
+
def _update_weights(
|
|
320
|
+
self,
|
|
321
|
+
indices: Iterable[int],
|
|
322
|
+
dim: int,
|
|
323
|
+
data: torch.Tensor,
|
|
324
|
+
optimisers: Union[
|
|
325
|
+
List[torch.optim.Optimizer], torch.optim.Optimizer, None
|
|
326
|
+
] = None,
|
|
327
|
+
):
|
|
328
|
+
if optimisers is None:
|
|
329
|
+
optimisers = []
|
|
330
|
+
if not isinstance(optimisers, list):
|
|
331
|
+
optimisers = [optimisers]
|
|
332
|
+
|
|
333
|
+
if not torch.is_tensor(indices):
|
|
334
|
+
idx_tensor = torch.as_tensor(
|
|
335
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
336
|
+
)
|
|
337
|
+
else:
|
|
338
|
+
idx_tensor = indices
|
|
339
|
+
|
|
340
|
+
if idx_tensor.numel() == 0:
|
|
341
|
+
return
|
|
342
|
+
|
|
343
|
+
with torch.no_grad():
|
|
344
|
+
if dim == 0:
|
|
345
|
+
self.linear.weight.data[idx_tensor] = data
|
|
346
|
+
elif dim == 1:
|
|
347
|
+
self.linear.weight.data[:, idx_tensor] = data
|
|
348
|
+
else:
|
|
349
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
350
|
+
|
|
351
|
+
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=dim)
|
|
352
|
+
|
|
353
|
+
def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
|
|
354
|
+
"""
|
|
355
|
+
Zeroes out the optimizer state for the given indices in a single operation.
|
|
356
|
+
"""
|
|
357
|
+
for optimiser in optimisers:
|
|
358
|
+
if param not in optimiser.state:
|
|
359
|
+
continue
|
|
360
|
+
state = optimiser.state[param]
|
|
361
|
+
|
|
362
|
+
for _, buffer in state.items():
|
|
363
|
+
if torch.is_tensor(buffer) and buffer.shape == param.shape:
|
|
364
|
+
# Vectorized zeroing
|
|
365
|
+
if dim == 0:
|
|
366
|
+
buffer[idx_tensor] = 0.0
|
|
367
|
+
else:
|
|
368
|
+
buffer[:, idx_tensor] = 0.0
|
|
@@ -325,6 +325,8 @@ class FeedforwardBlock(nn.Module):
|
|
|
325
325
|
activation=nn.ReLU,
|
|
326
326
|
activation_kwargs=None,
|
|
327
327
|
dropout=0.0,
|
|
328
|
+
inner_dropout=None,
|
|
329
|
+
outer_dropout=None,
|
|
328
330
|
linear_module_up=nn.Linear,
|
|
329
331
|
linear_module_down=nn.Linear,
|
|
330
332
|
pre_norm=True,
|
|
@@ -338,6 +340,7 @@ class FeedforwardBlock(nn.Module):
|
|
|
338
340
|
self.checkpoint = checkpoint
|
|
339
341
|
self.residual_path = residual_path
|
|
340
342
|
self.post_norm = post_norm
|
|
343
|
+
self.xglu = activation.__name__.endswith("GLU")
|
|
341
344
|
|
|
342
345
|
if self.residual_path and (output_features < input_features):
|
|
343
346
|
raise ValueError(
|
|
@@ -354,29 +357,63 @@ class FeedforwardBlock(nn.Module):
|
|
|
354
357
|
else:
|
|
355
358
|
self.activation = activation()
|
|
356
359
|
|
|
357
|
-
self.
|
|
360
|
+
self.inner_dropout = nn.Dropout(
|
|
361
|
+
inner_dropout if inner_dropout is not None else dropout
|
|
362
|
+
)
|
|
363
|
+
self.outer_dropout = nn.Dropout(
|
|
364
|
+
outer_dropout if outer_dropout is not None else dropout
|
|
365
|
+
)
|
|
358
366
|
|
|
359
367
|
self.max_features = (
|
|
360
|
-
2 * ratio * output_features
|
|
361
|
-
if activation.__name__.endswith("GLU")
|
|
362
|
-
else ratio * output_features
|
|
368
|
+
2 * ratio * output_features if self.xglu else ratio * output_features
|
|
363
369
|
)
|
|
364
370
|
|
|
371
|
+
self.linear_in = linear_module_up(input_features, self.max_features)
|
|
372
|
+
self.linear_out = linear_module_down(ratio * output_features, output_features)
|
|
373
|
+
|
|
365
374
|
self.process = nn.Sequential(
|
|
366
375
|
*[
|
|
367
376
|
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
|
368
|
-
|
|
377
|
+
self.linear_in,
|
|
369
378
|
self.activation,
|
|
379
|
+
self.inner_dropout,
|
|
370
380
|
nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
|
|
371
|
-
|
|
372
|
-
self.
|
|
381
|
+
self.linear_out,
|
|
382
|
+
self.outer_dropout,
|
|
373
383
|
]
|
|
374
384
|
)
|
|
375
385
|
|
|
386
|
+
self.recycling_enabled = False
|
|
387
|
+
if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
|
|
388
|
+
self.linear_out, "column_recycling_rate"
|
|
389
|
+
):
|
|
390
|
+
self.recycling_enabled = True
|
|
391
|
+
self.master_recycling_rate = self.linear_in.row_recycling_rate
|
|
392
|
+
self.linear_in.row_recycling_rate = 0.0
|
|
393
|
+
self.linear_out.column_recycling_rate = 0.0
|
|
394
|
+
if (
|
|
395
|
+
hasattr(self.linear_in, "column_recycling_rate")
|
|
396
|
+
and self.linear_in.column_recycling_rate > 0
|
|
397
|
+
) or (
|
|
398
|
+
hasattr(self.linear_out, "row_recycling_rate")
|
|
399
|
+
and self.linear_out.row_recycling_rate > 0
|
|
400
|
+
):
|
|
401
|
+
raise NotImplementedError(
|
|
402
|
+
"At the moment this layer can only support recycling linear "
|
|
403
|
+
"layers if the in layer resets only rows and the out layer "
|
|
404
|
+
"resets only columns."
|
|
405
|
+
)
|
|
406
|
+
|
|
376
407
|
self.reset_parameters()
|
|
377
408
|
|
|
378
409
|
def forward(self, x):
|
|
379
410
|
|
|
411
|
+
# Recycle weights if using recycling linear layers
|
|
412
|
+
if self.training and self.recycling_enabled:
|
|
413
|
+
indices = self.linear_out.get_reset_indices(1)
|
|
414
|
+
self.linear_in.reset_rows(indices, incoming_data=x)
|
|
415
|
+
self.linear_out.reset_columns(indices)
|
|
416
|
+
|
|
380
417
|
if self.checkpoint:
|
|
381
418
|
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
382
419
|
else:
|
|
@@ -422,7 +459,9 @@ class TransformerBlock(nn.Module):
|
|
|
422
459
|
ff_linear_module_up=None,
|
|
423
460
|
ff_linear_module_down=None,
|
|
424
461
|
msa_scaling="d",
|
|
425
|
-
|
|
462
|
+
ff_dropout=0.0,
|
|
463
|
+
ff_inner_dropout=0.0,
|
|
464
|
+
ff_outer_dropout=0.0,
|
|
426
465
|
msa_dropout=0.0,
|
|
427
466
|
identity_probability=0.0,
|
|
428
467
|
causal=False,
|
|
@@ -484,7 +523,9 @@ class TransformerBlock(nn.Module):
|
|
|
484
523
|
d_model,
|
|
485
524
|
activation=activation,
|
|
486
525
|
activation_kwargs=activation_kwargs,
|
|
487
|
-
dropout=
|
|
526
|
+
dropout=ff_dropout,
|
|
527
|
+
inner_dropout=ff_inner_dropout,
|
|
528
|
+
outer_dropout=ff_outer_dropout,
|
|
488
529
|
linear_module_up=(
|
|
489
530
|
ff_linear_module_up
|
|
490
531
|
if ff_linear_module_up is not None
|
|
@@ -567,7 +608,9 @@ class TransformerEncoder(nn.Module):
|
|
|
567
608
|
activation_kwargs: Optional[dict] = None,
|
|
568
609
|
ff_linear_module_up=None,
|
|
569
610
|
ff_linear_module_down=None,
|
|
570
|
-
|
|
611
|
+
ff_dropout=0.0,
|
|
612
|
+
ff_inner_dropout=0.0,
|
|
613
|
+
ff_outer_dropout=0.0,
|
|
571
614
|
msa_dropout=0.0,
|
|
572
615
|
stochastic_depth=0.0,
|
|
573
616
|
causal=False,
|
|
@@ -629,7 +672,7 @@ class TransformerEncoder(nn.Module):
|
|
|
629
672
|
else:
|
|
630
673
|
self.absolute_position_embedding = None
|
|
631
674
|
|
|
632
|
-
self.mlp_dropout =
|
|
675
|
+
self.mlp_dropout = ff_dropout
|
|
633
676
|
self.msa_dropout = msa_dropout
|
|
634
677
|
self.stochastic_depth = stochastic_depth
|
|
635
678
|
|
|
@@ -658,7 +701,9 @@ class TransformerEncoder(nn.Module):
|
|
|
658
701
|
ff_linear_module_up=ff_linear_module_up,
|
|
659
702
|
ff_linear_module_down=ff_linear_module_down,
|
|
660
703
|
msa_scaling=msa_scaling,
|
|
661
|
-
|
|
704
|
+
ff_dropout=ff_dropout,
|
|
705
|
+
ff_inner_dropout=ff_inner_dropout,
|
|
706
|
+
ff_outer_dropout=ff_outer_dropout,
|
|
662
707
|
msa_dropout=msa_dropout,
|
|
663
708
|
identity_probability=self.stochastic_depth_probabilities[i],
|
|
664
709
|
causal=causal,
|
|
@@ -161,7 +161,9 @@ class ViTEncoder(nn.Module):
|
|
|
161
161
|
transformer_initial_ff_residual_path=True,
|
|
162
162
|
transformer_initial_ff_linear_module_up=None,
|
|
163
163
|
transformer_initial_ff_linear_module_down=None,
|
|
164
|
-
|
|
164
|
+
transformer_initial_ff_dropout=None,
|
|
165
|
+
transformer_initial_ff_inner_dropout=None,
|
|
166
|
+
transformer_initial_ff_outer_dropout=None,
|
|
165
167
|
transformer_pre_norm=True,
|
|
166
168
|
transformer_normformer=False,
|
|
167
169
|
transformer_post_norm=False,
|
|
@@ -178,7 +180,9 @@ class ViTEncoder(nn.Module):
|
|
|
178
180
|
transformer_ff_linear_module_up=None,
|
|
179
181
|
transformer_ff_linear_module_down=None,
|
|
180
182
|
transformer_msa_scaling="d",
|
|
181
|
-
|
|
183
|
+
transformer_ff_dropout=0.0,
|
|
184
|
+
transformer_ff_inner_dropout=0.0,
|
|
185
|
+
transformer_ff_outer_dropout=0.0,
|
|
182
186
|
transformer_msa_dropout=0.1,
|
|
183
187
|
transformer_stochastic_depth=0.1,
|
|
184
188
|
transformer_checkpoint_ff=True,
|
|
@@ -333,7 +337,9 @@ class ViTEncoder(nn.Module):
|
|
|
333
337
|
ff_linear_module_up=transformer_ff_linear_module_up,
|
|
334
338
|
ff_linear_module_down=transformer_ff_linear_module_down,
|
|
335
339
|
msa_scaling=transformer_msa_scaling,
|
|
336
|
-
|
|
340
|
+
ff_dropout=transformer_ff_dropout,
|
|
341
|
+
ff_inner_dropout=transformer_ff_inner_dropout,
|
|
342
|
+
ff_outer_dropout=transformer_ff_outer_dropout,
|
|
337
343
|
msa_dropout=transformer_msa_dropout,
|
|
338
344
|
stochastic_depth=transformer_stochastic_depth,
|
|
339
345
|
causal=False,
|
|
@@ -357,9 +363,21 @@ class ViTEncoder(nn.Module):
|
|
|
357
363
|
activation_kwargs=transformer_activation_kwargs,
|
|
358
364
|
dropout=(
|
|
359
365
|
# First truthy assigned value
|
|
360
|
-
|
|
361
|
-
if
|
|
362
|
-
else
|
|
366
|
+
transformer_initial_ff_dropout
|
|
367
|
+
if transformer_initial_ff_dropout is not None
|
|
368
|
+
else transformer_ff_dropout
|
|
369
|
+
),
|
|
370
|
+
inner_dropout=(
|
|
371
|
+
# First truthy assigned value
|
|
372
|
+
transformer_initial_ff_inner_dropout
|
|
373
|
+
if transformer_initial_ff_inner_dropout is not None
|
|
374
|
+
else transformer_ff_inner_dropout
|
|
375
|
+
),
|
|
376
|
+
outer_dropout=(
|
|
377
|
+
# First truthy assigned value
|
|
378
|
+
transformer_initial_ff_outer_dropout
|
|
379
|
+
if transformer_initial_ff_outer_dropout is not None
|
|
380
|
+
else transformer_ff_outer_dropout
|
|
363
381
|
),
|
|
364
382
|
linear_module_up=(
|
|
365
383
|
# First truthy assigned value
|
|
@@ -441,7 +459,9 @@ class ViT(nn.Module):
|
|
|
441
459
|
transformer_initial_ff_residual_path=True,
|
|
442
460
|
transformer_initial_ff_linear_module_up=None,
|
|
443
461
|
transformer_initial_ff_linear_module_down=None,
|
|
444
|
-
|
|
462
|
+
transformer_initial_ff_dropout=None,
|
|
463
|
+
transformer_initial_ff_inner_dropout=None,
|
|
464
|
+
transformer_initial_ff_outer_dropout=None,
|
|
445
465
|
transformer_pre_norm=True,
|
|
446
466
|
transformer_normformer=False,
|
|
447
467
|
transformer_post_norm=False,
|
|
@@ -458,7 +478,9 @@ class ViT(nn.Module):
|
|
|
458
478
|
transformer_ff_linear_module_up=None,
|
|
459
479
|
transformer_ff_linear_module_down=None,
|
|
460
480
|
transformer_msa_scaling="d",
|
|
461
|
-
|
|
481
|
+
transformer_ff_dropout=0.0,
|
|
482
|
+
transformer_ff_inner_dropout=0.0,
|
|
483
|
+
transformer_ff_outer_dropout=0.0,
|
|
462
484
|
transformer_msa_dropout=0.1,
|
|
463
485
|
transformer_stochastic_depth=0.1,
|
|
464
486
|
transformer_checkpoint_ff=True,
|
|
@@ -508,7 +530,9 @@ class ViT(nn.Module):
|
|
|
508
530
|
transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
|
|
509
531
|
transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
|
|
510
532
|
transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
|
|
511
|
-
|
|
533
|
+
transformer_initial_ff_dropout=transformer_initial_ff_dropout,
|
|
534
|
+
transformer_initial_ff_inner_dropout=transformer_initial_ff_inner_dropout,
|
|
535
|
+
transformer_initial_ff_outer_dropout=transformer_initial_ff_outer_dropout,
|
|
512
536
|
transformer_pre_norm=transformer_pre_norm,
|
|
513
537
|
transformer_normformer=transformer_normformer,
|
|
514
538
|
transformer_post_norm=transformer_post_norm,
|
|
@@ -525,7 +549,9 @@ class ViT(nn.Module):
|
|
|
525
549
|
transformer_ff_linear_module_up=transformer_ff_linear_module_up,
|
|
526
550
|
transformer_ff_linear_module_down=transformer_ff_linear_module_down,
|
|
527
551
|
transformer_msa_scaling=transformer_msa_scaling,
|
|
528
|
-
|
|
552
|
+
transformer_ff_dropout=transformer_ff_dropout,
|
|
553
|
+
transformer_ff_inner_dropout=transformer_ff_inner_dropout,
|
|
554
|
+
transformer_ff_outer_dropout=transformer_ff_outer_dropout,
|
|
529
555
|
transformer_msa_dropout=transformer_msa_dropout,
|
|
530
556
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
531
557
|
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
@@ -1,138 +0,0 @@
|
|
|
1
|
-
import math
|
|
2
|
-
import torch
|
|
3
|
-
from torch import nn
|
|
4
|
-
from torch.nn import functional as F
|
|
5
|
-
|
|
6
|
-
from .tensor import SigmaReparamTensor, AnchoredReparamTensor, NormReparamTensor
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class SpectralNormLinear(nn.Module):
|
|
10
|
-
"""
|
|
11
|
-
Inspired by Apple's Spectral Normed Linear Layers
|
|
12
|
-
(https://github.com/apple/ml-sigma-reparam)
|
|
13
|
-
"""
|
|
14
|
-
|
|
15
|
-
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
|
16
|
-
super().__init__()
|
|
17
|
-
self.in_features = in_features
|
|
18
|
-
self.out_features = out_features
|
|
19
|
-
self.use_bias = bias
|
|
20
|
-
|
|
21
|
-
self.weights = None
|
|
22
|
-
|
|
23
|
-
# Define the bias vector as a learnable parameter if required.
|
|
24
|
-
if self.use_bias:
|
|
25
|
-
self.bias = nn.Parameter(torch.empty(out_features))
|
|
26
|
-
else:
|
|
27
|
-
# If no bias, register it as None.
|
|
28
|
-
# This is important so that PyTorch doesn't complain when saving/loading the model.
|
|
29
|
-
self.register_parameter("bias", None)
|
|
30
|
-
|
|
31
|
-
self.reset_parameters()
|
|
32
|
-
|
|
33
|
-
def reset_parameters(self) -> None:
|
|
34
|
-
weights = torch.empty(self.out_features, self.in_features)
|
|
35
|
-
stdv = 1.0 / math.sqrt(self.in_features)
|
|
36
|
-
nn.init.uniform_(weights, a=-stdv, b=stdv)
|
|
37
|
-
if self.use_bias:
|
|
38
|
-
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
|
39
|
-
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
40
|
-
nn.init.uniform_(self.bias, -bound, bound)
|
|
41
|
-
self.weights = SigmaReparamTensor(weights)
|
|
42
|
-
|
|
43
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
44
|
-
return F.linear(x, self.weights(), self.bias)
|
|
45
|
-
|
|
46
|
-
def __repr__(self) -> str:
|
|
47
|
-
# Optional: A nice representation for printing the module.
|
|
48
|
-
return (
|
|
49
|
-
f"SpectralNormFeedForward(in_features={self.in_features},"
|
|
50
|
-
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
class AnchoredLinear(nn.Module):
|
|
55
|
-
"""
|
|
56
|
-
...
|
|
57
|
-
"""
|
|
58
|
-
|
|
59
|
-
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
|
60
|
-
super().__init__()
|
|
61
|
-
self.in_features = in_features
|
|
62
|
-
self.out_features = out_features
|
|
63
|
-
self.use_bias = bias
|
|
64
|
-
|
|
65
|
-
self.weights = None
|
|
66
|
-
|
|
67
|
-
# Define the bias vector as a learnable parameter if required.
|
|
68
|
-
if self.use_bias:
|
|
69
|
-
self.bias = nn.Parameter(torch.empty(out_features))
|
|
70
|
-
else:
|
|
71
|
-
# If no bias, register it as None.
|
|
72
|
-
# This is important so that PyTorch doesn't complain when saving/loading the model.
|
|
73
|
-
self.register_parameter("bias", None)
|
|
74
|
-
|
|
75
|
-
self.reset_parameters()
|
|
76
|
-
|
|
77
|
-
def reset_parameters(self) -> None:
|
|
78
|
-
weights = torch.empty(self.out_features, self.in_features)
|
|
79
|
-
stdv = 1.0 / math.sqrt(self.in_features)
|
|
80
|
-
nn.init.uniform_(weights, a=-stdv, b=stdv)
|
|
81
|
-
if self.use_bias:
|
|
82
|
-
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
|
83
|
-
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
84
|
-
nn.init.uniform_(self.bias, -bound, bound)
|
|
85
|
-
self.weights = AnchoredReparamTensor(weights)
|
|
86
|
-
|
|
87
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
88
|
-
return F.linear(x, self.weights(), self.bias)
|
|
89
|
-
|
|
90
|
-
def __repr__(self) -> str:
|
|
91
|
-
# Optional: A nice representation for printing the module.
|
|
92
|
-
return (
|
|
93
|
-
f"AnchoredLinear(in_features={self.in_features},"
|
|
94
|
-
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
class WeightNormedLinear(nn.Module):
|
|
99
|
-
"""
|
|
100
|
-
...
|
|
101
|
-
"""
|
|
102
|
-
|
|
103
|
-
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
|
104
|
-
super().__init__()
|
|
105
|
-
self.in_features = in_features
|
|
106
|
-
self.out_features = out_features
|
|
107
|
-
self.use_bias = bias
|
|
108
|
-
|
|
109
|
-
self.weights = None
|
|
110
|
-
|
|
111
|
-
# Define the bias vector as a learnable parameter if required.
|
|
112
|
-
if self.use_bias:
|
|
113
|
-
self.bias = nn.Parameter(torch.empty(out_features))
|
|
114
|
-
else:
|
|
115
|
-
# If no bias, register it as None.
|
|
116
|
-
# This is important so that PyTorch doesn't complain when saving/loading the model.
|
|
117
|
-
self.register_parameter("bias", None)
|
|
118
|
-
|
|
119
|
-
self.reset_parameters()
|
|
120
|
-
|
|
121
|
-
def reset_parameters(self) -> None:
|
|
122
|
-
weights = torch.empty(self.out_features, self.in_features)
|
|
123
|
-
stdv = 1.0 / math.sqrt(self.in_features)
|
|
124
|
-
nn.init.uniform_(weights, a=-stdv, b=stdv)
|
|
125
|
-
if self.use_bias:
|
|
126
|
-
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
|
127
|
-
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
128
|
-
nn.init.uniform_(self.bias, -bound, bound)
|
|
129
|
-
self.weights = NormReparamTensor(weights)
|
|
130
|
-
|
|
131
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
132
|
-
return F.linear(x, self.weights(), self.bias)
|
|
133
|
-
|
|
134
|
-
def __repr__(self) -> str:
|
|
135
|
-
return (
|
|
136
|
-
f"WeightNormedLinear(in_features={self.in_features},"
|
|
137
|
-
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
138
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|