broccoli-ml 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/linear.py CHANGED
@@ -1,9 +1,56 @@
1
1
  # UNDER CONSTRUCTION
2
2
 
3
+ import math
3
4
  import torch
4
5
  from torch import nn
5
6
  from torch.nn import functional as F
6
7
 
8
+ from .tensor import SigmaReparamTensor
9
+
10
+
11
+ class SpectralNormLinear(nn.Module):
12
+ """
13
+ ...
14
+ """
15
+
16
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
17
+ super().__init__()
18
+ self.in_features = in_features
19
+ self.out_features = out_features
20
+ self.use_bias = bias
21
+
22
+ self.weights = None
23
+
24
+ self.weight_init = nn.Parameter(torch.empty(out_features, in_features))
25
+
26
+ # Define the bias vector as a learnable parameter if required.
27
+ if self.use_bias:
28
+ self.bias = nn.Parameter(torch.empty(out_features))
29
+ else:
30
+ # If no bias, register it as None.
31
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
32
+ self.register_parameter("bias", None)
33
+
34
+ self.reset_parameters()
35
+
36
+ def reset_parameters(self) -> None:
37
+ nn.init.kaiming_uniform_(self.weight_init, a=math.sqrt(5))
38
+ if self.use_bias:
39
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)
40
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
41
+ nn.init.uniform_(self.bias, -bound, bound)
42
+ self.weights = SigmaReparamTensor(self.weight_init)
43
+
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ return F.linear(x, self.weights(), self.bias)
46
+
47
+ def __repr__(self) -> str:
48
+ # Optional: A nice representation for printing the module.
49
+ return (
50
+ f"SpectralNormFeedForward(in_features={self.in_features}",
51
+ f"out_features={self.out_features}, bias={self.use_bias})",
52
+ )
53
+
7
54
 
8
55
  class RandomLinear(nn.Linear):
9
56
  """ """
broccoli/transformer.py CHANGED
@@ -10,6 +10,7 @@ import torch.nn.functional as F
10
10
  from einops import rearrange
11
11
 
12
12
  from .rope import RotaryEmbedding, apply_rotary_emb
13
+ from .linear import SpectralNormLinear
13
14
 
14
15
 
15
16
  class MHAttention(nn.Module):
@@ -235,6 +236,7 @@ class FeedforwardLayer(nn.Module):
235
236
  activation_kwargs=None,
236
237
  dropout=0.0,
237
238
  linear_module=nn.Linear,
239
+ norm_memory=False,
238
240
  ):
239
241
  super().__init__()
240
242
 
@@ -245,19 +247,28 @@ class FeedforwardLayer(nn.Module):
245
247
 
246
248
  self.dropout = nn.Dropout(dropout)
247
249
 
250
+ self.max_features = (
251
+ 2 * ratio * output_features
252
+ if activation.__name__.endswith("GLU")
253
+ else ratio * output_features
254
+ )
255
+
256
+ if norm_memory:
257
+ self.memory_type = SpectralNormLinear
258
+ self.bias_memories = False
259
+ else:
260
+ self.memory_type = linear_module
261
+ self.bias_memories = True
262
+
248
263
  self.process = nn.Sequential(
249
264
  *[
250
265
  nn.LayerNorm(input_features),
251
- linear_module(
252
- input_features,
253
- (
254
- 2 * ratio * input_features
255
- if activation.__name__.endswith("GLU")
256
- else ratio * input_features
257
- ),
258
- ),
266
+ linear_module(input_features, self.max_features),
259
267
  self.activation,
260
- linear_module(ratio * input_features, output_features),
268
+ nn.LayerNorm(self.max_features),
269
+ self.memory_type(
270
+ ratio * output_features, output_features, bias=self.bias_memories
271
+ ),
261
272
  self.dropout,
262
273
  ]
263
274
  )
broccoli/vit.py CHANGED
@@ -82,6 +82,7 @@ class ViTEncoder(nn.Module):
82
82
  pooling_kernel_stride=2,
83
83
  pooling_padding=1,
84
84
  intermediate_feedforward_layer=True,
85
+ norm_intermediate_ff_memory=True,
85
86
  transformer_position_embedding="relative", # absolute or relative
86
87
  transformer_embedding_size=256,
87
88
  transformer_layers=7,
@@ -263,6 +264,7 @@ class ViTEncoder(nn.Module):
263
264
  activation_kwargs=transformer_activation_kwargs,
264
265
  dropout=transformer_mlp_dropout,
265
266
  linear_module=linear_module,
267
+ norm_memory=norm_intermediate_ff_memory,
266
268
  )
267
269
  elif pooling_out_channels < transformer_embedding_size:
268
270
  self.intermediate_feedforward_layer = nn.Identity()
@@ -326,6 +328,7 @@ class CCT(nn.Module):
326
328
  pooling_kernel_stride=2,
327
329
  pooling_padding=1,
328
330
  intermediate_feedforward_layer=True,
331
+ norm_intermediate_ff_memory=True,
329
332
  transformer_position_embedding="relative", # absolute or relative
330
333
  transformer_embedding_size=256,
331
334
  transformer_layers=7,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.10.0
3
+ Version: 0.12.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -5,13 +5,13 @@ broccoli/assets/cifar100_eigenvectors_size_2.pt,sha256=DjXDOXMeuMpIqNuGhX9z-OWYV
5
5
  broccoli/assets/cifar100_eigenvectors_size_3.pt,sha256=gL6k0xtXYiYP6ZSvEiMBdJ7kIkT0AngTpDJHFQqwgxA,7173
6
6
  broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
7
7
  broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
- broccoli/linear.py,sha256=0XYCi3ckTEKwAgBOMUSJP2HsnrroOH8eyrhRdpANG2w,1298
8
+ broccoli/linear.py,sha256=9ZwqC6kkgkr0uPoEjdi_Uq1QFHb4wCXzuU1r2pDreXM,2910
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
11
- broccoli/transformer.py,sha256=RSZpbHs_K4ts5os6lWxcGDI3p0zreRwQNnk6mV8HJnk,15930
11
+ broccoli/transformer.py,sha256=niooSyrG9kZPk60IUPa-ZevEGaUQ8MI6AUxuOInAoqc,16265
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=uXqMIvAVY4PuA-Fv1YxU7L3_74fR19GtNu1caQeMr6k,15185
14
- broccoli_ml-0.10.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.10.0.dist-info/METADATA,sha256=zRP8Hn-Q0uMDvNhInEuwcm-sVwG5zE3vsWZBSuVqDmE,1257
16
- broccoli_ml-0.10.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.10.0.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=J1-59oROlpbIEZ-grrXMhuluM3cLoJIAvixnIRSmgOs,15326
14
+ broccoli_ml-0.12.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.12.0.dist-info/METADATA,sha256=ApiCIlFOB8rrdnUwLXgYv_0GPsQx85L3Uq2rUh3JhjQ,1257
16
+ broccoli_ml-0.12.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.12.0.dist-info/RECORD,,