broccoli-ml 0.24.3__tar.gz → 0.25.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.24.3
3
+ Version: 0.25.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -5,7 +5,7 @@ import torch
5
5
  from torch import nn
6
6
  from torch.nn import functional as F
7
7
 
8
- from .tensor import SigmaReparamTensor, AnchoredReparamTensor
8
+ from .tensor import SigmaReparamTensor, AnchoredReparamTensor, NormReparamTensor
9
9
 
10
10
 
11
11
  class SpectralNormLinear(nn.Module):
@@ -93,3 +93,46 @@ class AnchoredLinear(nn.Module):
93
93
  f"AnchoredLinear(in_features={self.in_features},"
94
94
  f"out_features={self.out_features}, bias={self.use_bias})"
95
95
  )
96
+
97
+
98
+ class ReparamLinear(nn.Module):
99
+ """
100
+ ...
101
+ """
102
+
103
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
104
+ super().__init__()
105
+ self.in_features = in_features
106
+ self.out_features = out_features
107
+ self.use_bias = bias
108
+
109
+ self.weights = None
110
+
111
+ # Define the bias vector as a learnable parameter if required.
112
+ if self.use_bias:
113
+ self.bias = nn.Parameter(torch.empty(out_features))
114
+ else:
115
+ # If no bias, register it as None.
116
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
117
+ self.register_parameter("bias", None)
118
+
119
+ self.reset_parameters()
120
+
121
+ def reset_parameters(self) -> None:
122
+ weights = torch.empty(self.out_features, self.in_features)
123
+ nn.init.kaiming_uniform_(weights, a=math.sqrt(5))
124
+ if self.use_bias:
125
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
126
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
127
+ nn.init.uniform_(self.bias, -bound, bound)
128
+ self.weights = NormReparamTensor(weights)
129
+
130
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
131
+ return F.linear(x, self.weights(), self.bias)
132
+
133
+ def __repr__(self) -> str:
134
+ # Optional: A nice representation for printing the module.
135
+ return (
136
+ f"AnchoredLinear(in_features={self.in_features},"
137
+ f"out_features={self.out_features}, bias={self.use_bias})"
138
+ )
@@ -100,3 +100,24 @@ class AnchoredReparamTensor(nn.Module):
100
100
 
101
101
  # Return the reparameterized tensor.
102
102
  return self.scale * (self.nondecay_weight / (norm + 1e-6))
103
+
104
+
105
+ class NormReparamTensor(nn.Module):
106
+ """
107
+ Reparameterise a tensor as a normalised tensor of weights multiplied by a
108
+ learnable scaling factor.
109
+ """
110
+
111
+ def __init__(self, init_tensor: torch.Tensor):
112
+ assert init_tensor.ndim == 2, "Input tensor must be a 2D matrix."
113
+ super().__init__()
114
+
115
+ # Use the gradboard convention of calling something nondecay_* if we should
116
+ # exclude it from weight decay
117
+ self.nondecay_weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
118
+ self.scale = nn.Parameter(
119
+ torch.linalg.norm(self.nondecay_weight).clone().detach(), requires_grad=True
120
+ )
121
+
122
+ def forward(self) -> torch.Tensor:
123
+ return self.scale * F.normalize(self.nondecay_weight)
@@ -236,7 +236,8 @@ class FeedforwardBlock(nn.Module):
236
236
  activation_kwargs=None,
237
237
  dropout=0.0,
238
238
  linear_module=nn.Linear,
239
- raw_input=False,
239
+ pre_norm=True,
240
+ normformer=False,
240
241
  ):
241
242
  super().__init__()
242
243
 
@@ -253,19 +254,13 @@ class FeedforwardBlock(nn.Module):
253
254
  else ratio * output_features
254
255
  )
255
256
 
