broccoli-ml 0.24.3__tar.gz → 0.26.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.24.3
3
+ Version: 0.26.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -5,7 +5,7 @@ import torch
5
5
  from torch import nn
6
6
  from torch.nn import functional as F
7
7
 
8
- from .tensor import SigmaReparamTensor, AnchoredReparamTensor
8
+ from .tensor import SigmaReparamTensor, AnchoredReparamTensor, NormReparamTensor
9
9
 
10
10
 
11
11
  class SpectralNormLinear(nn.Module):
@@ -93,3 +93,46 @@ class AnchoredLinear(nn.Module):
93
93
  f"AnchoredLinear(in_features={self.in_features},"
94
94
  f"out_features={self.out_features}, bias={self.use_bias})"
95
95
  )
96
+
97
+
98
+ class ReparamLinear(nn.Module):
99
+ """
100
+ ...
101
+ """
102
+
103
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
104
+ super().__init__()
105
+ self.in_features = in_features
106
+ self.out_features = out_features
107
+ self.use_bias = bias
108
+
109
+ self.weights = None
110
+
111
+ # Define the bias vector as a learnable parameter if required.
112
+ if self.use_bias:
113
+ self.bias = nn.Parameter(torch.empty(out_features))
114
+ else:
115
+ # If no bias, register it as None.
116
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
117
+ self.register_parameter("bias", None)
118
+
119
+ self.reset_parameters()
120
+
121
+ def reset_parameters(self) -> None:
122
+ weights = torch.empty(self.out_features, self.in_features)
123
+ nn.init.kaiming_uniform_(weights, a=math.sqrt(5))
124
+ if self.use_bias:
125
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
126
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
127
+ nn.init.uniform_(self.bias, -bound, bound)
128
+ self.weights = NormReparamTensor(weights)
129
+
130
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
131
+ return F.linear(x, self.weights(), self.bias)
132
+
133
+ def __repr__(self) -> str:
134
+ # Optional: A nice representation for printing the module.
135
+ return (
136
+ f"AnchoredLinear(in_features={self.in_features},"
137
+ f"out_features={self.out_features}, bias={self.use_bias})"
138
+ )
@@ -100,3 +100,24 @@ class AnchoredReparamTensor(nn.Module):
100
100
 
101
101
  # Return the reparameterized tensor.
102
102
  return self.scale * (self.nondecay_weight / (norm + 1e-6))
103
+
104
+
105
+ class NormReparamTensor(nn.Module):
106
+ """
107
+ Reparameterise a tensor as a normalised tensor of weights multiplied by a
108
+ learnable scaling factor.
109
+ """
110
+
111
+ def __init__(self, init_tensor: torch.Tensor):
112
+ assert init_tensor.ndim == 2, "Input tensor must be a 2D matrix."
113
+ super().__init__()
114
+
115
+ # Use the gradboard convention of calling something nondecay_* if we should
116
+ # exclude it from weight decay
117
+ self.nondecay_weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
118
+ self.scale = nn.Parameter(
119
+ torch.linalg.norm(self.nondecay_weight).clone().detach(), requires_grad=True
120
+ )
121
+
122
+ def forward(self) -> torch.Tensor:
123
+ return self.scale * F.normalize(self.nondecay_weight)
@@ -236,7 +236,8 @@ class FeedforwardBlock(nn.Module):
236
236
  activation_kwargs=None,
237
237
  dropout=0.0,
238
238
  linear_module=nn.Linear,
239
- raw_input=False,
239
+ pre_norm=True,
240
+ normformer=False,
240
241
  ):
241
242
  super().__init__()
242
243
 
@@ -253,19 +254,13 @@ class FeedforwardBlock(nn.Module):
253
254
  else ratio * output_features
254
255
  )
255
256
 
256
- if raw_input:
257
- self.memory_type = AnchoredLinear
258
-
259
- else:
260
- self.memory_type = linear_module
261
-
262
257
  self.process = nn.Sequential(
263
258
  *[
264
- nn.LayerNorm(input_features),
259
+ nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
265
260
  linear_module(input_features, self.max_features),
266
261
  self.activation,
267
- # nn.LayerNorm(ratio * output_features) if raw_input else nn.Identity(),
268
- self.memory_type(ratio * output_features, output_features),
262
+ nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
263
+ linear_module(ratio * output_features, output_features),
269
264
  self.dropout,
270
265
  ]
271
266
  )
@@ -299,12 +294,17 @@ class TransformerBlock(nn.Module):
299
294
  identity_probability=0.0,
300
295
  causal=False,
301
296
  linear_module=nn.Linear,
297
+ pre_norm=True,
298
+ normformer=False,
302
299
  ):
303
300
  super().__init__()
304
301
 
302
+ self.pre_norm = pre_norm
303
+
305
304
  self.identity_probability = identity_probability
306
305
 
