broccoli-ml 9.1.0__tar.gz → 9.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.1.0
3
+ Version: 9.2.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -340,6 +340,7 @@ class FeedforwardBlock(nn.Module):
340
340
  self.checkpoint = checkpoint
341
341
  self.residual_path = residual_path
342
342
  self.post_norm = post_norm
343
+ self.xglu = activation.__name__.endswith("GLU")
343
344
 
344
345
  if self.residual_path and (output_features < input_features):
345
346
  raise ValueError(
@@ -364,27 +365,67 @@ class FeedforwardBlock(nn.Module):
364
365
  )
365
366
 
366
367
  self.max_features = (
367
- 2 * ratio * output_features
368
- if activation.__name__.endswith("GLU")
369
- else ratio * output_features
368
+ 2 * ratio * output_features if self.xglu else ratio * output_features
370
369
  )
371
370
 
371
+ self.linear_in = linear_module_up(input_features, self.max_features)
372
+ self.linear_out = linear_module_down(ratio * output_features, output_features)
373
+
372
374
  self.process = nn.Sequential(
373
375
  *[
374
376
  nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
375
- linear_module_up(input_features, self.max_features),
377
+ self.linear_in,
376
378
  self.activation,
377
379
  self.inner_dropout,
378
380
  nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
379
- linear_module_down(ratio * output_features, output_features),
381
+ self.linear_out,
380
382
  self.outer_dropout,
381
383
  ]
382
384
  )
383
385
 
386
+ self.recycling_enabled = False
387
+ if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
388
+ self.linear_out, "column_recycling_rate"
389
+ ):
390
+ self.recycling_enabled = True
391
+ self.master_recycling_rate = self.linear_in.row_recycling_rate
392
+ self.linear_in.row_recycling_rate = 0.0
393
+ self.linear_out.column_recycling_rate = 0.0
394
+ if (
395
+ hasattr(self.linear_in, "column_recycling_rate")
396
+ and self.linear_in.column_recycling_rate > 0
397
+ ) or (
398
+ hasattr(self.linear_out, "row_recycling_rate")
399
+ and self.linear_out.row_recycling_rate > 0
400
+ ):
401
+ raise NotImplementedError(
402
+ "At the moment this layer can only support recycling linear "
403
+ "layers if the in layer resets only rows and the out layer "
404
+ "resets only columns."
405
+ )
406
+
384
407
  self.reset_parameters()
385
408
 
386
409
  def forward(self, x):
387
410
 
411
+ # Recycle weights if using recycling linear layers
412
+ if self.training and self.recycling_enabled:
413
+ multiplier = self.linear_in._get_multiplier()
414
+ rate = self.master_recycling_rate * multiplier
415
+ if rate > 0:
416
+ probs = torch.rand(self.linear_out.in_features, device=x.device)
417
+ mask = probs < rate
418
+ if mask.any():
419
+ indices = torch.nonzero(mask).squeeze(-1)
420
+ self.linear_out.reset_columns(indices, self.linear_out.optimisers)
421
+ if self.xglu:
422
+ indices_in = torch.cat(
423
+ [indices, indices + self.linear_out.in_features]
424
+ )
425
+ self.linear_in.reset_rows(indices_in, self.linear_in.optimisers)
426
+ else:
427
+ self.linear_in.reset_rows(indices, self.linear_in.optimisers)
428
+
388
429
  if self.checkpoint:
389
430
  processed = checkpoint(self.process, x, use_reentrant=False)
390
431
  else:
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "9.1.0"
3
+ version = "9.2.1"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes
File without changes