256
- if raw_input:
257
- self.memory_type = AnchoredLinear
258
-
259
- else:
260
- self.memory_type = linear_module
261
-
262
257
  self.process = nn.Sequential(
263
258
  *[
264
- nn.LayerNorm(input_features),
259
+ nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
265
260
  linear_module(input_features, self.max_features),
266
261
  self.activation,
267
- # nn.LayerNorm(ratio * output_features) if raw_input else nn.Identity(),
268
- self.memory_type(ratio * output_features, output_features),
262
+ nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
263
+ linear_module(ratio * output_features, output_features),
269
264
  self.dropout,
270
265
  ]
271
266
  )
@@ -299,6 +294,8 @@ class TransformerBlock(nn.Module):
299
294
  identity_probability=0.0,
300
295
  causal=False,
301
296
  linear_module=nn.Linear,
297
+ pre_norm=True,
298
+ normformer=False,
302
299
  ):
303
300
  super().__init__()
304
301
 
@@ -339,6 +336,8 @@ class TransformerBlock(nn.Module):
339
336
  activation_kwargs=activation_kwargs,
340
337
  dropout=mlp_dropout,
341
338
  linear_module=linear_module,
339
+ pre_norm=pre_norm,
340
+ normformer=normformer,
342
341
  )
343
342
 
344
343
  @property
@@ -393,6 +392,8 @@ class TransformerEncoder(nn.Module):
393
392
  linear_module=nn.Linear,
394
393
  bos_tokens=0,
395
394
  return_bos_tokens=False,
395
+ pre_norm=True,
396
+ normformer=False,
396
397
  ):
397
398
  if position_embedding_type == "relative":
398
399
  assert source_size is not None # TODO: make this a proper exception
@@ -451,6 +452,8 @@ class TransformerEncoder(nn.Module):
451
452
  identity_probability=self.stochastic_depth_probabilities[i],
452
453
  causal=causal,
453
454
  linear_module=linear_module,
455
+ pre_norm=pre_norm,
456
+ normformer=normformer,
454
457
  )
455
458
  for i in range(n_layers)
456
459
  ]
@@ -117,6 +117,8 @@ class ViTEncoder(nn.Module):
117
117
  pooling_kernel_stride=2,
118
118
  pooling_padding=1,
119
119
  transformer_feedforward_first=True,
120
+ transformer_pre_norm=True,
121
+ transformer_normformer=False,
120
122
  transformer_position_embedding="relative", # absolute or relative
121
123
  transformer_embedding_size=256,
122
124
  transformer_layers=7,
@@ -289,6 +291,8 @@ class ViTEncoder(nn.Module):
289
291
  linear_module=linear_module,
290
292
  bos_tokens=transformer_bos_tokens,
291
293
  return_bos_tokens=transformer_return_bos_tokens,
294
+ pre_norm=transformer_pre_norm,
295
+ normformer=transformer_normformer,
292
296
  )
293
297
  else:
294
298
  self.transformer = nn.Identity()
@@ -356,6 +360,8 @@ class ViT(nn.Module):
356
360
  pooling_kernel_stride=2,
357
361
  pooling_padding=1,
358
362
  transformer_feedforward_first=True,
363
+ transformer_pre_norm=True,
364
+ transformer_normformer=False,
359
365
  transformer_position_embedding="relative", # absolute or relative
360
366
  transformer_embedding_size=256,
361
367
  transformer_layers=7,
@@ -410,6 +416,8 @@ class ViT(nn.Module):
410
416
  pooling_kernel_stride=pooling_kernel_stride,
411
417
  pooling_padding=pooling_padding,
412
418
  transformer_feedforward_first=transformer_feedforward_first,
419
+ transformer_pre_norm=transformer_pre_norm,
420
+ transformer_normformer=transformer_normformer,
413
421
  transformer_position_embedding=transformer_position_embedding,
414
422
  transformer_embedding_size=transformer_embedding_size,
415
423
  transformer_layers=transformer_layers,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "0.24.3"
3
+ version = "0.25.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes