broccoli-ml 6.0.0__py3-none-any.whl → 13.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/activation.py +1 -4
- broccoli/linear.py +230 -0
- broccoli/transformer.py +219 -61
- broccoli/vit.py +63 -22
- {broccoli_ml-6.0.0.dist-info → broccoli_ml-13.0.1.dist-info}/METADATA +1 -1
- broccoli_ml-13.0.1.dist-info/RECORD +13 -0
- broccoli_ml-6.0.0.dist-info/RECORD +0 -13
- {broccoli_ml-6.0.0.dist-info → broccoli_ml-13.0.1.dist-info}/LICENSE +0 -0
- {broccoli_ml-6.0.0.dist-info → broccoli_ml-13.0.1.dist-info}/WHEEL +0 -0
broccoli/activation.py
CHANGED
|
@@ -46,10 +46,7 @@ class GELU(nn.Module):
|
|
|
46
46
|
|
|
47
47
|
class Swish(nn.Module):
|
|
48
48
|
"""
|
|
49
|
-
Implementation of (beta)
|
|
50
|
-
(https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
|
|
51
|
-
|
|
52
|
-
Halves the incoming parameter count, which should be scaled up before input.
|
|
49
|
+
Implementation of (beta) Swish
|
|
53
50
|
"""
|
|
54
51
|
|
|
55
52
|
def __init__(self) -> None:
|
broccoli/linear.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
|
1
1
|
import math
|
|
2
|
+
import random
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Union, List, Iterable
|
|
5
|
+
|
|
2
6
|
import torch
|
|
3
7
|
from torch import nn
|
|
4
8
|
from torch.nn import functional as F
|
|
@@ -136,3 +140,229 @@ class WeightNormedLinear(nn.Module):
|
|
|
136
140
|
f"WeightNormedLinear(in_features={self.in_features},"
|
|
137
141
|
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
138
142
|
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class RecyclingLinear(nn.Module):
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
in_features: int,
|
|
149
|
+
out_features: int,
|
|
150
|
+
bias: bool = True,
|
|
151
|
+
row_recycling_rate: float = 0.0,
|
|
152
|
+
column_recycling_rate: float = 0.0,
|
|
153
|
+
adaptive=False,
|
|
154
|
+
xglu=False,
|
|
155
|
+
):
|
|
156
|
+
super().__init__()
|
|
157
|
+
self.in_features = in_features
|
|
158
|
+
self.out_features = out_features
|
|
159
|
+
self.bias = bias
|
|
160
|
+
self.xglu = xglu
|
|
161
|
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
162
|
+
self.row_recycling_rate = row_recycling_rate
|
|
163
|
+
self.column_recycling_rate = column_recycling_rate
|
|
164
|
+
self.adaptive = adaptive
|
|
165
|
+
self.optimisers = []
|
|
166
|
+
self.initial_learning_rates = []
|
|
167
|
+
self._warned_about_registration = False
|
|
168
|
+
|
|
169
|
+
def register_optimiser(self, optimiser: torch.optim.Optimizer):
|
|
170
|
+
self.optimisers.append(optimiser)
|
|
171
|
+
self.initial_learning_rates.append(self._get_learning_rate(optimiser))
|
|
172
|
+
if self.initial_learning_rates[-1] == 0.0:
|
|
173
|
+
warnings.warn(
|
|
174
|
+
"Learning rate of registered optimiser was 0.0 - make sure "
|
|
175
|
+
"you haven't initialised a scheduler before registering the "
|
|
176
|
+
"optimiser",
|
|
177
|
+
stacklevel=2,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def _get_learning_rate(self, optimiser: torch.optim.Optimizer):
|
|
181
|
+
for group in optimiser.param_groups:
|
|
182
|
+
for param in group["params"]:
|
|
183
|
+
if param is self.linear.weight:
|
|
184
|
+
return group["lr"]
|
|
185
|
+
|
|
186
|
+
def _get_multiplier(self):
|
|
187
|
+
if not self.adaptive or not self.optimisers:
|
|
188
|
+
return 1.0
|
|
189
|
+
else:
|
|
190
|
+
init = self.initial_learning_rates
|
|
191
|
+
current = [self._get_learning_rate(o) for o in self.optimisers]
|
|
192
|
+
pairs = zip(current, init, strict=True)
|
|
193
|
+
multipliers = [a / b for a, b in pairs if b != 0.0]
|
|
194
|
+
return min(multipliers) if multipliers else 0.0
|
|
195
|
+
|
|
196
|
+
def reset_rows(self, indices, incoming_data=None):
|
|
197
|
+
"""
|
|
198
|
+
Resets rows.
|
|
199
|
+
If incoming_data is provided, resets to the centroid (mean) of that data.
|
|
200
|
+
If not, resets to the mean of existing weights.
|
|
201
|
+
"""
|
|
202
|
+
if not torch.is_tensor(indices):
|
|
203
|
+
idx_tensor = torch.as_tensor(
|
|
204
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
205
|
+
)
|
|
206
|
+
else:
|
|
207
|
+
idx_tensor = indices
|
|
208
|
+
|
|
209
|
+
if idx_tensor.numel() == 0:
|
|
210
|
+
return
|
|
211
|
+
|
|
212
|
+
if incoming_data is not None:
|
|
213
|
+
target_center = self._mean_input_weights(incoming_data)
|
|
214
|
+
else:
|
|
215
|
+
target_center = self._mean_value_weights()
|
|
216
|
+
|
|
217
|
+
target_center = target_center.expand(idx_tensor.size(0), -1)
|
|
218
|
+
|
|
219
|
+
if self.xglu:
|
|
220
|
+
gate_indices = idx_tensor
|
|
221
|
+
value_indices = idx_tensor + (self.linear.out_features // 2)
|
|
222
|
+
self._update_weights(gate_indices, 0, target_center, self.optimisers)
|
|
223
|
+
self._update_weights(value_indices, 0, target_center, self.optimisers)
|
|
224
|
+
else:
|
|
225
|
+
self._update_weights(idx_tensor, 0, target_center, self.optimisers)
|
|
226
|
+
|
|
227
|
+
def reset_columns(self, indices):
|
|
228
|
+
if not torch.is_tensor(indices):
|
|
229
|
+
idx_tensor = torch.as_tensor(
|
|
230
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
idx_tensor = indices
|
|
234
|
+
|
|
235
|
+
if idx_tensor.size(0):
|
|
236
|
+
random_weights = self._random_weights(
|
|
237
|
+
self.linear.weight.size(0), indices.size(0)
|
|
238
|
+
)
|
|
239
|
+
# Make random col weights quiet so they don't introduce loud noise...
|
|
240
|
+
# ...but not so quiet that FP16 zeros them and ruins symmetry breaking!
|
|
241
|
+
random_weights *= 0.1
|
|
242
|
+
self._update_weights(indices, 1, random_weights, self.optimisers) # dim
|
|
243
|
+
else:
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
def forward(self, x):
|
|
247
|
+
if self.training and self.optimisers:
|
|
248
|
+
self.reset_rows(self.get_reset_indices(0))
|
|
249
|
+
self.reset_columns(self.get_reset_indices(1))
|
|
250
|
+
elif self.training and not self._warned_about_registration:
|
|
251
|
+
warnings.warn(
|
|
252
|
+
"RecyclingLinear: No optimiser registered. Recycling disabled.",
|
|
253
|
+
stacklevel=2,
|
|
254
|
+
)
|
|
255
|
+
self._warned_about_registration = True
|
|
256
|
+
|
|
257
|
+
return self.linear(x)
|
|
258
|
+
|
|
259
|
+
def get_reset_indices(self, dim):
|
|
260
|
+
base_rate = self.row_recycling_rate if dim == 0 else self.column_recycling_rate
|
|
261
|
+
p = base_rate * self._get_multiplier()
|
|
262
|
+
if dim == 0:
|
|
263
|
+
if self.xglu:
|
|
264
|
+
sample_space = self.linear.out_features // 2
|
|
265
|
+
else:
|
|
266
|
+
sample_space = self.linear.out_features
|
|
267
|
+
elif dim == 1:
|
|
268
|
+
sample_space = self.linear.in_features
|
|
269
|
+
else:
|
|
270
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
271
|
+
|
|
272
|
+
# Sample the indices
|
|
273
|
+
probs = torch.rand(sample_space, device=self.linear.weight.device)
|
|
274
|
+
mask = probs < p
|
|
275
|
+
if mask.any():
|
|
276
|
+
return torch.nonzero(mask).squeeze(-1)
|
|
277
|
+
else:
|
|
278
|
+
return torch.tensor([], dtype=torch.long, device=self.linear.weight.device)
|
|
279
|
+
|
|
280
|
+
def _random_weights(self, rows, columns):
|
|
281
|
+
device = self.linear.weight.device
|
|
282
|
+
weights = self.linear.weight.data
|
|
283
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
284
|
+
random_weights = torch.rand(rows, columns, device=device)
|
|
285
|
+
random_weights -= 0.5 # Range [-0.5, +0.5]
|
|
286
|
+
random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
|
|
287
|
+
return random_weights
|
|
288
|
+
|
|
289
|
+
def _mean_input_weights(self, input):
|
|
290
|
+
reduce_dims = list(range(input.ndim - 1))
|
|
291
|
+
data_mean = input.detach().mean(dim=reduce_dims, keepdim=True)
|
|
292
|
+
|
|
293
|
+
weights = self.linear.weight.data
|
|
294
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
295
|
+
data_norm = data_mean.std() + 1e-6
|
|
296
|
+
scale_factor = stdv / data_norm
|
|
297
|
+
|
|
298
|
+
return data_mean * scale_factor
|
|
299
|
+
|
|
300
|
+
def _mean_value_weights(self):
|
|
301
|
+
"""
|
|
302
|
+
Only used when self.xglu
|
|
303
|
+
"""
|
|
304
|
+
weights = self.linear.weight.data
|
|
305
|
+
rows = weights.size(0)
|
|
306
|
+
if self.xglu:
|
|
307
|
+
return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
|
|
308
|
+
else:
|
|
309
|
+
return self.linear.weight.data.mean(dim=0, keepdim=True)
|
|
310
|
+
|
|
311
|
+
def _mean_gate_weights(self):
|
|
312
|
+
"""
|
|
313
|
+
Only used when self.xglu
|
|
314
|
+
"""
|
|
315
|
+
weights = self.linear.weight.data
|
|
316
|
+
rows = weights.size(0)
|
|
317
|
+
return self.linear.weight[: int(rows / 2)].data.mean(dim=0, keepdim=True)
|
|
318
|
+
|
|
319
|
+
def _update_weights(
|
|
320
|
+
self,
|
|
321
|
+
indices: Iterable[int],
|
|
322
|
+
dim: int,
|
|
323
|
+
data: torch.Tensor,
|
|
324
|
+
optimisers: Union[
|
|
325
|
+
List[torch.optim.Optimizer], torch.optim.Optimizer, None
|
|
326
|
+
] = None,
|
|
327
|
+
):
|
|
328
|
+
if optimisers is None:
|
|
329
|
+
optimisers = []
|
|
330
|
+
if not isinstance(optimisers, list):
|
|
331
|
+
optimisers = [optimisers]
|
|
332
|
+
|
|
333
|
+
if not torch.is_tensor(indices):
|
|
334
|
+
idx_tensor = torch.as_tensor(
|
|
335
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
336
|
+
)
|
|
337
|
+
else:
|
|
338
|
+
idx_tensor = indices
|
|
339
|
+
|
|
340
|
+
if idx_tensor.numel() == 0:
|
|
341
|
+
return
|
|
342
|
+
|
|
343
|
+
with torch.no_grad():
|
|
344
|
+
if dim == 0:
|
|
345
|
+
self.linear.weight.data[idx_tensor] = data
|
|
346
|
+
elif dim == 1:
|
|
347
|
+
self.linear.weight.data[:, idx_tensor] = data
|
|
348
|
+
else:
|
|
349
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
350
|
+
|
|
351
|
+
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=dim)
|
|
352
|
+
|
|
353
|
+
def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
|
|
354
|
+
"""
|
|
355
|
+
Zeroes out the optimizer state for the given indices in a single operation.
|
|
356
|
+
"""
|
|
357
|
+
for optimiser in optimisers:
|
|
358
|
+
if param not in optimiser.state:
|
|
359
|
+
continue
|
|
360
|
+
state = optimiser.state[param]
|
|
361
|
+
|
|
362
|
+
for _, buffer in state.items():
|
|
363
|
+
if torch.is_tensor(buffer) and buffer.shape == param.shape:
|
|
364
|
+
# Vectorized zeroing
|
|
365
|
+
if dim == 0:
|
|
366
|
+
buffer[idx_tensor] = 0.0
|
|
367
|
+
else:
|
|
368
|
+
buffer[:, idx_tensor] = 0.0
|
broccoli/transformer.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import math
|
|
2
3
|
from typing import Optional, Tuple
|
|
3
4
|
|
|
@@ -13,12 +14,19 @@ from .rope import RotaryEmbedding, apply_rotary_emb
|
|
|
13
14
|
try:
|
|
14
15
|
from flash_attn import flash_attn_func
|
|
15
16
|
|
|
17
|
+
print("Using flash-attn.")
|
|
16
18
|
FLASH_ATTN = True
|
|
17
19
|
except ImportError:
|
|
18
20
|
pass
|
|
19
21
|
FLASH_ATTN = False
|
|
20
22
|
|
|
21
23
|
|
|
24
|
+
def scale_parameters(torch_module: nn.Module, factor: float):
|
|
25
|
+
with torch.no_grad():
|
|
26
|
+
for param in torch_module.parameters():
|
|
27
|
+
param.mul_(factor)
|
|
28
|
+
|
|
29
|
+
|
|
22
30
|
def drop_path(
|
|
23
31
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
24
32
|
):
|
|
@@ -77,9 +85,11 @@ class MHAttention(nn.Module):
|
|
|
77
85
|
seq_len=None,
|
|
78
86
|
linear_module: nn.Module = nn.Linear,
|
|
79
87
|
utility_tokens=0,
|
|
88
|
+
talking_heads=False,
|
|
80
89
|
rotary_embedding=None,
|
|
81
90
|
source_size=None,
|
|
82
91
|
scaling="d",
|
|
92
|
+
beta=1.0,
|
|
83
93
|
):
|
|
84
94
|
"""
|
|
85
95
|
Args:
|
|
@@ -95,10 +105,20 @@ class MHAttention(nn.Module):
|
|
|
95
105
|
if causal:
|
|
96
106
|
assert seq_len is not None
|
|
97
107
|
|
|
108
|
+
self.talking_heads = talking_heads
|
|
109
|
+
|
|
110
|
+
if self.talking_heads:
|
|
111
|
+
self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
112
|
+
self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
113
|
+
else:
|
|
114
|
+
self.head_projection = None
|
|
115
|
+
self.sample_projection = None
|
|
116
|
+
|
|
98
117
|
self.embed_dim = embed_dim
|
|
99
118
|
self.n_heads = n_heads
|
|
100
119
|
assert embed_dim % n_heads == 0
|
|
101
120
|
self.scaling = scaling
|
|
121
|
+
self.beta = beta
|
|
102
122
|
|
|
103
123
|
self.head_dim = self.embed_dim // self.n_heads
|
|
104
124
|
|
|
@@ -180,17 +200,26 @@ class MHAttention(nn.Module):
|
|
|
180
200
|
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
181
201
|
)
|
|
182
202
|
|
|
183
|
-
|
|
184
|
-
|
|
203
|
+
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
204
|
+
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
205
|
+
|
|
206
|
+
q_util, q_img = (
|
|
207
|
+
q[:, : self.utility_tokens, :, :],
|
|
208
|
+
q[:, self.utility_tokens :, :, :],
|
|
209
|
+
)
|
|
210
|
+
k_util, k_img = (
|
|
211
|
+
k[:, : self.utility_tokens, :, :],
|
|
212
|
+
k[:, self.utility_tokens :, :, :],
|
|
213
|
+
)
|
|
185
214
|
|
|
186
215
|
q_img = rearrange(
|
|
187
216
|
q_img,
|
|
188
|
-
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
217
|
+
f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
|
|
189
218
|
**spatial_dimension_values,
|
|
190
219
|
)
|
|
191
220
|
k_img = rearrange(
|
|
192
221
|
k_img,
|
|
193
|
-
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
222
|
+
f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
|
|
194
223
|
**spatial_dimension_values,
|
|
195
224
|
)
|
|
196
225
|
|
|
@@ -201,17 +230,20 @@ class MHAttention(nn.Module):
|
|
|
201
230
|
|
|
202
231
|
q_img = rearrange(
|
|
203
232
|
q_img,
|
|
204
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
233
|
+
f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
|
|
205
234
|
)
|
|
206
235
|
k_img = rearrange(
|
|
207
236
|
k_img,
|
|
208
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
237
|
+
f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
|
|
209
238
|
)
|
|
210
239
|
|
|
211
240
|
# Re-combine the utility tokens and the RoPE-enhanced sequence tokens
|
|
212
241
|
q = torch.cat([q_util, q_img], dim=1)
|
|
213
242
|
k = torch.cat([k_util, k_img], dim=1)
|
|
214
243
|
|
|
244
|
+
q = rearrange(q, "b t h d -> b t (h d)")
|
|
245
|
+
k = rearrange(k, "b t h d -> b t (h d)")
|
|
246
|
+
|
|
215
247
|
return q, k
|
|
216
248
|
|
|
217
249
|
def project_qkv(
|
|
@@ -242,7 +274,7 @@ class MHAttention(nn.Module):
|
|
|
242
274
|
|
|
243
275
|
q, k, v = self.project_qkv(q, k, v)
|
|
244
276
|
|
|
245
|
-
if FLASH_ATTN:
|
|
277
|
+
if FLASH_ATTN and not self.talking_heads:
|
|
246
278
|
# Divide Q/K/V into heads
|
|
247
279
|
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
248
280
|
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
@@ -270,12 +302,22 @@ class MHAttention(nn.Module):
|
|
|
270
302
|
|
|
271
303
|
qk_scores *= self.scaling_factor
|
|
272
304
|
|
|
305
|
+
if self.talking_heads:
|
|
306
|
+
qk_scores = torch.einsum(
|
|
307
|
+
"b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
|
|
308
|
+
)
|
|
309
|
+
|
|
273
310
|
# Apply mask if causal (must come before softmax)
|
|
274
311
|
if self.causal:
|
|
275
312
|
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
276
313
|
|
|
277
314
|
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
278
315
|
|
|
316
|
+
if self.talking_heads:
|
|
317
|
+
qk_scores = torch.einsum(
|
|
318
|
+
"b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
|
|
319
|
+
)
|
|
320
|
+
|
|
279
321
|
qk_scores = self.dropout(qk_scores)
|
|
280
322
|
|
|
281
323
|
output_with_heads = qk_scores @ v
|
|
@@ -308,7 +350,14 @@ class MHAttention(nn.Module):
|
|
|
308
350
|
self.q_proj.reset_parameters()
|
|
309
351
|
self.k_proj.reset_parameters()
|
|
310
352
|
self.v_proj.reset_parameters()
|
|
353
|
+
scale_parameters(self.v_proj, self.beta) # per Microsoft DeepNet
|
|
311
354
|
self.out_proj.reset_parameters()
|
|
355
|
+
scale_parameters(self.out_proj, self.beta) # per Microsoft DeepNet
|
|
356
|
+
|
|
357
|
+
if self.talking_heads:
|
|
358
|
+
# Initialize close to identity
|
|
359
|
+
nn.init.eye_(self.head_projection.weight)
|
|
360
|
+
nn.init.eye_(self.sample_projection.weight)
|
|
312
361
|
|
|
313
362
|
|
|
314
363
|
class FeedforwardBlock(nn.Module):
|
|
@@ -324,19 +373,19 @@ class FeedforwardBlock(nn.Module):
|
|
|
324
373
|
activation=nn.ReLU,
|
|
325
374
|
activation_kwargs=None,
|
|
326
375
|
dropout=0.0,
|
|
376
|
+
inner_dropout=None,
|
|
377
|
+
outer_dropout=None,
|
|
327
378
|
linear_module_up=nn.Linear,
|
|
328
379
|
linear_module_down=nn.Linear,
|
|
329
|
-
pre_norm=True,
|
|
330
380
|
normformer=False,
|
|
331
|
-
post_norm=True,
|
|
332
|
-
residual_path=True,
|
|
333
381
|
checkpoint=True,
|
|
382
|
+
beta=1.0,
|
|
334
383
|
):
|
|
335
384
|
super().__init__()
|
|
336
385
|
|
|
337
386
|
self.checkpoint = checkpoint
|
|
338
|
-
self.
|
|
339
|
-
self.
|
|
387
|
+
self.beta = beta
|
|
388
|
+
self.xglu = activation.__name__.endswith("GLU")
|
|
340
389
|
|
|
341
390
|
if self.residual_path and (output_features < input_features):
|
|
342
391
|
raise ValueError(
|
|
@@ -353,40 +402,76 @@ class FeedforwardBlock(nn.Module):
|
|
|
353
402
|
else:
|
|
354
403
|
self.activation = activation()
|
|
355
404
|
|
|
356
|
-
self.
|
|
405
|
+
self.inner_dropout = nn.Dropout(
|
|
406
|
+
inner_dropout if inner_dropout is not None else dropout
|
|
407
|
+
)
|
|
408
|
+
self.outer_dropout = nn.Dropout(
|
|
409
|
+
outer_dropout if outer_dropout is not None else dropout
|
|
410
|
+
)
|
|
357
411
|
|
|
358
412
|
self.max_features = (
|
|
359
|
-
2 * ratio * output_features
|
|
360
|
-
if
|
|
361
|
-
else ratio * output_features
|
|
413
|
+
2 * int(ratio * output_features)
|
|
414
|
+
if self.xglu
|
|
415
|
+
else int(ratio * output_features)
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
self.linear_in = linear_module_up(input_features, self.max_features)
|
|
419
|
+
self.linear_out = linear_module_down(
|
|
420
|
+
int(ratio * output_features), output_features
|
|
362
421
|
)
|
|
363
422
|
|
|
364
423
|
self.process = nn.Sequential(
|
|
365
424
|
*[
|
|
366
|
-
|
|
367
|
-
linear_module_up(input_features, self.max_features),
|
|
425
|
+
self.linear_in,
|
|
368
426
|
self.activation,
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
427
|
+
self.inner_dropout,
|
|
428
|
+
(
|
|
429
|
+
nn.LayerNorm(int(ratio * output_features))
|
|
430
|
+
if normformer
|
|
431
|
+
else nn.Identity()
|
|
432
|
+
),
|
|
433
|
+
self.linear_out,
|
|
434
|
+
self.outer_dropout,
|
|
372
435
|
]
|
|
373
436
|
)
|
|
374
437
|
|
|
438
|
+
self.recycling_enabled = False
|
|
439
|
+
if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
|
|
440
|
+
self.linear_out, "column_recycling_rate"
|
|
441
|
+
):
|
|
442
|
+
self.recycling_enabled = True
|
|
443
|
+
self.master_recycling_rate = self.linear_in.row_recycling_rate
|
|
444
|
+
self.linear_in.row_recycling_rate = 0.0
|
|
445
|
+
self.linear_out.column_recycling_rate = 0.0
|
|
446
|
+
if (
|
|
447
|
+
hasattr(self.linear_in, "column_recycling_rate")
|
|
448
|
+
and self.linear_in.column_recycling_rate > 0
|
|
449
|
+
) or (
|
|
450
|
+
hasattr(self.linear_out, "row_recycling_rate")
|
|
451
|
+
and self.linear_out.row_recycling_rate > 0
|
|
452
|
+
):
|
|
453
|
+
raise NotImplementedError(
|
|
454
|
+
"At the moment this layer can only support recycling linear "
|
|
455
|
+
"layers if the in layer resets only rows and the out layer "
|
|
456
|
+
"resets only columns."
|
|
457
|
+
)
|
|
458
|
+
|
|
375
459
|
self.reset_parameters()
|
|
376
460
|
|
|
377
461
|
def forward(self, x):
|
|
378
462
|
|
|
463
|
+
# Recycle weights if using recycling linear layers
|
|
464
|
+
if self.training and self.recycling_enabled:
|
|
465
|
+
indices = self.linear_out.get_reset_indices(1)
|
|
466
|
+
self.linear_in.reset_rows(indices, incoming_data=x)
|
|
467
|
+
self.linear_out.reset_columns(indices)
|
|
468
|
+
|
|
379
469
|
if self.checkpoint:
|
|
380
470
|
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
381
471
|
else:
|
|
382
472
|
processed = self.process(x)
|
|
383
473
|
|
|
384
|
-
|
|
385
|
-
return self.layernorm(x + processed)
|
|
386
|
-
elif self.residual_path:
|
|
387
|
-
return x + processed
|
|
388
|
-
else:
|
|
389
|
-
return processed
|
|
474
|
+
return processed
|
|
390
475
|
|
|
391
476
|
def reset_parameters(self):
|
|
392
477
|
if self.post_norm:
|
|
@@ -397,8 +482,11 @@ class FeedforwardBlock(nn.Module):
|
|
|
397
482
|
if hasattr(module, "reset_parameters"):
|
|
398
483
|
module.reset_parameters()
|
|
399
484
|
|
|
485
|
+
scale_parameters(self.linear_in, self.beta) # per Microsoft DeepNet
|
|
486
|
+
scale_parameters(self.linear_out, self.beta)
|
|
487
|
+
|
|
400
488
|
|
|
401
|
-
class
|
|
489
|
+
class EncoderBlock(nn.Module):
|
|
402
490
|
"""
|
|
403
491
|
Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
|
|
404
492
|
which is also what is seen in e.g.
|
|
@@ -415,13 +503,16 @@ class TransformerBlock(nn.Module):
|
|
|
415
503
|
relative_position_embedding=False,
|
|
416
504
|
source_size=None,
|
|
417
505
|
utility_tokens=0,
|
|
506
|
+
talking_heads=False,
|
|
418
507
|
mlp_ratio=4,
|
|
419
508
|
activation: nn.Module = nn.ReLU,
|
|
420
509
|
activation_kwargs: Optional[dict] = None,
|
|
421
510
|
ff_linear_module_up=None,
|
|
422
511
|
ff_linear_module_down=None,
|
|
423
512
|
msa_scaling="d",
|
|
424
|
-
|
|
513
|
+
ff_dropout=0.0,
|
|
514
|
+
ff_inner_dropout=0.0,
|
|
515
|
+
ff_outer_dropout=0.0,
|
|
425
516
|
msa_dropout=0.0,
|
|
426
517
|
identity_probability=0.0,
|
|
427
518
|
causal=False,
|
|
@@ -430,6 +521,8 @@ class TransformerBlock(nn.Module):
|
|
|
430
521
|
post_norm=False,
|
|
431
522
|
normformer=False,
|
|
432
523
|
checkpoint_ff=True,
|
|
524
|
+
alpha=1.0,
|
|
525
|
+
beta=1.0,
|
|
433
526
|
):
|
|
434
527
|
"""
|
|
435
528
|
Args:
|
|
@@ -441,15 +534,29 @@ class TransformerBlock(nn.Module):
|
|
|
441
534
|
|
|
442
535
|
super().__init__()
|
|
443
536
|
|
|
537
|
+
if pre_norm and post_norm:
|
|
538
|
+
raise ValueError("A transformer cannot be both prenorm and postnorm.")
|
|
539
|
+
|
|
444
540
|
self.pre_norm = pre_norm
|
|
445
541
|
self.post_norm = post_norm
|
|
446
542
|
self.normformer = normformer
|
|
447
543
|
|
|
544
|
+
self.alpha = alpha
|
|
545
|
+
self.beta = beta
|
|
546
|
+
|
|
448
547
|
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
|
449
548
|
|
|
450
|
-
self.
|
|
451
|
-
|
|
452
|
-
|
|
549
|
+
if self.pre_norm:
|
|
550
|
+
self.pre_attention_norm = nn.LayerNorm(d_model)
|
|
551
|
+
self.pre_mlp_norm = nn.LayerNorm(d_model)
|
|
552
|
+
|
|
553
|
+
if normformer:
|
|
554
|
+
self.normformer_norm = nn.LayerNorm(d_model)
|
|
555
|
+
|
|
556
|
+
if self.post_norm:
|
|
557
|
+
self.input_norm = nn.LayerNorm(d_model)
|
|
558
|
+
self.post_attention_norm = nn.LayerNorm(d_model)
|
|
559
|
+
self.post_mlp_norm = nn.LayerNorm(d_model)
|
|
453
560
|
|
|
454
561
|
if relative_position_embedding:
|
|
455
562
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
|
@@ -473,7 +580,9 @@ class TransformerBlock(nn.Module):
|
|
|
473
580
|
rotary_embedding=self.rotary_embedding,
|
|
474
581
|
source_size=source_size,
|
|
475
582
|
utility_tokens=utility_tokens,
|
|
583
|
+
talking_heads=talking_heads,
|
|
476
584
|
scaling=msa_scaling,
|
|
585
|
+
beta=beta,
|
|
477
586
|
)
|
|
478
587
|
|
|
479
588
|
# Submodule for the feedforward process
|
|
@@ -483,7 +592,9 @@ class TransformerBlock(nn.Module):
|
|
|
483
592
|
d_model,
|
|
484
593
|
activation=activation,
|
|
485
594
|
activation_kwargs=activation_kwargs,
|
|
486
|
-
dropout=
|
|
595
|
+
dropout=ff_dropout,
|
|
596
|
+
inner_dropout=ff_inner_dropout,
|
|
597
|
+
outer_dropout=ff_outer_dropout,
|
|
487
598
|
linear_module_up=(
|
|
488
599
|
ff_linear_module_up
|
|
489
600
|
if ff_linear_module_up is not None
|
|
@@ -494,11 +605,9 @@ class TransformerBlock(nn.Module):
|
|
|
494
605
|
if ff_linear_module_down is not None
|
|
495
606
|
else linear_module
|
|
496
607
|
),
|
|
497
|
-
pre_norm=False, # Handled outside the block
|
|
498
608
|
normformer=normformer,
|
|
499
|
-
post_norm=False, # Handled outside the block
|
|
500
|
-
residual_path=False, # Handled outside the block
|
|
501
609
|
checkpoint=checkpoint_ff,
|
|
610
|
+
beta=beta,
|
|
502
611
|
)
|
|
503
612
|
|
|
504
613
|
self.reset_parameters()
|
|
@@ -508,22 +617,34 @@ class TransformerBlock(nn.Module):
|
|
|
508
617
|
return self.attn._kv_distance
|
|
509
618
|
|
|
510
619
|
def forward(self, x):
|
|
620
|
+
if self.post_norm:
|
|
621
|
+
x = self.input_norm(x)
|
|
511
622
|
|
|
512
623
|
if self.pre_norm:
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
x =
|
|
526
|
-
|
|
624
|
+
process_x = self.pre_attention_norm(x)
|
|
625
|
+
else:
|
|
626
|
+
process_x = x
|
|
627
|
+
|
|
628
|
+
processed = self.drop_path(self.attn(process_x, process_x, process_x))
|
|
629
|
+
|
|
630
|
+
if self.normformer:
|
|
631
|
+
processed = self.normformer_norm(processed)
|
|
632
|
+
|
|
633
|
+
x = self.alpha * x + processed
|
|
634
|
+
|
|
635
|
+
if self.post_norm:
|
|
636
|
+
x = self.post_attention_norm(x)
|
|
637
|
+
elif self.pre_norm:
|
|
638
|
+
process_x = self.pre_mlp_norm(x)
|
|
639
|
+
else:
|
|
640
|
+
process_x = x
|
|
641
|
+
|
|
642
|
+
processed = self.drop_path(self.ff(process_x))
|
|
643
|
+
|
|
644
|
+
x = self.alpha * x + processed
|
|
645
|
+
|
|
646
|
+
if self.post_norm:
|
|
647
|
+
x = self.post_mlp_norm(x)
|
|
527
648
|
|
|
528
649
|
return x
|
|
529
650
|
|
|
@@ -531,16 +652,26 @@ class TransformerBlock(nn.Module):
|
|
|
531
652
|
"""
|
|
532
653
|
Give back the attention scores used in this layer.
|
|
533
654
|
"""
|
|
655
|
+
# Fix: Use the correct attribute name 'pre_attention_norm'
|
|
534
656
|
if self.pre_norm:
|
|
535
|
-
|
|
657
|
+
# We must normalize the input before measuring attention logits
|
|
658
|
+
# to match what the model actually sees during forward()
|
|
659
|
+
x = self.pre_attention_norm(x)
|
|
536
660
|
return self.attn.attention_logits(x, x, x)
|
|
537
661
|
else:
|
|
538
662
|
return self.attn.attention_logits(x, x, x)
|
|
539
663
|
|
|
540
664
|
def reset_parameters(self):
|
|
541
|
-
self.
|
|
542
|
-
|
|
543
|
-
|
|
665
|
+
if self.pre_norm:
|
|
666
|
+
self.pre_attention_norm.reset_parameters()
|
|
667
|
+
self.pre_mlp_norm.reset_parameters()
|
|
668
|
+
|
|
669
|
+
if self.post_norm:
|
|
670
|
+
self.post_attention_norm.reset_parameters()
|
|
671
|
+
self.post_mlp_norm.reset_parameters()
|
|
672
|
+
|
|
673
|
+
if self.normformer:
|
|
674
|
+
self.normformer_norm.reset_parameters()
|
|
544
675
|
|
|
545
676
|
self.attn.reset_parameters()
|
|
546
677
|
self.ff.reset_parameters()
|
|
@@ -566,18 +697,23 @@ class TransformerEncoder(nn.Module):
|
|
|
566
697
|
activation_kwargs: Optional[dict] = None,
|
|
567
698
|
ff_linear_module_up=None,
|
|
568
699
|
ff_linear_module_down=None,
|
|
569
|
-
|
|
700
|
+
ff_dropout=0.0,
|
|
701
|
+
ff_inner_dropout=0.0,
|
|
702
|
+
ff_outer_dropout=0.0,
|
|
570
703
|
msa_dropout=0.0,
|
|
571
704
|
stochastic_depth=0.0,
|
|
572
705
|
causal=False,
|
|
573
706
|
linear_module=nn.Linear,
|
|
574
707
|
utility_tokens=0,
|
|
708
|
+
talking_heads=False,
|
|
575
709
|
return_utility_tokens=False,
|
|
576
710
|
pre_norm=True,
|
|
577
711
|
post_norm=False,
|
|
578
712
|
normformer=False,
|
|
579
713
|
msa_scaling="d",
|
|
580
714
|
checkpoint_ff=True,
|
|
715
|
+
alpha=1.0,
|
|
716
|
+
beta=1.0,
|
|
581
717
|
):
|
|
582
718
|
"""
|
|
583
719
|
Args:
|
|
@@ -590,10 +726,23 @@ class TransformerEncoder(nn.Module):
|
|
|
590
726
|
if relative_position_embedding and (source_size is None):
|
|
591
727
|
raise ValueError(
|
|
592
728
|
"`source_size` for TransformerEncoder cannot be None if"
|
|
593
|
-
" `
|
|
729
|
+
" `relative_position_embedding` is True"
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
if absolute_position_embedding and (seq_len is None):
|
|
733
|
+
raise ValueError(
|
|
734
|
+
"`seq_len` for TransformerEncoder cannot be None if"
|
|
735
|
+
" `absolute_position_embedding` is True"
|
|
594
736
|
)
|
|
595
737
|
|
|
596
738
|
super().__init__()
|
|
739
|
+
|
|
740
|
+
if FLASH_ATTN and talking_heads:
|
|
741
|
+
warnings.warn(
|
|
742
|
+
"Using talking heads currently prevents using flash attention.",
|
|
743
|
+
stacklevel=2,
|
|
744
|
+
)
|
|
745
|
+
|
|
597
746
|
self.seq_len = seq_len
|
|
598
747
|
self.n_heads = n_heads
|
|
599
748
|
self._utility_tokens = utility_tokens
|
|
@@ -605,9 +754,12 @@ class TransformerEncoder(nn.Module):
|
|
|
605
754
|
torch.empty(self._utility_tokens, d_model)
|
|
606
755
|
)
|
|
607
756
|
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
608
|
-
self.full_sequence_length = self.seq_len + self._utility_tokens
|
|
609
757
|
else:
|
|
610
758
|
self._utility_token_embedding = None
|
|
759
|
+
|
|
760
|
+
if self._utility_tokens and (self.seq_len is not None):
|
|
761
|
+
self.full_sequence_length = self.seq_len + self._utility_tokens
|
|
762
|
+
else:
|
|
611
763
|
self.full_sequence_length = self.seq_len
|
|
612
764
|
|
|
613
765
|
self.d_model = d_model
|
|
@@ -619,7 +771,7 @@ class TransformerEncoder(nn.Module):
|
|
|
619
771
|
else:
|
|
620
772
|
self.absolute_position_embedding = None
|
|
621
773
|
|
|
622
|
-
self.mlp_dropout =
|
|
774
|
+
self.mlp_dropout = ff_dropout
|
|
623
775
|
self.msa_dropout = msa_dropout
|
|
624
776
|
self.stochastic_depth = stochastic_depth
|
|
625
777
|
|
|
@@ -635,20 +787,23 @@ class TransformerEncoder(nn.Module):
|
|
|
635
787
|
|
|
636
788
|
self.blocks = nn.ModuleList(
|
|
637
789
|
[
|
|
638
|
-
|
|
790
|
+
EncoderBlock(
|
|
639
791
|
self.full_sequence_length,
|
|
640
792
|
d_model,
|
|
641
793
|
n_heads,
|
|
642
794
|
relative_position_embedding=relative_position_embedding,
|
|
643
795
|
source_size=source_size,
|
|
644
796
|
utility_tokens=utility_tokens,
|
|
797
|
+
talking_heads=talking_heads,
|
|
645
798
|
mlp_ratio=mlp_ratio,
|
|
646
799
|
activation=activation,
|
|
647
800
|
activation_kwargs=activation_kwargs,
|
|
648
801
|
ff_linear_module_up=ff_linear_module_up,
|
|
649
802
|
ff_linear_module_down=ff_linear_module_down,
|
|
650
803
|
msa_scaling=msa_scaling,
|
|
651
|
-
|
|
804
|
+
ff_dropout=ff_dropout,
|
|
805
|
+
ff_inner_dropout=ff_inner_dropout,
|
|
806
|
+
ff_outer_dropout=ff_outer_dropout,
|
|
652
807
|
msa_dropout=msa_dropout,
|
|
653
808
|
identity_probability=self.stochastic_depth_probabilities[i],
|
|
654
809
|
causal=causal,
|
|
@@ -657,6 +812,8 @@ class TransformerEncoder(nn.Module):
|
|
|
657
812
|
post_norm=post_norm,
|
|
658
813
|
normformer=normformer,
|
|
659
814
|
checkpoint_ff=checkpoint_ff,
|
|
815
|
+
alpha=alpha,
|
|
816
|
+
beta=beta,
|
|
660
817
|
)
|
|
661
818
|
for i in range(n_layers)
|
|
662
819
|
]
|
|
@@ -677,13 +834,14 @@ class TransformerEncoder(nn.Module):
|
|
|
677
834
|
x = x
|
|
678
835
|
|
|
679
836
|
if self.absolute_position_embedding is not None:
|
|
680
|
-
|
|
837
|
+
position_embedding = self.absolute_position_embedding(
|
|
681
838
|
torch.arange(
|
|
682
839
|
0, self.full_sequence_length, dtype=torch.long, device=x.device
|
|
683
840
|
).unsqueeze(
|
|
684
841
|
0
|
|
685
842
|
) # to shape (1, seq_len) to broadcast over batch
|
|
686
843
|
)
|
|
844
|
+
x += position_embedding
|
|
687
845
|
|
|
688
846
|
return x
|
|
689
847
|
|
broccoli/vit.py
CHANGED
|
@@ -158,10 +158,11 @@ class ViTEncoder(nn.Module):
|
|
|
158
158
|
pooling_kernel_stride=2,
|
|
159
159
|
pooling_padding=1,
|
|
160
160
|
transformer_feedforward_first=True,
|
|
161
|
-
transformer_initial_ff_residual_path=True,
|
|
162
161
|
transformer_initial_ff_linear_module_up=None,
|
|
163
162
|
transformer_initial_ff_linear_module_down=None,
|
|
164
|
-
|
|
163
|
+
transformer_initial_ff_dropout=None,
|
|
164
|
+
transformer_initial_ff_inner_dropout=None,
|
|
165
|
+
transformer_initial_ff_outer_dropout=None,
|
|
165
166
|
transformer_pre_norm=True,
|
|
166
167
|
transformer_normformer=False,
|
|
167
168
|
transformer_post_norm=False,
|
|
@@ -172,20 +173,28 @@ class ViTEncoder(nn.Module):
|
|
|
172
173
|
transformer_heads=4,
|
|
173
174
|
transformer_mlp_ratio=2,
|
|
174
175
|
transformer_utility_tokens=0,
|
|
176
|
+
transformer_talking_heads=False,
|
|
175
177
|
transformer_return_utility_tokens=False,
|
|
176
178
|
transformer_activation: nn.Module = SquaredReLU,
|
|
177
179
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
178
180
|
transformer_ff_linear_module_up=None,
|
|
179
181
|
transformer_ff_linear_module_down=None,
|
|
180
182
|
transformer_msa_scaling="d",
|
|
181
|
-
|
|
183
|
+
transformer_ff_dropout=0.0,
|
|
184
|
+
transformer_ff_inner_dropout=0.0,
|
|
185
|
+
transformer_ff_outer_dropout=0.0,
|
|
182
186
|
transformer_msa_dropout=0.1,
|
|
183
187
|
transformer_stochastic_depth=0.1,
|
|
184
188
|
transformer_checkpoint_ff=True,
|
|
185
189
|
linear_module=nn.Linear,
|
|
190
|
+
alpha=1.0,
|
|
191
|
+
beta=1.0,
|
|
186
192
|
):
|
|
187
193
|
super().__init__()
|
|
188
194
|
|
|
195
|
+
self.alpha = alpha
|
|
196
|
+
self.beta = beta
|
|
197
|
+
|
|
189
198
|
if cnn_activation_kwargs is not None:
|
|
190
199
|
self.cnn_activation = cnn_activation(**cnn_activation_kwargs)
|
|
191
200
|
else:
|
|
@@ -333,17 +342,22 @@ class ViTEncoder(nn.Module):
|
|
|
333
342
|
ff_linear_module_up=transformer_ff_linear_module_up,
|
|
334
343
|
ff_linear_module_down=transformer_ff_linear_module_down,
|
|
335
344
|
msa_scaling=transformer_msa_scaling,
|
|
336
|
-
|
|
345
|
+
ff_dropout=transformer_ff_dropout,
|
|
346
|
+
ff_inner_dropout=transformer_ff_inner_dropout,
|
|
347
|
+
ff_outer_dropout=transformer_ff_outer_dropout,
|
|
337
348
|
msa_dropout=transformer_msa_dropout,
|
|
338
349
|
stochastic_depth=transformer_stochastic_depth,
|
|
339
350
|
causal=False,
|
|
340
351
|
linear_module=linear_module,
|
|
341
352
|
utility_tokens=transformer_utility_tokens,
|
|
353
|
+
talking_heads=transformer_talking_heads,
|
|
342
354
|
return_utility_tokens=transformer_return_utility_tokens,
|
|
343
355
|
pre_norm=transformer_pre_norm,
|
|
344
356
|
normformer=transformer_normformer,
|
|
345
357
|
post_norm=transformer_post_norm,
|
|
346
358
|
checkpoint_ff=transformer_checkpoint_ff,
|
|
359
|
+
alpha=self.alpha,
|
|
360
|
+
beta=self.beta,
|
|
347
361
|
)
|
|
348
362
|
else:
|
|
349
363
|
self.transformer = nn.Identity()
|
|
@@ -357,9 +371,21 @@ class ViTEncoder(nn.Module):
|
|
|
357
371
|
activation_kwargs=transformer_activation_kwargs,
|
|
358
372
|
dropout=(
|
|
359
373
|
# First truthy assigned value
|
|
360
|
-
|
|
361
|
-
if
|
|
362
|
-
else
|
|
374
|
+
transformer_initial_ff_dropout
|
|
375
|
+
if transformer_initial_ff_dropout is not None
|
|
376
|
+
else transformer_ff_dropout
|
|
377
|
+
),
|
|
378
|
+
inner_dropout=(
|
|
379
|
+
# First truthy assigned value
|
|
380
|
+
transformer_initial_ff_inner_dropout
|
|
381
|
+
if transformer_initial_ff_inner_dropout is not None
|
|
382
|
+
else transformer_ff_inner_dropout
|
|
383
|
+
),
|
|
384
|
+
outer_dropout=(
|
|
385
|
+
# First truthy assigned value
|
|
386
|
+
transformer_initial_ff_outer_dropout
|
|
387
|
+
if transformer_initial_ff_outer_dropout is not None
|
|
388
|
+
else transformer_ff_outer_dropout
|
|
363
389
|
),
|
|
364
390
|
linear_module_up=(
|
|
365
391
|
# First truthy assigned value
|
|
@@ -373,16 +399,14 @@ class ViTEncoder(nn.Module):
|
|
|
373
399
|
or transformer_ff_linear_module_down
|
|
374
400
|
or linear_module
|
|
375
401
|
),
|
|
376
|
-
pre_norm=transformer_pre_norm,
|
|
377
402
|
normformer=transformer_normformer,
|
|
378
|
-
post_norm=transformer_post_norm,
|
|
379
|
-
residual_path=transformer_initial_ff_residual_path,
|
|
380
403
|
checkpoint=transformer_checkpoint_ff,
|
|
404
|
+
beta=self.beta,
|
|
381
405
|
)
|
|
382
406
|
else:
|
|
383
407
|
self.initial_ff = nn.Identity()
|
|
384
408
|
|
|
385
|
-
self.
|
|
409
|
+
self.preprocess = nn.Sequential(
|
|
386
410
|
*[
|
|
387
411
|
batchnormxd(in_channels) if initial_batch_norm else nn.Identity(),
|
|
388
412
|
self.cnn,
|
|
@@ -392,19 +416,21 @@ class ViTEncoder(nn.Module):
|
|
|
392
416
|
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
|
393
417
|
),
|
|
394
418
|
self.pooling_channels_padding,
|
|
395
|
-
|
|
396
|
-
self.transformer,
|
|
419
|
+
nn.LayerNorm(),
|
|
397
420
|
]
|
|
398
421
|
)
|
|
399
422
|
|
|
400
423
|
self.reset_parameters()
|
|
401
424
|
|
|
402
425
|
def forward(self, x):
|
|
403
|
-
|
|
426
|
+
x = self.preprocess(x)
|
|
427
|
+
x = x + self.initial_ff(x)
|
|
428
|
+
return self.transformer(x)
|
|
404
429
|
|
|
405
430
|
def attention_logits(self, x):
|
|
406
|
-
x = self.
|
|
407
|
-
|
|
431
|
+
x = self.preprocess(x)
|
|
432
|
+
x = x + self.initial_ff(x)
|
|
433
|
+
return self.transformer.attention_logits(x)
|
|
408
434
|
|
|
409
435
|
def reset_parameters(self):
|
|
410
436
|
for module in self.encoder:
|
|
@@ -438,10 +464,11 @@ class ViT(nn.Module):
|
|
|
438
464
|
pooling_kernel_stride=2,
|
|
439
465
|
pooling_padding=1,
|
|
440
466
|
transformer_feedforward_first=True,
|
|
441
|
-
transformer_initial_ff_residual_path=True,
|
|
442
467
|
transformer_initial_ff_linear_module_up=None,
|
|
443
468
|
transformer_initial_ff_linear_module_down=None,
|
|
444
|
-
|
|
469
|
+
transformer_initial_ff_dropout=None,
|
|
470
|
+
transformer_initial_ff_inner_dropout=None,
|
|
471
|
+
transformer_initial_ff_outer_dropout=None,
|
|
445
472
|
transformer_pre_norm=True,
|
|
446
473
|
transformer_normformer=False,
|
|
447
474
|
transformer_post_norm=False,
|
|
@@ -452,13 +479,16 @@ class ViT(nn.Module):
|
|
|
452
479
|
transformer_heads=4,
|
|
453
480
|
transformer_mlp_ratio=2,
|
|
454
481
|
transformer_utility_tokens=0,
|
|
482
|
+
transformer_talking_heads=False,
|
|
455
483
|
transformer_return_utility_tokens=False,
|
|
456
484
|
transformer_activation: nn.Module = SquaredReLU,
|
|
457
485
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
458
486
|
transformer_ff_linear_module_up=None,
|
|
459
487
|
transformer_ff_linear_module_down=None,
|
|
460
488
|
transformer_msa_scaling="d",
|
|
461
|
-
|
|
489
|
+
transformer_ff_dropout=0.0,
|
|
490
|
+
transformer_ff_inner_dropout=0.0,
|
|
491
|
+
transformer_ff_outer_dropout=0.0,
|
|
462
492
|
transformer_msa_dropout=0.1,
|
|
463
493
|
transformer_stochastic_depth=0.1,
|
|
464
494
|
transformer_checkpoint_ff=True,
|
|
@@ -466,6 +496,8 @@ class ViT(nn.Module):
|
|
|
466
496
|
batch_norm_logits=True,
|
|
467
497
|
logit_projection_layer=nn.Linear,
|
|
468
498
|
linear_module=nn.Linear,
|
|
499
|
+
alpha=1.0,
|
|
500
|
+
beta=1.0,
|
|
469
501
|
):
|
|
470
502
|
|
|
471
503
|
super().__init__()
|
|
@@ -486,6 +518,9 @@ class ViT(nn.Module):
|
|
|
486
518
|
"SwiGLU": SwiGLU,
|
|
487
519
|
}[transformer_activation]
|
|
488
520
|
|
|
521
|
+
self.alpha = alpha
|
|
522
|
+
self.beta = beta
|
|
523
|
+
|
|
489
524
|
self.encoder = ViTEncoder(
|
|
490
525
|
input_size=input_size,
|
|
491
526
|
initial_batch_norm=initial_batch_norm,
|
|
@@ -505,10 +540,11 @@ class ViT(nn.Module):
|
|
|
505
540
|
pooling_kernel_stride=pooling_kernel_stride,
|
|
506
541
|
pooling_padding=pooling_padding,
|
|
507
542
|
transformer_feedforward_first=transformer_feedforward_first,
|
|
508
|
-
transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
|
|
509
543
|
transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
|
|
510
544
|
transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
|
|
511
|
-
|
|
545
|
+
transformer_initial_ff_dropout=transformer_initial_ff_dropout,
|
|
546
|
+
transformer_initial_ff_inner_dropout=transformer_initial_ff_inner_dropout,
|
|
547
|
+
transformer_initial_ff_outer_dropout=transformer_initial_ff_outer_dropout,
|
|
512
548
|
transformer_pre_norm=transformer_pre_norm,
|
|
513
549
|
transformer_normformer=transformer_normformer,
|
|
514
550
|
transformer_post_norm=transformer_post_norm,
|
|
@@ -519,17 +555,22 @@ class ViT(nn.Module):
|
|
|
519
555
|
transformer_heads=transformer_heads,
|
|
520
556
|
transformer_mlp_ratio=transformer_mlp_ratio,
|
|
521
557
|
transformer_utility_tokens=transformer_utility_tokens,
|
|
558
|
+
transformer_talking_heads=transformer_talking_heads,
|
|
522
559
|
transformer_return_utility_tokens=transformer_return_utility_tokens,
|
|
523
560
|
transformer_activation=transformer_activation,
|
|
524
561
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
|
525
562
|
transformer_ff_linear_module_up=transformer_ff_linear_module_up,
|
|
526
563
|
transformer_ff_linear_module_down=transformer_ff_linear_module_down,
|
|
527
564
|
transformer_msa_scaling=transformer_msa_scaling,
|
|
528
|
-
|
|
565
|
+
transformer_ff_dropout=transformer_ff_dropout,
|
|
566
|
+
transformer_ff_inner_dropout=transformer_ff_inner_dropout,
|
|
567
|
+
transformer_ff_outer_dropout=transformer_ff_outer_dropout,
|
|
529
568
|
transformer_msa_dropout=transformer_msa_dropout,
|
|
530
569
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
531
570
|
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
532
571
|
linear_module=linear_module,
|
|
572
|
+
alpha=alpha,
|
|
573
|
+
beta=beta,
|
|
533
574
|
)
|
|
534
575
|
|
|
535
576
|
self.pool = head(
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
|
|
2
|
+
broccoli/activation.py,sha256=nrpTOrpg9k23_E4AJWy7VlXXAJCtCJCOR-TonEWJr04,3218
|
|
3
|
+
broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
4
|
+
broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
|
|
5
|
+
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
|
+
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
+
broccoli/transformer.py,sha256=3vAQQ75SAyr4-m3e7vSru8M-RpUy2Enp5cVUafaVYMU,28410
|
|
8
|
+
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
|
+
broccoli/vit.py,sha256=jd4e6MjL2JKB8ynSQssWRh6Hs36RuLj4uWyUNVhIMUY,22472
|
|
10
|
+
broccoli_ml-13.0.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-13.0.1.dist-info/METADATA,sha256=gN9cKQDpwRr8JG_Ilj0ZknlPdxhq62I8xEBExVspKGw,1369
|
|
12
|
+
broccoli_ml-13.0.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-13.0.1.dist-info/RECORD,,
|
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
|
|
2
|
-
broccoli/activation.py,sha256=-Jf30C6iGqWCorC9HEGn2oduWwjeaCAxGLUUYIy1zX8,3438
|
|
3
|
-
broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
4
|
-
broccoli/linear.py,sha256=Y7s-DzcwsOipRboNHc4HTScw4mJRalNoVFsNcxOB6a4,4872
|
|
5
|
-
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
|
-
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=MxIdzoxoWx_IWcq86vDZJIV4tk-dMNivhopZu8zJk90,23293
|
|
8
|
-
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
|
-
broccoli/vit.py,sha256=9oyh76ulmX5lDPMCDicQhhqm8RYCvJIgAJkDbYRVdi4,20873
|
|
10
|
-
broccoli_ml-6.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
-
broccoli_ml-6.0.0.dist-info/METADATA,sha256=Sv8nRPb7oCAeoMe3AAHIYDewAETvb0ZDxN8IKFniVHk,1368
|
|
12
|
-
broccoli_ml-6.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
-
broccoli_ml-6.0.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|