broccoli-ml 0.26.0__tar.gz → 0.28.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.26.0
3
+ Version: 0.28.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -238,6 +238,7 @@ class FeedforwardBlock(nn.Module):
238
238
  linear_module=nn.Linear,
239
239
  pre_norm=True,
240
240
  normformer=False,
241
+ raw_input=False,
241
242
  ):
242
243
  super().__init__()
243
244
 
@@ -246,6 +247,11 @@ class FeedforwardBlock(nn.Module):
246
247
  else:
247
248
  self.activation = activation()
248
249
 
250
+ if raw_input:
251
+ self.memory_type = AnchoredLinear
252
+ else:
253
+ self.memory_type = nn.Linear
254
+
249
255
  self.dropout = nn.Dropout(dropout)
250
256
 
251
257
  self.max_features = (
@@ -260,7 +266,7 @@ class FeedforwardBlock(nn.Module):
260
266
  linear_module(input_features, self.max_features),
261
267
  self.activation,
262
268
  nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
263
- linear_module(ratio * output_features, output_features),
269
+ self.memory_type(ratio * output_features, output_features),
264
270
  self.dropout,
265
271
  ]
266
272
  )
@@ -371,9 +377,11 @@ class TransformerBlock(nn.Module):
371
377
  process_x = process_x + self.attn(process_x, process_x, process_x)
372
378
  norm_process_x = self.layer_norm_1(process_x)
373
379
  process_x = process_x + self.ff(process_x)
374
- x = self.layer_norm_2(
375
- torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
376
- )
380
+
381
+ # Always post norm as eventually we reach the classification head!
382
+ x = self.layer_norm_2(
383
+ torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
384
+ )
377
385
 
378
386
  return x
379
387
 
@@ -308,6 +308,7 @@ class ViTEncoder(nn.Module):
308
308
  linear_module=linear_module,
309
309
  pre_norm=transformer_pre_norm,
310
310
  normformer=transformer_normformer,
311
+ raw_input=not cnn,
311
312
  )
312
313
  else:
313
314
  self.initial_ff = nn.Identity()
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "0.26.0"
3
+ version = "0.28.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes