broccoli-ml 6.0.1__py3-none-any.whl → 9.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/linear.py CHANGED
@@ -1,4 +1,7 @@
1
1
  import math
2
+ import random
3
+ from typing import Union, List, Iterable
4
+
2
5
  import torch
3
6
  from torch import nn
4
7
  from torch.nn import functional as F
@@ -136,3 +139,136 @@ class WeightNormedLinear(nn.Module):
136
139
  f"WeightNormedLinear(in_features={self.in_features},"
137
140
  f"out_features={self.out_features}, bias={self.use_bias})"
138
141
  )
142
+
143
+
144
+ class RecyclingLinear(nn.Module):
145
+ def __init__(
146
+ self,
147
+ in_features: int,
148
+ out_features: int,
149
+ bias: bool = True,
150
+ row_recycling_rate: float = 0.0,
151
+ column_recycling_rate: float = 0.0,
152
+ ):
153
+ super().__init__()
154
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
155
+ self.row_recycling_rate = row_recycling_rate
156
+ self.column_recycling_rate = column_recycling_rate
157
+ self.optimisers = []
158
+
159
+ def register_optimiser(self, optimiser: torch.optim.Optimizer):
160
+ self.optimisers.append(optimiser)
161
+
162
+ def forward(self, x):
163
+ if self.training and self.optimisers:
164
+
165
+ if self.row_recycling_rate > 0:
166
+ probs = torch.rand(self.linear.out_features, device=x.device)
167
+ mask = probs < self.row_recycling_rate
168
+ if mask.any():
169
+ # nonzero returns [N, 1], squeeze to get [N]
170
+ indices = torch.nonzero(mask).squeeze(-1)
171
+ self.reset_rows(indices, self.optimisers)
172
+
173
+ if self.column_recycling_rate > 0:
174
+ probs = torch.rand(self.linear.in_features, device=x.device)
175
+ mask = probs < self.column_recycling_rate
176
+ if mask.any():
177
+ indices = torch.nonzero(mask).squeeze(-1)
178
+ self.reset_columns(indices, self.optimisers)
179
+
180
+ return self.linear(x)
181
+
182
+ def reset_rows(
183
+ self,
184
+ indices: Iterable[int],
185
+ optimisers: Union[
186
+ List[torch.optim.Optimizer], torch.optim.Optimizer, None
187
+ ] = None,
188
+ ):
189
+ """
190
+ Update some of the weight rows to be equal to the mean of all weight rows.
191
+ """
192
+ if optimisers is None:
193
+ optimisers = []
194
+ if not isinstance(optimisers, list):
195
+ optimisers = [optimisers]
196
+
197
+ device = self.linear.weight.device
198
+ idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
199
+
200
+ if idx_tensor.numel() == 0:
201
+ return
202
+
203
+ with torch.no_grad():
204
+ # Calculate mean of all rows including the rows to be reset
205
+ mean_vector = self.linear.weight.data.mean(
206
+ dim=0, keepdim=True
207
+ ) # [1, in_features]
208
+ update_data = mean_vector.expand(idx_tensor.size(0), -1)
209
+ self.linear.weight.data[idx_tensor] = update_data
210
+
211
+ if self.linear.bias is not None:
212
+ self.linear.bias.data[idx_tensor] = 0.0
213
+
214
+ self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=0)
215
+ if self.linear.bias is not None:
216
+ self._reset_optim_state(self.linear.bias, idx_tensor, optimisers, dim=0)
217
+
218
+ def reset_columns(
219
+ self,
220
+ indices: Iterable[int],
221
+ optimisers: Union[
222
+ List[torch.optim.Optimizer], torch.optim.Optimizer, None
223
+ ] = None,
224
+ ):
225
+ """
226
+ Update some of the weight columns to be random as though reinitialised.
227
+ """
228
+ if optimisers is None:
229
+ optimisers = []
230
+ if not isinstance(optimisers, list):
231
+ optimisers = [optimisers]
232
+
233
+ device = self.linear.weight.device
234
+ idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
235
+
236
+ if idx_tensor.numel() == 0:
237
+ return
238
+
239
+ with torch.no_grad():
240
+ # 1. Generate Random Columns
241
+ # Shape: [out_features, N_indices]
242
+ weights = self.linear.weight.data
243
+ stdv = 1.0 / math.sqrt(weights.size(1))
244
+
245
+ # Generate [Rows, N] block
246
+ random_weights = torch.rand(
247
+ weights.size(0), idx_tensor.size(0), device=device
248
+ )
249
+ random_weights = (random_weights - 0.5) * 2.0 * stdv
250
+
251
+ # 2. Update Weights (One-shot)
252
+ # We assign into the columns specified by idx_tensor
253
+ self.linear.weight.data[:, idx_tensor] = random_weights
254
+
255
+ # 3. Update Optimizers
256
+ # Bias is untouched by column resets (bias is shape [Out], cols are [In])
257
+ self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=1)
258
+
259
+ def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
260
+ """
261
+ Zeroes out the optimizer state for the given indices in a single operation.
262
+ """
263
+ for optimiser in optimisers:
264
+ if param not in optimiser.state:
265
+ continue
266
+ state = optimiser.state[param]
267
+
268
+ for _, buffer in state.items():
269
+ if torch.is_tensor(buffer) and buffer.shape == param.shape:
270
+ # Vectorized zeroing
271
+ if dim == 0:
272
+ buffer[idx_tensor] = 0.0
273
+ else:
274
+ buffer[:, idx_tensor] = 0.0
broccoli/transformer.py CHANGED
@@ -325,6 +325,8 @@ class FeedforwardBlock(nn.Module):
325
325
  activation=nn.ReLU,