307
- self.layer_norm = nn.LayerNorm(d_model)
306
+ self.layer_norm_1 = nn.LayerNorm(d_model)
307
+ self.layer_norm_2 = nn.LayerNorm(d_model)
308
308
 
309
309
  if position_embedding_type == "relative":
310
310
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
@@ -339,6 +339,8 @@ class TransformerBlock(nn.Module):
339
339
  activation_kwargs=activation_kwargs,
340
340
  dropout=mlp_dropout,
341
341
  linear_module=linear_module,
342
+ pre_norm=pre_norm,
343
+ normformer=normformer,
342
344
  )
343
345
 
344
346
  @property
@@ -359,12 +361,19 @@ class TransformerBlock(nn.Module):
359
361
  identity_x = shuffled[:identity_count, :, :]
360
362
  process_x = shuffled[identity_count:, :, :]
361
363
 
362
- norm_process_x = self.layer_norm(process_x)
363
- process_x = process_x + self.attn(
364
- norm_process_x, norm_process_x, norm_process_x
365
- )
366
- process_x = process_x + self.ff(process_x)
367
- x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
364
+ if self.pre_norm:
365
+ norm_process_x = self.layer_norm_1(process_x)
366
+ process_x = process_x + self.attn(
367
+ norm_process_x, norm_process_x, norm_process_x
368
+ )
369
+ process_x = process_x + self.ff(process_x)
370
+ else: # post-norm
371
+ process_x = process_x + self.attn(process_x, process_x, process_x)
372
+ norm_process_x = self.layer_norm_1(process_x)
373
+ process_x = process_x + self.ff(process_x)
374
+ x = self.layer_norm_2(
375
+ torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
376
+ )
368
377
 
369
378
  return x
370
379
 
@@ -393,6 +402,8 @@ class TransformerEncoder(nn.Module):
393
402
  linear_module=nn.Linear,
394
403
  bos_tokens=0,
395
404
  return_bos_tokens=False,
405
+ pre_norm=True,
406
+ normformer=False,
396
407
  ):
397
408
  if position_embedding_type == "relative":
398
409
  assert source_size is not None # TODO: make this a proper exception
@@ -451,6 +462,8 @@ class TransformerEncoder(nn.Module):
451
462
  identity_probability=self.stochastic_depth_probabilities[i],
452
463
  causal=causal,
453
464
  linear_module=linear_module,
465
+ pre_norm=pre_norm,
466
+ normformer=normformer,
454
467
  )
455
468
  for i in range(n_layers)
456
469
  ]
@@ -117,6 +117,8 @@ class ViTEncoder(nn.Module):
117
117
  pooling_kernel_stride=2,
118
118
  pooling_padding=1,
119
119
  transformer_feedforward_first=True,
120
+ transformer_pre_norm=True,
121
+ transformer_normformer=False,
120
122
  transformer_position_embedding="relative", # absolute or relative
121
123
  transformer_embedding_size=256,
122
124
  transformer_layers=7,
@@ -289,6 +291,8 @@ class ViTEncoder(nn.Module):
289
291
  linear_module=linear_module,
290
292
  bos_tokens=transformer_bos_tokens,
291
293
  return_bos_tokens=transformer_return_bos_tokens,
294
+ pre_norm=transformer_pre_norm,
295
+ normformer=transformer_normformer,
292
296
  )
293
297
  else:
294
298
  self.transformer = nn.Identity()
@@ -302,7 +306,8 @@ class ViTEncoder(nn.Module):
302
306
  activation_kwargs=transformer_activation_kwargs,
303
307
  dropout=transformer_mlp_dropout,
304
308
  linear_module=linear_module,
305
- raw_input=not cnn,
309
+ pre_norm=transformer_pre_norm,
310
+ normformer=transformer_normformer,
306
311
  )
307
312
  else:
308
313
  self.initial_ff = nn.Identity()
@@ -356,6 +361,8 @@ class ViT(nn.Module):
356
361
  pooling_kernel_stride=2,
357
362
  pooling_padding=1,
358
363
  transformer_feedforward_first=True,
364
+ transformer_pre_norm=True,
365
+ transformer_normformer=False,
359
366
  transformer_position_embedding="relative", # absolute or relative
360
367
  transformer_embedding_size=256,
361
368
  transformer_layers=7,
@@ -410,6 +417,8 @@ class ViT(nn.Module):
410
417
  pooling_kernel_stride=pooling_kernel_stride,
411
418
  pooling_padding=pooling_padding,
412
419
  transformer_feedforward_first=transformer_feedforward_first,
420
+ transformer_pre_norm=transformer_pre_norm,
421
+ transformer_normformer=transformer_normformer,
413
422
  transformer_position_embedding=transformer_position_embedding,
414
423
  transformer_embedding_size=transformer_embedding_size,
415
424
  transformer_layers=transformer_layers,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "0.24.3"
3
+ version = "0.26.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes