broccoli-ml 9.5.1__py3-none-any.whl → 9.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/linear.py +33 -17
- broccoli/transformer.py +38 -2
- broccoli/vit.py +4 -0
- {broccoli_ml-9.5.1.dist-info → broccoli_ml-9.7.0.dist-info}/METADATA +1 -1
- {broccoli_ml-9.5.1.dist-info → broccoli_ml-9.7.0.dist-info}/RECORD +7 -7
- {broccoli_ml-9.5.1.dist-info → broccoli_ml-9.7.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-9.5.1.dist-info → broccoli_ml-9.7.0.dist-info}/WHEEL +0 -0
broccoli/linear.py
CHANGED
|
@@ -193,7 +193,12 @@ class RecyclingLinear(nn.Module):
|
|
|
193
193
|
multipliers = [a / b for a, b in pairs if b != 0.0]
|
|
194
194
|
return min(multipliers) if multipliers else 0.0
|
|
195
195
|
|
|
196
|
-
def reset_rows(self, indices):
|
|
196
|
+
def reset_rows(self, indices, incoming_data=None):
|
|
197
|
+
"""
|
|
198
|
+
Resets rows.
|
|
199
|
+
If incoming_data is provided, resets to the centroid (mean) of that data.
|
|
200
|
+
If not, resets to the mean of existing weights.
|
|
201
|
+
"""
|
|
197
202
|
if not torch.is_tensor(indices):
|
|
198
203
|
idx_tensor = torch.as_tensor(
|
|
199
204
|
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
@@ -201,24 +206,24 @@ class RecyclingLinear(nn.Module):
|
|
|
201
206
|
else:
|
|
202
207
|
idx_tensor = indices
|
|
203
208
|
|
|
204
|
-
if idx_tensor.
|
|
205
|
-
value_indices = indices
|
|
206
|
-
centred_value_weights = self._mean_value_weights()
|
|
207
|
-
centred_value_weights = centred_value_weights.expand(indices.size(0), -1)
|
|
208
|
-
if self.xglu:
|
|
209
|
-
gate_indices = indices
|
|
210
|
-
value_indices = indices + (self.linear.out_features // 2)
|
|
211
|
-
centred_gate_weights = self._mean_gate_weights()
|
|
212
|
-
centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
|
|
213
|
-
self._update_weights(
|
|
214
|
-
gate_indices, 0, centred_gate_weights, self.optimisers # dim
|
|
215
|
-
)
|
|
216
|
-
self._update_weights(
|
|
217
|
-
value_indices, 0, centred_value_weights, self.optimisers
|
|
218
|
-
)
|
|
219
|
-
else:
|
|
209
|
+
if idx_tensor.numel() == 0:
|
|
220
210
|
return
|
|
221
211
|
|
|
212
|
+
if incoming_data is not None:
|
|
213
|
+
target_center = self._mean_input_weights(incoming_data)
|
|
214
|
+
else:
|
|
215
|
+
target_center = self._mean_value_weights()
|
|
216
|
+
|
|
217
|
+
target_center = target_center.expand(idx_tensor.size(0), -1)
|
|
218
|
+
|
|
219
|
+
if self.xglu:
|
|
220
|
+
gate_indices = idx_tensor
|
|
221
|
+
value_indices = idx_tensor + (self.linear.out_features // 2)
|
|
222
|
+
self._update_weights(gate_indices, 0, target_center, self.optimisers)
|
|
223
|
+
self._update_weights(value_indices, 0, target_center, self.optimisers)
|
|
224
|
+
else:
|
|
225
|
+
self._update_weights(idx_tensor, 0, target_center, self.optimisers)
|
|
226
|
+
|
|
222
227
|
def reset_columns(self, indices):
|
|
223
228
|
if not torch.is_tensor(indices):
|
|
224
229
|
idx_tensor = torch.as_tensor(
|
|
@@ -281,6 +286,17 @@ class RecyclingLinear(nn.Module):
|
|
|
281
286
|
random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
|
|
282
287
|
return random_weights
|
|
283
288
|
|
|
289
|
+
def _mean_input_weights(self, input):
|
|
290
|
+
reduce_dims = list(range(input.ndim - 1))
|
|
291
|
+
data_mean = input.detach().mean(dim=reduce_dims, keepdim=True)
|
|
292
|
+
|
|
293
|
+
weights = self.linear.weight.data
|
|
294
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
295
|
+
data_norm = data_mean.std() + 1e-6
|
|
296
|
+
scale_factor = stdv / data_norm
|
|
297
|
+
|
|
298
|
+
return data_mean * scale_factor
|
|
299
|
+
|
|
284
300
|
def _mean_value_weights(self):
|
|
285
301
|
"""
|
|
286
302
|
Only used when self.xglu
|
broccoli/transformer.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import math
|
|
2
3
|
from typing import Optional, Tuple
|
|
3
4
|
|
|
@@ -78,6 +79,7 @@ class MHAttention(nn.Module):
|
|
|
78
79
|
seq_len=None,
|
|
79
80
|
linear_module: nn.Module = nn.Linear,
|
|
80
81
|
utility_tokens=0,
|
|
82
|
+
talking_heads=False,
|
|
81
83
|
rotary_embedding=None,
|
|
82
84
|
source_size=None,
|
|
83
85
|
scaling="d",
|
|
@@ -96,6 +98,15 @@ class MHAttention(nn.Module):
|
|
|
96
98
|
if causal:
|
|
97
99
|
assert seq_len is not None
|
|
98
100
|
|
|
101
|
+
self.talking_heads = talking_heads
|
|
102
|
+
|
|
103
|
+
if self.talking_heads:
|
|
104
|
+
self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
105
|
+
self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
106
|
+
else:
|
|
107
|
+
self.head_projection = None
|
|
108
|
+
self.sample_projection = None
|
|
109
|
+
|
|
99
110
|
self.embed_dim = embed_dim
|
|
100
111
|
self.n_heads = n_heads
|
|
101
112
|
assert embed_dim % n_heads == 0
|
|
@@ -243,7 +254,7 @@ class MHAttention(nn.Module):
|
|
|
243
254
|
|
|
244
255
|
q, k, v = self.project_qkv(q, k, v)
|
|
245
256
|
|
|
246
|
-
if FLASH_ATTN:
|
|
257
|
+
if FLASH_ATTN and not self.talking_heads:
|
|
247
258
|
# Divide Q/K/V into heads
|
|
248
259
|
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
249
260
|
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
@@ -271,12 +282,22 @@ class MHAttention(nn.Module):
|
|
|
271
282
|
|
|
272
283
|
qk_scores *= self.scaling_factor
|
|
273
284
|
|
|
285
|
+
if self.talking_heads:
|
|
286
|
+
qk_scores = torch.einsum(
|
|
287
|
+
"b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
|
|
288
|
+
)
|
|
289
|
+
|
|
274
290
|
# Apply mask if causal (must come before softmax)
|
|
275
291
|
if self.causal:
|
|
276
292
|
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
277
293
|
|
|
278
294
|
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
279
295
|
|
|
296
|
+
if self.talking_heads:
|
|
297
|
+
qk_scores = torch.einsum(
|
|
298
|
+
"b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
|
|
299
|
+
)
|
|
300
|
+
|
|
280
301
|
qk_scores = self.dropout(qk_scores)
|
|
281
302
|
|
|
282
303
|
output_with_heads = qk_scores @ v
|
|
@@ -310,6 +331,10 @@ class MHAttention(nn.Module):
|
|
|
310
331
|
self.k_proj.reset_parameters()
|
|
311
332
|
self.v_proj.reset_parameters()
|
|
312
333
|
self.out_proj.reset_parameters()
|
|
334
|
+
if self.talking_heads:
|
|
335
|
+
# Initialize close to identity
|
|
336
|
+
nn.init.eye_(self.head_projection.weight)
|
|
337
|
+
nn.init.eye_(self.sample_projection.weight)
|
|
313
338
|
|
|
314
339
|
|
|
315
340
|
class FeedforwardBlock(nn.Module):
|
|
@@ -411,7 +436,7 @@ class FeedforwardBlock(nn.Module):
|
|
|
411
436
|
# Recycle weights if using recycling linear layers
|
|
412
437
|
if self.training and self.recycling_enabled:
|
|
413
438
|
indices = self.linear_out.get_reset_indices(1)
|
|
414
|
-
self.linear_in.reset_rows(indices)
|
|
439
|
+
self.linear_in.reset_rows(indices, incoming_data=x)
|
|
415
440
|
self.linear_out.reset_columns(indices)
|
|
416
441
|
|
|
417
442
|
if self.checkpoint:
|
|
@@ -453,6 +478,7 @@ class TransformerBlock(nn.Module):
|
|
|
453
478
|
relative_position_embedding=False,
|
|
454
479
|
source_size=None,
|
|
455
480
|
utility_tokens=0,
|
|
481
|
+
talking_heads=False,
|
|
456
482
|
mlp_ratio=4,
|
|
457
483
|
activation: nn.Module = nn.ReLU,
|
|
458
484
|
activation_kwargs: Optional[dict] = None,
|
|
@@ -513,6 +539,7 @@ class TransformerBlock(nn.Module):
|
|
|
513
539
|
rotary_embedding=self.rotary_embedding,
|
|
514
540
|
source_size=source_size,
|
|
515
541
|
utility_tokens=utility_tokens,
|
|
542
|
+
talking_heads=talking_heads,
|
|
516
543
|
scaling=msa_scaling,
|
|
517
544
|
)
|
|
518
545
|
|
|
@@ -616,6 +643,7 @@ class TransformerEncoder(nn.Module):
|
|
|
616
643
|
causal=False,
|
|
617
644
|
linear_module=nn.Linear,
|
|
618
645
|
utility_tokens=0,
|
|
646
|
+
talking_heads=False,
|
|
619
647
|
return_utility_tokens=False,
|
|
620
648
|
pre_norm=True,
|
|
621
649
|
post_norm=False,
|
|
@@ -644,6 +672,13 @@ class TransformerEncoder(nn.Module):
|
|
|
644
672
|
)
|
|
645
673
|
|
|
646
674
|
super().__init__()
|
|
675
|
+
|
|
676
|
+
if FLASH_ATTN and talking_heads:
|
|
677
|
+
warnings.warn(
|
|
678
|
+
"Using talking heads currently prevents using flash attention.",
|
|
679
|
+
stacklevel=2,
|
|
680
|
+
)
|
|
681
|
+
|
|
647
682
|
self.seq_len = seq_len
|
|
648
683
|
self.n_heads = n_heads
|
|
649
684
|
self._utility_tokens = utility_tokens
|
|
@@ -695,6 +730,7 @@ class TransformerEncoder(nn.Module):
|
|
|
695
730
|
relative_position_embedding=relative_position_embedding,
|
|
696
731
|
source_size=source_size,
|
|
697
732
|
utility_tokens=utility_tokens,
|
|
733
|
+
talking_heads=talking_heads,
|
|
698
734
|
mlp_ratio=mlp_ratio,
|
|
699
735
|
activation=activation,
|
|
700
736
|
activation_kwargs=activation_kwargs,
|
broccoli/vit.py
CHANGED
|
@@ -174,6 +174,7 @@ class ViTEncoder(nn.Module):
|
|
|
174
174
|
transformer_heads=4,
|
|
175
175
|
transformer_mlp_ratio=2,
|
|
176
176
|
transformer_utility_tokens=0,
|
|
177
|
+
transformer_talking_heads=False,
|
|
177
178
|
transformer_return_utility_tokens=False,
|
|
178
179
|
transformer_activation: nn.Module = SquaredReLU,
|
|
179
180
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
@@ -345,6 +346,7 @@ class ViTEncoder(nn.Module):
|
|
|
345
346
|
causal=False,
|
|
346
347
|
linear_module=linear_module,
|
|
347
348
|
utility_tokens=transformer_utility_tokens,
|
|
349
|
+
talking_heads=transformer_talking_heads,
|
|
348
350
|
return_utility_tokens=transformer_return_utility_tokens,
|
|
349
351
|
pre_norm=transformer_pre_norm,
|
|
350
352
|
normformer=transformer_normformer,
|
|
@@ -472,6 +474,7 @@ class ViT(nn.Module):
|
|
|
472
474
|
transformer_heads=4,
|
|
473
475
|
transformer_mlp_ratio=2,
|
|
474
476
|
transformer_utility_tokens=0,
|
|
477
|
+
transformer_talking_heads=False,
|
|
475
478
|
transformer_return_utility_tokens=False,
|
|
476
479
|
transformer_activation: nn.Module = SquaredReLU,
|
|
477
480
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
@@ -543,6 +546,7 @@ class ViT(nn.Module):
|
|
|
543
546
|
transformer_heads=transformer_heads,
|
|
544
547
|
transformer_mlp_ratio=transformer_mlp_ratio,
|
|
545
548
|
transformer_utility_tokens=transformer_utility_tokens,
|
|
549
|
+
transformer_talking_heads=transformer_talking_heads,
|
|
546
550
|
transformer_return_utility_tokens=transformer_return_utility_tokens,
|
|
547
551
|
transformer_activation=transformer_activation,
|
|
548
552
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
|
|
2
2
|
broccoli/activation.py,sha256=nrpTOrpg9k23_E4AJWy7VlXXAJCtCJCOR-TonEWJr04,3218
|
|
3
3
|
broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
4
|
-
broccoli/linear.py,sha256=
|
|
4
|
+
broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=
|
|
7
|
+
broccoli/transformer.py,sha256=24nBBGBjtW0DuSxA0_FJew_pZwbSwT6iM2S6_wjjaCI,26871
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
|
-
broccoli/vit.py,sha256=
|
|
10
|
-
broccoli_ml-9.
|
|
11
|
-
broccoli_ml-9.
|
|
12
|
-
broccoli_ml-9.
|
|
13
|
-
broccoli_ml-9.
|
|
9
|
+
broccoli/vit.py,sha256=DvVpayMIcUhH7Xg6CiYyeedUuuMHrjsGxEdXfnTGa_Q,22428
|
|
10
|
+
broccoli_ml-9.7.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-9.7.0.dist-info/METADATA,sha256=If-lv5EbXZt24tbFQcWtrLHa99e6uxZaM4G-daINb40,1368
|
|
12
|
+
broccoli_ml-9.7.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-9.7.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|