326
326
  activation_kwargs=None,
327
327
  dropout=0.0,
328
+ inner_dropout=None,
329
+ outer_dropout=None,
328
330
  linear_module_up=nn.Linear,
329
331
  linear_module_down=nn.Linear,
330
332
  pre_norm=True,
@@ -354,7 +356,12 @@ class FeedforwardBlock(nn.Module):
354
356
  else:
355
357
  self.activation = activation()
356
358
 
357
- self.dropout = nn.Dropout(dropout)
359
+ self.inner_dropout = nn.Dropout(
360
+ inner_dropout if inner_dropout is not None else dropout
361
+ )
362
+ self.outer_dropout = nn.Dropout(
363
+ outer_dropout if outer_dropout is not None else dropout
364
+ )
358
365
 
359
366
  self.max_features = (
360
367
  2 * ratio * output_features
@@ -367,9 +374,10 @@ class FeedforwardBlock(nn.Module):
367
374
  nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
368
375
  linear_module_up(input_features, self.max_features),
369
376
  self.activation,
377
+ self.inner_dropout,
370
378
  nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
371
379
  linear_module_down(ratio * output_features, output_features),
372
- self.dropout,
380
+ self.outer_dropout,
373
381
  ]
374
382
  )
375
383
 
@@ -422,7 +430,9 @@ class TransformerBlock(nn.Module):
422
430
  ff_linear_module_up=None,
423
431
  ff_linear_module_down=None,
424
432
  msa_scaling="d",
425
- mlp_dropout=0.0,
433
+ ff_dropout=0.0,
434
+ ff_inner_dropout=0.0,
435
+ ff_outer_dropout=0.0,
426
436
  msa_dropout=0.0,
427
437
  identity_probability=0.0,
428
438
  causal=False,
@@ -484,7 +494,9 @@ class TransformerBlock(nn.Module):
484
494
  d_model,
485
495
  activation=activation,
486
496
  activation_kwargs=activation_kwargs,
487
- dropout=mlp_dropout,
497
+ dropout=ff_dropout,
498
+ inner_dropout=ff_inner_dropout,
499
+ outer_dropout=ff_outer_dropout,
488
500
  linear_module_up=(
489
501
  ff_linear_module_up
490
502
  if ff_linear_module_up is not None
@@ -567,7 +579,9 @@ class TransformerEncoder(nn.Module):
567
579
  activation_kwargs: Optional[dict] = None,
568
580
  ff_linear_module_up=None,
569
581
  ff_linear_module_down=None,
570
- mlp_dropout=0.0,
582
+ ff_dropout=0.0,
583
+ ff_inner_dropout=0.0,
584
+ ff_outer_dropout=0.0,
571
585
  msa_dropout=0.0,
572
586
  stochastic_depth=0.0,
573
587
  causal=False,
@@ -591,7 +605,13 @@ class TransformerEncoder(nn.Module):
591
605
  if relative_position_embedding and (source_size is None):
592
606
  raise ValueError(
593
607
  "`source_size` for TransformerEncoder cannot be None if"
594
- " `position_embedding_type` is relative"
608
+ " `relative_position_embedding` is True"
609
+ )
610
+
611
+ if absolute_position_embedding and (seq_len is None):
612
+ raise ValueError(
613
+ "`seq_len` for TransformerEncoder cannot be None if"
614
+ " `absolute_position_embedding` is True"
595
615
  )
596
616
 
597
617
  super().__init__()
@@ -606,9 +626,12 @@ class TransformerEncoder(nn.Module):
606
626
  torch.empty(self._utility_tokens, d_model)
607
627
  )
