broccoli-ml 9.1.0__tar.gz → 9.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.1.0
3
+ Version: 9.2.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -340,6 +340,7 @@ class FeedforwardBlock(nn.Module):
340
340
  self.checkpoint = checkpoint
341
341
  self.residual_path = residual_path
342
342
  self.post_norm = post_norm
343
+ self.xglu = activation.__name__.endswith("GLU")
343
344
 
344
345
  if self.residual_path and (output_features < input_features):
345
346
  raise ValueError(
@@ -364,27 +365,63 @@ class FeedforwardBlock(nn.Module):
364
365
  )
365
366
 
366
367
  self.max_features = (
367
- 2 * ratio * output_features
368
- if activation.__name__.endswith("GLU")
369
- else ratio * output_features
368
+ 2 * ratio * output_features if self.xglu else ratio * output_features
370
369
  )
371
370
 
371
+ self.linear_in = linear_module_up(input_features, self.max_features)
372
+ self.linear_out = linear_module_down(ratio * output_features, output_features)
373
+
372
374
  self.process = nn.Sequential(
373
375
  *[
374
376
  nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
375
- linear_module_up(input_features, self.max_features),
377
+ self.linear_in,
376
378
  self.activation,
377
379
  self.inner_dropout,
378
380
  nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
379
- linear_module_down(ratio * output_features, output_features),
381
+ self.linear_out,
380
382
  self.outer_dropout,
381
383
  ]
382
384
  )
383
385
 
386
+ self.recycling_enabled = False
387
+ if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
388
+ self.linear_out, "column_recycling_rate"
389
+ ):
390
+ self.recycling_enabled = True
391
+ self.master_recycling_rate = self.linear_in.row_recycling_rate
392
+ self.linear_in.row_recycling_rate = 0.0
393
+ self.linear_out.column_recycling_rate = 0.0
394
+ if hasattr(self.linear_in, "column_recycling_rate") or hasattr(
395
+ self.linear_out, "row_recycling_rate"
396
+ ):
397
+ raise NotImplementedError(
398
+ "At the moment this layer can only support recycling linear "
399
+ "layers if the in layer resets only rows and the out layer "
400
+ "resets only columns."
401
+ )
402
+
384
403
  self.reset_parameters()
385
404
 
386
405
  def forward(self, x):
387
406
 
407
+ # Recycle weights if using recycling linear layers
408
+ if self.training and self.recycling_enabled:
409
+ multiplier = self.linear_in._get_multiplier()
410
+ rate = self.master_recycling_rate * multiplier
411
+ if rate > 0:
412
+ probs = torch.rand(self.linear_out.in_features, device=x.device)
413
+ mask = probs < rate
414
+ if mask.any():
415
+ indices = torch.nonzero(mask).squeeze(-1)
416
+ self.linear_out.reset_columns(indices, self.linear_out.optimisers)
417
+ if self.xglu:
418
+ indices_in = torch.cat(
419
+ [indices, indices + self.linear_out.in_features]
420
+ )
421
+ self.linear_in.reset_rows(indices_in, self.linear_in.optimisers)
422
+ else:
423
+ self.linear_in.reset_rows(indices, self.linear_in.optimisers)
424
+
388
425
  if self.checkpoint:
389
426
  processed = checkpoint(self.process, x, use_reentrant=False)
390
427
  else:
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "9.1.0"
3
+ version = "9.2.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes
File without changes