broccoli-ml 9.5.1__py3-none-any.whl → 9.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/linear.py CHANGED
@@ -193,7 +193,12 @@ class RecyclingLinear(nn.Module):
193
193
  multipliers = [a / b for a, b in pairs if b != 0.0]
194
194
  return min(multipliers) if multipliers else 0.0
195
195
 
196
- def reset_rows(self, indices):
196
+ def reset_rows(self, indices, incoming_data=None):
197
+ """
198
+ Resets rows.
199
+ If incoming_data is provided, resets to the centroid (mean) of that data.
200
+ If not, resets to the mean of existing weights.
201
+ """
197
202
  if not torch.is_tensor(indices):
198
203
  idx_tensor = torch.as_tensor(
199
204
  list(indices), dtype=torch.long, device=self.linear.weight.device
@@ -201,24 +206,24 @@ class RecyclingLinear(nn.Module):
201
206
  else:
202
207
  idx_tensor = indices
203
208
 
204
- if idx_tensor.size(0):
205
- value_indices = indices
206
- centred_value_weights = self._mean_value_weights()
207
- centred_value_weights = centred_value_weights.expand(indices.size(0), -1)
208
- if self.xglu:
209
- gate_indices = indices
210
- value_indices = indices + (self.linear.out_features // 2)
211
- centred_gate_weights = self._mean_gate_weights()
212
- centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
213
- self._update_weights(
214
- gate_indices, 0, centred_gate_weights, self.optimisers # dim
215
- )
216
- self._update_weights(
217
- value_indices, 0, centred_value_weights, self.optimisers
218
- )
219
- else:
209
+ if idx_tensor.numel() == 0:
220
210
  return
221
211
 
212
+ if incoming_data is not None:
213
+ target_center = self._mean_input_weights(incoming_data)
214
+ else:
215
+ target_center = self._mean_value_weights()
216
+
217
+ target_center = target_center.expand(idx_tensor.size(0), -1)
218
+
219
+ if self.xglu:
220
+ gate_indices = idx_tensor
221
+ value_indices = idx_tensor + (self.linear.out_features // 2)
222
+ self._update_weights(gate_indices, 0, target_center, self.optimisers)
223
+ self._update_weights(value_indices, 0, target_center, self.optimisers)
224
+ else:
225
+ self._update_weights(idx_tensor, 0, target_center, self.optimisers)
226
+
222
227
  def reset_columns(self, indices):
223
228
  if not torch.is_tensor(indices):
224
229
  idx_tensor = torch.as_tensor(
@@ -281,6 +286,17 @@ class RecyclingLinear(nn.Module):
281
286
  random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
282
287
  return random_weights
283
288
 
289
+ def _mean_input_weights(self, input):
290
+ reduce_dims = list(range(input.ndim - 1))
291
+ data_mean = input.detach().mean(dim=reduce_dims, keepdim=True)
292
+
293
+ weights = self.linear.weight.data
294
+ stdv = 1.0 / math.sqrt(weights.size(1))
295
+ data_norm = data_mean.std() + 1e-6
296
+ scale_factor = stdv / data_norm
297
+
298
+ return data_mean * scale_factor
299
+
284
300
  def _mean_value_weights(self):
285
301
  """
286
302
  Only used when self.xglu
broccoli/transformer.py CHANGED
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import math
2
3
  from typing import Optional, Tuple
3
4
 
@@ -78,6 +79,7 @@ class MHAttention(nn.Module):
78
79
  seq_len=None,
79
80
  linear_module: nn.Module = nn.Linear,
80
81
  utility_tokens=0,
82
+ talking_heads=False,
81
83
  rotary_embedding=None,
82
84
  source_size=None,
83
85
  scaling="d",
@@ -96,6 +98,15 @@ class MHAttention(nn.Module):
96
98
  if causal:
97
99
  assert seq_len is not None
98
100
 
101
+ self.talking_heads = talking_heads
102
+
103
+ if self.talking_heads:
104
+ self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
105
+ self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
106
+ else:
107
+ self.head_projection = None
108
+ self.sample_projection = None
109
+
99
110
  self.embed_dim = embed_dim
100
111
  self.n_heads = n_heads
101
112
  assert embed_dim % n_heads == 0
@@ -243,7 +254,7 @@ class MHAttention(nn.Module):
243
254
 
244
255
  q, k, v = self.project_qkv(q, k, v)
245
256
 
246
- if FLASH_ATTN:
257
+ if FLASH_ATTN and not self.talking_heads:
247
258
  # Divide Q/K/V into heads
248
259
  q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
249
260
  k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
@@ -271,12 +282,22 @@ class MHAttention(nn.Module):
271
282
 
272
283
  qk_scores *= self.scaling_factor
273
284
 
285
+ if self.talking_heads:
286
+ qk_scores = torch.einsum(
287
+ "b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
288
+ )
289
+
274
290
  # Apply mask if causal (must come before softmax)
275
291
  if self.causal:
276
292
  qk_scores.masked_fill_(self.mask, float("-inf"))
277
293
 
278
294
  qk_scores = F.softmax(qk_scores, dim=-1)
279
295
 
296
+ if self.talking_heads:
297
+ qk_scores = torch.einsum(
298
+ "b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
299
+ )
300
+
280
301
  qk_scores = self.dropout(qk_scores)
281
302
 
282
303
  output_with_heads = qk_scores @ v
@@ -310,6 +331,10 @@ class MHAttention(nn.Module):
310
331
  self.k_proj.reset_parameters()
311
332
  self.v_proj.reset_parameters()
312
333
  self.out_proj.reset_parameters()
334
+ if self.talking_heads:
335
+ # Initialize close to identity
336
+ nn.init.eye_(self.head_projection.weight)
337
+ nn.init.eye_(self.sample_projection.weight)
313
338
 
314
339
 
315
340
  class FeedforwardBlock(nn.Module):
@@ -411,7 +436,7 @@ class FeedforwardBlock(nn.Module):
411
436
  # Recycle weights if using recycling linear layers
412
437
  if self.training and self.recycling_enabled:
413
438
  indices = self.linear_out.get_reset_indices(1)
414
- self.linear_in.reset_rows(indices)
439
+ self.linear_in.reset_rows(indices, incoming_data=x)
415
440
  self.linear_out.reset_columns(indices)
416
441
 
417
442
  if self.checkpoint:
@@ -453,6 +478,7 @@ class TransformerBlock(nn.Module):
453
478
  relative_position_embedding=False,
454
479
  source_size=None,
455
480
  utility_tokens=0,
481
+ talking_heads=False,
456
482
  mlp_ratio=4,
457
483
  activation: nn.Module = nn.ReLU,
458
484
  activation_kwargs: Optional[dict] = None,
@@ -513,6 +539,7 @@ class TransformerBlock(nn.Module):
513
539
  rotary_embedding=self.rotary_embedding,
514
540
  source_size=source_size,
515
541
  utility_tokens=utility_tokens,
542
+ talking_heads=talking_heads,
516
543
  scaling=msa_scaling,
517
544
  )
518
545
 
@@ -616,6 +643,7 @@ class TransformerEncoder(nn.Module):
616
643
  causal=False,
617
644
  linear_module=nn.Linear,
618
645
  utility_tokens=0,
646
+ talking_heads=False,
619
647
  return_utility_tokens=False,
620
648
  pre_norm=True,
621
649
  post_norm=False,
@@ -644,6 +672,13 @@ class TransformerEncoder(nn.Module):
644
672
  )
645
673
 
646
674
  super().__init__()
675
+
676
+ if FLASH_ATTN and talking_heads:
677
+ warnings.warn(
678
+ "Using talking heads currently prevents using flash attention.",
679
+ stacklevel=2,
680
+ )
681
+
647
682
  self.seq_len = seq_len
648
683
  self.n_heads = n_heads
649
684
  self._utility_tokens = utility_tokens
@@ -695,6 +730,7 @@ class TransformerEncoder(nn.Module):
695
730
  relative_position_embedding=relative_position_embedding,
696
731
  source_size=source_size,
697
732
  utility_tokens=utility_tokens,
733
+ talking_heads=talking_heads,
698
734
  mlp_ratio=mlp_ratio,
699
735
  activation=activation,
700
736
  activation_kwargs=activation_kwargs,
broccoli/vit.py CHANGED
@@ -174,6 +174,7 @@ class ViTEncoder(nn.Module):
174
174
  transformer_heads=4,
175
175
  transformer_mlp_ratio=2,
176
176
  transformer_utility_tokens=0,
177
+ transformer_talking_heads=False,
177
178
  transformer_return_utility_tokens=False,
178
179
  transformer_activation: nn.Module = SquaredReLU,
179
180
  transformer_activation_kwargs: Optional[dict] = None,
@@ -345,6 +346,7 @@ class ViTEncoder(nn.Module):
345
346
  causal=False,
346
347
  linear_module=linear_module,
347
348
  utility_tokens=transformer_utility_tokens,
349
+ talking_heads=transformer_talking_heads,
348
350
  return_utility_tokens=transformer_return_utility_tokens,
349
351
  pre_norm=transformer_pre_norm,
350
352
  normformer=transformer_normformer,
@@ -472,6 +474,7 @@ class ViT(nn.Module):
472
474
  transformer_heads=4,
473
475
  transformer_mlp_ratio=2,
474
476
  transformer_utility_tokens=0,
477
+ transformer_talking_heads=False,
475
478
  transformer_return_utility_tokens=False,
476
479
  transformer_activation: nn.Module = SquaredReLU,
477
480
  transformer_activation_kwargs: Optional[dict] = None,
@@ -543,6 +546,7 @@ class ViT(nn.Module):
543
546
  transformer_heads=transformer_heads,
544
547
  transformer_mlp_ratio=transformer_mlp_ratio,
545
548
  transformer_utility_tokens=transformer_utility_tokens,
549
+ transformer_talking_heads=transformer_talking_heads,
546
550
  transformer_return_utility_tokens=transformer_return_utility_tokens,
547
551
  transformer_activation=transformer_activation,
548
552
  transformer_activation_kwargs=transformer_activation_kwargs,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.5.1
3
+ Version: 9.7.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -1,13 +1,13 @@
1
1
  broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
2
2
  broccoli/activation.py,sha256=nrpTOrpg9k23_E4AJWy7VlXXAJCtCJCOR-TonEWJr04,3218
3
3
  broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
- broccoli/linear.py,sha256=i4U7ZC4ZWEH82YpDasx0Qs1pc3gkyL-3ajuyKCbsGTM,12649
4
+ broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
5
5
  broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
6
  broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
- broccoli/transformer.py,sha256=ULk-QQX3hAI14-aCKhp9QSebzX4KUjlisEGup2Eycck,25565
7
+ broccoli/transformer.py,sha256=24nBBGBjtW0DuSxA0_FJew_pZwbSwT6iM2S6_wjjaCI,26871
8
8
  broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
- broccoli/vit.py,sha256=sC6K3FK3a8ojOgvNWSWhuZHBtnFrrTQbsDdlagcKJH4,22224
10
- broccoli_ml-9.5.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
- broccoli_ml-9.5.1.dist-info/METADATA,sha256=HXRWnuc_-Gs_g37_RP3-POTLmi7sZamlzYv5SJEun1Y,1368
12
- broccoli_ml-9.5.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
- broccoli_ml-9.5.1.dist-info/RECORD,,
9
+ broccoli/vit.py,sha256=DvVpayMIcUhH7Xg6CiYyeedUuuMHrjsGxEdXfnTGa_Q,22428
10
+ broccoli_ml-9.7.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-9.7.0.dist-info/METADATA,sha256=If-lv5EbXZt24tbFQcWtrLHa99e6uxZaM4G-daINb40,1368
12
+ broccoli_ml-9.7.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-9.7.0.dist-info/RECORD,,