608
628
  nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
609
- self.full_sequence_length = self.seq_len + self._utility_tokens
610
629
  else:
611
630
  self._utility_token_embedding = None
631
+
632
+ if self._utility_tokens and (self.seq_len is not None):
633
+ self.full_sequence_length = self.seq_len + self._utility_tokens
634
+ else:
612
635
  self.full_sequence_length = self.seq_len
613
636
 
614
637
  self.d_model = d_model
@@ -620,7 +643,7 @@ class TransformerEncoder(nn.Module):
620
643
  else:
621
644
  self.absolute_position_embedding = None
622
645
 
623
- self.mlp_dropout = mlp_dropout
646
+ self.mlp_dropout = ff_dropout
624
647
  self.msa_dropout = msa_dropout
625
648
  self.stochastic_depth = stochastic_depth
626
649
 
@@ -649,7 +672,9 @@ class TransformerEncoder(nn.Module):
649
672
  ff_linear_module_up=ff_linear_module_up,
650
673
  ff_linear_module_down=ff_linear_module_down,
651
674
  msa_scaling=msa_scaling,
652
- mlp_dropout=mlp_dropout,
675
+ ff_dropout=ff_dropout,
676
+ ff_inner_dropout=ff_inner_dropout,
677
+ ff_outer_dropout=ff_outer_dropout,
653
678
  msa_dropout=msa_dropout,
654
679
  identity_probability=self.stochastic_depth_probabilities[i],
655
680
  causal=causal,
broccoli/vit.py CHANGED
@@ -161,7 +161,9 @@ class ViTEncoder(nn.Module):
161
161
  transformer_initial_ff_residual_path=True,
162
162
  transformer_initial_ff_linear_module_up=None,
163
163
  transformer_initial_ff_linear_module_down=None,
164
- transformer_initial_ff_mlp_dropout=None,
164
+ transformer_initial_ff_dropout=None,
165
+ transformer_initial_ff_inner_dropout=None,
166
+ transformer_initial_ff_outer_dropout=None,
165
167
  transformer_pre_norm=True,
166
168
  transformer_normformer=False,
167
169
  transformer_post_norm=False,
@@ -178,7 +180,9 @@ class ViTEncoder(nn.Module):
178
180
  transformer_ff_linear_module_up=None,
179
181
  transformer_ff_linear_module_down=None,
180
182
  transformer_msa_scaling="d",
181
- transformer_mlp_dropout=0.0,
183
+ transformer_ff_dropout=0.0,
184
+ transformer_ff_inner_dropout=0.0,
185
+ transformer_ff_outer_dropout=0.0,
182
186
  transformer_msa_dropout=0.1,
183
187
  transformer_stochastic_depth=0.1,
184
188
  transformer_checkpoint_ff=True,
@@ -333,7 +337,9 @@ class ViTEncoder(nn.Module):
333
337
  ff_linear_module_up=transformer_ff_linear_module_up,
334
338
  ff_linear_module_down=transformer_ff_linear_module_down,
335
339
  msa_scaling=transformer_msa_scaling,
336
- mlp_dropout=transformer_mlp_dropout,
340
+ ff_dropout=transformer_ff_dropout,
341
+ ff_inner_dropout=transformer_ff_inner_dropout,
342
+ ff_outer_dropout=transformer_ff_outer_dropout,
337
343
  msa_dropout=transformer_msa_dropout,
338
344
  stochastic_depth=transformer_stochastic_depth,
339
345
  causal=False,
@@ -357,9 +363,21 @@ class ViTEncoder(nn.Module):
357
363
  activation_kwargs=transformer_activation_kwargs,
358
364
  dropout=(
359
365
  # First truthy assigned value
360
- transformer_initial_ff_mlp_dropout
361
- if transformer_initial_ff_mlp_dropout is not None
362
- else transformer_mlp_dropout
366
+ transformer_initial_ff_dropout
367
+ if transformer_initial_ff_dropout is not None
368
+ else transformer_ff_dropout
369
+ ),
370
+ inner_dropout=(
371
+ # First truthy assigned value
372
+ transformer_initial_ff_inner_dropout
373
+ if transformer_initial_ff_inner_dropout is not None
374
+ else transformer_ff_inner_dropout
375
+ ),
376
+ outer_dropout=(
377
+ # First truthy assigned value
378
+ transformer_initial_ff_outer_dropout
379
+ if transformer_initial_ff_outer_dropout is not None
380
+ else transformer_ff_outer_dropout
363
381
  ),
