broccoli-ml 13.0.5__py3-none-any.whl → 14.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -200,26 +200,26 @@ class MHAttention(nn.Module):
200
200
  "`source_size` must be a tuple of 1, 2 or 3 integers"
201
201
  )
202
202
 
203
- q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
204
- k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
203
+ q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
204
+ k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
205
205
 
206
206
  q_util, q_img = (
207
- q[:, : self.utility_tokens, :, :],
208
- q[:, self.utility_tokens :, :, :],
207
+ q[:, :, : self.utility_tokens, :],
208
+ q[:, :, self.utility_tokens :, :],
209
209
  )
210
210
  k_util, k_img = (
211
- k[:, : self.utility_tokens, :, :],
212
- k[:, self.utility_tokens :, :, :],
211
+ k[:, :, : self.utility_tokens, :],
212
+ k[:, :, self.utility_tokens :, :],
213
213
  )
214
214
 
215
215
  q_img = rearrange(
216
216
  q_img,
217
- f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
217
+ f"b h ({spatial_dimension_names}) d -> b h {spatial_dimension_names} d",
218
218
  **spatial_dimension_values,
219
219
  )
220
220
  k_img = rearrange(
221
221
  k_img,
222
- f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
222
+ f"b h ({spatial_dimension_names}) d -> b h {spatial_dimension_names} d",
223
223
  **spatial_dimension_values,
224
224
  )
225
225
 
@@ -230,19 +230,19 @@ class MHAttention(nn.Module):
230
230
 
231
231
  q_img = rearrange(
232
232
  q_img,
233
- f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
233
+ f"b h {spatial_dimension_names} d -> b h ({spatial_dimension_names}) d",
234
234
  )
235
235
  k_img = rearrange(
236
236
  k_img,
237
- f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
237
+ f"b h {spatial_dimension_names} d -> b h ({spatial_dimension_names}) d",
238
238
  )
239
239
 
240
240
  # Re-combine the utility tokens and the RoPE-enhanced sequence tokens
241
- q = torch.cat([q_util, q_img], dim=1)
242
- k = torch.cat([k_util, k_img], dim=1)
241
+ q = torch.cat([q_util, q_img], dim=2)
242
+ k = torch.cat([k_util, k_img], dim=2)
243
243
 
244
- q = rearrange(q, "b t h d -> b t (h d)")
245
- k = rearrange(k, "b t h d -> b t (h d)")
244
+ q = rearrange(q, "b h t d -> b t (h d)")
245
+ k = rearrange(k, "b h t d -> b t (h d)")
246
246
 
247
247
  return q, k
248
248
 
@@ -621,6 +621,7 @@ class EncoderBlock(nn.Module):
621
621
 
622
622
  if self.post_norm:
623
623
  x = self.post_attention_norm(x)
624
+ process_x = x
624
625
  elif self.pre_norm:
625
626
  process_x = self.pre_mlp_norm(x)
626
627
  else:
@@ -638,15 +639,15 @@ class EncoderBlock(nn.Module):
638
639
  def attention_logits(self, x):
639
640
  """
640
641
  Give back the attention scores used in this layer.
642
+ Needs to match what the model actually sees during forward()
643
+ by applying the correct normalisations.
641
644
  """
642
- # Fix: Use the correct attribute name 'pre_attention_norm'
643
645
  if self.pre_norm:
644
- # We must normalize the input before measuring attention logits
645
- # to match what the model actually sees during forward()
646
646
  x = self.pre_attention_norm(x)
647
- return self.attn.attention_logits(x, x, x)
648
- else:
649
- return self.attn.attention_logits(x, x, x)
647
+ elif self.post_norm:
648
+ x = self.input_norm(x)
649
+
650
+ return self.attn.attention_logits(x, x, x)
650
651
 
651
652
  def reset_parameters(self):
652
653
  if self.pre_norm:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 13.0.5
3
+ Version: 14.0.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
4
  broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
5
5
  broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
6
  broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
- broccoli/transformer.py,sha256=-rVhSl5yWDbEbjLysSpjiEy0h01E33TGdtgBnd-zRgA,27952
7
+ broccoli/transformer.py,sha256=eSh_oyVhQad0rI0YPt1617OtH_EPn8XuNYj6BR6zRSM,27875
8
8
  broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
9
  broccoli/vit.py,sha256=SYAHGuGVZRa-zYCLK5siKdUauDgHtMYUmQ6W2tUIqKA,22588
10
- broccoli_ml-13.0.5.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
- broccoli_ml-13.0.5.dist-info/METADATA,sha256=JM7ILFgCPGW7lkP_vWfet9KdNMz6xpRcNlChvzQA5xU,1369
12
- broccoli_ml-13.0.5.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
- broccoli_ml-13.0.5.dist-info/RECORD,,
10
+ broccoli_ml-14.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-14.0.0.dist-info/METADATA,sha256=XAj9Kyl3rn7tH3_Tm1CpnkfQeacS2w6ymjw4_AH73D4,1369
12
+ broccoli_ml-14.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-14.0.0.dist-info/RECORD,,