364
382
  linear_module_up=(
365
383
  # First truthy assigned value
@@ -441,7 +459,9 @@ class ViT(nn.Module):
441
459
  transformer_initial_ff_residual_path=True,
442
460
  transformer_initial_ff_linear_module_up=None,
443
461
  transformer_initial_ff_linear_module_down=None,
444
- transformer_initial_ff_mlp_dropout=None,
462
+ transformer_initial_ff_dropout=None,
463
+ transformer_initial_ff_inner_dropout=None,
464
+ transformer_initial_ff_outer_dropout=None,
445
465
  transformer_pre_norm=True,
446
466
  transformer_normformer=False,
447
467
  transformer_post_norm=False,
@@ -458,7 +478,9 @@ class ViT(nn.Module):
458
478
  transformer_ff_linear_module_up=None,
459
479
  transformer_ff_linear_module_down=None,
460
480
  transformer_msa_scaling="d",
461
- transformer_mlp_dropout=0.0,
481
+ transformer_ff_dropout=0.0,
482
+ transformer_ff_inner_dropout=0.0,
483
+ transformer_ff_outer_dropout=0.0,
462
484
  transformer_msa_dropout=0.1,
463
485
  transformer_stochastic_depth=0.1,
464
486
  transformer_checkpoint_ff=True,
@@ -508,7 +530,9 @@ class ViT(nn.Module):
508
530
  transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
509
531
  transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
510
532
  transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
511
- transformer_initial_ff_mlp_dropout=transformer_initial_ff_mlp_dropout,
533
+ transformer_initial_ff_dropout=transformer_initial_ff_dropout,
534
+ transformer_initial_ff_inner_dropout=transformer_initial_ff_inner_dropout,
535
+ transformer_initial_ff_outer_dropout=transformer_initial_ff_outer_dropout,
512
536
  transformer_pre_norm=transformer_pre_norm,
513
537
  transformer_normformer=transformer_normformer,
514
538
  transformer_post_norm=transformer_post_norm,
@@ -525,7 +549,9 @@ class ViT(nn.Module):
525
549
  transformer_ff_linear_module_up=transformer_ff_linear_module_up,
526
550
  transformer_ff_linear_module_down=transformer_ff_linear_module_down,
527
551
  transformer_msa_scaling=transformer_msa_scaling,
528
- transformer_mlp_dropout=transformer_mlp_dropout,
552
+ transformer_ff_dropout=transformer_ff_dropout,
553
+ transformer_ff_inner_dropout=transformer_ff_inner_dropout,
554
+ transformer_ff_outer_dropout=transformer_ff_outer_dropout,
529
555
  transformer_msa_dropout=transformer_msa_dropout,
530
556
  transformer_stochastic_depth=transformer_stochastic_depth,
531
557
  transformer_checkpoint_ff=transformer_checkpoint_ff,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 6.0.1
3
+ Version: 9.0.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -1,13 +1,13 @@
1
1
  broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
2
2
  broccoli/activation.py,sha256=-Jf30C6iGqWCorC9HEGn2oduWwjeaCAxGLUUYIy1zX8,3438
3
3
  broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
- broccoli/linear.py,sha256=Y7s-DzcwsOipRboNHc4HTScw4mJRalNoVFsNcxOB6a4,4872
4
+ broccoli/linear.py,sha256=XaGHZguvK-7hvtIt07zo8uQZBQvS7oMD2K9nPvyYJLE,9769
5
5
  broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
6
  broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
- broccoli/transformer.py,sha256=4Zd_orIXzJQU6jmbiebBqmSZ73GRT4MMKSIu65r7seg,23324
7
+ broccoli/transformer.py,sha256=Rozh0hExHjwGvvKbMeZfLoB95dDKyDn3X6o1Ms26aAI,24241
8
8
  broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
- broccoli/vit.py,sha256=9oyh76ulmX5lDPMCDicQhhqm8RYCvJIgAJkDbYRVdi4,20873
10
- broccoli_ml-6.0.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
- broccoli_ml-6.0.1.dist-info/METADATA,sha256=09PQOiSQWjnNVWE0Iw_R5YUD5ewDBwvurNV8VugM7N0,1368
12
- broccoli_ml-6.0.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
- broccoli_ml-6.0.1.dist-info/RECORD,,
9
+ broccoli/vit.py,sha256=sC6K3FK3a8ojOgvNWSWhuZHBtnFrrTQbsDdlagcKJH4,22224
10
+ broccoli_ml-9.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-9.0.0.dist-info/METADATA,sha256=ecQ2BRxtzmNSO2CMAp2rcNRq9L37urE_pKdsPf-jJKs,1368
12
+ broccoli_ml-9.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-9.0.0.dist-info/RECORD,,