broccoli-ml 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/tensor.py CHANGED
@@ -77,25 +77,25 @@ class AnchoredReparamTensor(nn.Module):
77
77
 
78
78
  super().__init__()
79
79
 
80
- self.nondecay_weight = nn.Parameter(init_tensor, requires_grad=True)
80
+ self.weight = nn.Parameter(init_tensor, requires_grad=True)
81
81
 
82
82
  with torch.no_grad():
83
- _, sigma, v_transpose = torch.linalg.svd(
84
- self.nondecay_weight, full_matrices=False
85
- )
83
+ _, sigma, v_transpose = torch.linalg.svd(self.weight, full_matrices=False)
86
84
 
87
85
  self.register_buffer("rayleigh_norm", sigma[:1])
88
86
  self.register_buffer("initial_right_singular", v_transpose[0])
89
- self.scale = nn.Parameter(sigma[:1].clone().detach(), requires_grad=True)
87
+ self.nondecay_scale = nn.Parameter(
88
+ sigma[:1].clone().detach(), requires_grad=True
89
+ )
90
90
 
91
91
  def _update_rayleigh_norm(self):
92
92
  with torch.no_grad():
93
- product = self.nondecay_weight.mv(self.initial_right_singular)
93
+ product = self.weight.mv(self.initial_right_singular)
94
94
  normed_product = F.normalize(product, dim=0)
95
95
  rayleigh_norm = torch.einsum(
96
96
  "m,mn,n->",
97
97
  normed_product,
98
- self.nondecay_weight,
98
+ self.weight,
99
99
  self.initial_right_singular,
100
100
  )
101
101
  self.rayleigh_norm.data.copy_(rayleigh_norm)
@@ -103,7 +103,7 @@ class AnchoredReparamTensor(nn.Module):
103
103
  def forward(self):
104
104
  if self.training:
105
105
  self._update_rayleigh_norm()
106
- return self.scale * (self.nondecay_weight / (self.rayleigh_norm + 1e-6))
106
+ return self.nondecay_scale * (self.weight / (self.rayleigh_norm + 1e-6))
107
107
 
108
108
 
109
109
  class NormReparamTensor(nn.Module):
@@ -118,11 +118,11 @@ class NormReparamTensor(nn.Module):
118
118
 
119
119
  # Use the gradboard convention of calling something nondecay_* if we should
120
120
  # exclude it from weight decay
121
- self.nondecay_weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
122
- self.scale = nn.Parameter(
123
- torch.linalg.norm(self.nondecay_weight).clone().detach(), requires_grad=True
121
+ self.weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
122
+ self.nondecay_scale = nn.Parameter(
123
+ torch.linalg.norm(self.weight).clone().detach(), requires_grad=True
124
124
  )
125
125
 
126
126
  def forward(self) -> torch.Tensor:
127
- norm = torch.linalg.norm(self.nondecay_weight)
128
- return self.scale * (self.nondecay_weight / (norm + 1e-6))
127
+ norm = torch.linalg.norm(self.weight)
128
+ return self.nondecay_scale * (self.weight / (norm + 1e-6))
broccoli/transformer.py CHANGED
@@ -235,7 +235,8 @@ class FeedforwardBlock(nn.Module):
235
235
  activation=nn.ReLU,
236
236
  activation_kwargs=None,
237
237
  dropout=0.0,
238
- linear_module=nn.Linear,
238
+ linear_module_up=nn.Linear,
239
+ linear_module_down=nn.Linear,
239
240
  pre_norm=True,
240
241
  normformer=False,
241
242
  post_norm=True,
@@ -265,10 +266,10 @@ class FeedforwardBlock(nn.Module):
265
266
  self.process = nn.Sequential(
266
267
  *[
267
268
  nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
268
- linear_module(input_features, self.max_features),
269
+ linear_module_up(input_features, self.max_features),
269
270
  self.activation,
270
271
  nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
271
- linear_module(ratio * output_features, output_features),
272
+ linear_module_down(ratio * output_features, output_features),
272
273
  self.dropout,
273
274
  ]
274
275
  )
@@ -346,7 +347,7 @@ class TransformerBlock(nn.Module):
346
347
  bos_tokens=bos_tokens,
347
348
  )
348
349
 
349
- # Submodules for the feedforward process
350
+ # Submodule for the feedforward process
350
351
  self.ff = FeedforwardBlock(
351
352
  d_model,
352
353
  mlp_ratio,
@@ -354,7 +355,8 @@ class TransformerBlock(nn.Module):
354
355
  activation=activation,
355
356
  activation_kwargs=activation_kwargs,
356
357
  dropout=mlp_dropout,
357
- linear_module=linear_module,
358
+ linear_module_up=linear_module,
359
+ linear_module_down=linear_module,
358
360
  pre_norm=pre_norm,
359
361
  normformer=normformer,
360
362
  post_norm=post_norm,
broccoli/vit.py CHANGED
@@ -118,7 +118,8 @@ class ViTEncoder(nn.Module):
118
118
  pooling_padding=1,
119
119
  transformer_feedforward_first=True,
120
120
  transformer_initial_ff_residual_path=True,
121
- transformer_initial_ff_linear_module=None,
121
+ transformer_initial_ff_linear_module_up=None,
122
+ transformer_initial_ff_linear_module_down=None,
122
123
  transformer_pre_norm=True,
123
124
  transformer_normformer=False,
124
125
  transformer_post_norm=False,
@@ -296,6 +297,7 @@ class ViTEncoder(nn.Module):
296
297
  return_bos_tokens=transformer_return_bos_tokens,
297
298
  pre_norm=transformer_pre_norm,
298
299
  normformer=transformer_normformer,
300
+ post_norm=transformer_post_norm,
299
301
  )
300
302
  else:
301
303
  self.transformer = nn.Identity()
@@ -308,13 +310,19 @@ class ViTEncoder(nn.Module):
308
310
  activation=transformer_activation,
309
311
  activation_kwargs=transformer_activation_kwargs,
310
312
  dropout=transformer_mlp_dropout,
311
- linear_module=(
312
- transformer_initial_ff_linear_module
313
- if transformer_initial_ff_linear_module is not None
313
+ linear_module_up=(
314
+ transformer_initial_ff_linear_module_up
315
+ if transformer_initial_ff_linear_module_up is not None
316
+ else linear_module
317
+ ),
318
+ linear_module_down=(
319
+ transformer_initial_ff_linear_module_down
320
+ if transformer_initial_ff_linear_module_down is not None
314
321
  else linear_module
315
322
  ),
316
323
  pre_norm=transformer_pre_norm,
317
324
  normformer=transformer_normformer,
325
+ post_norm=transformer_post_norm,
318
326
  residual_path=transformer_initial_ff_residual_path,
319
327
  )
320
328
  else:
@@ -370,7 +378,8 @@ class ViT(nn.Module):
370
378
  pooling_padding=1,
371
379
  transformer_feedforward_first=True,
372
380
  transformer_initial_ff_residual_path=True,
373
- transformer_initial_ff_linear_module=None,
381
+ transformer_initial_ff_linear_module_up=None,
382
+ transformer_initial_ff_linear_module_down=None,
374
383
  transformer_pre_norm=True,
375
384
  transformer_normformer=False,
376
385
  transformer_post_norm=False,
@@ -429,7 +438,8 @@ class ViT(nn.Module):
429
438
  pooling_padding=pooling_padding,
430
439
  transformer_feedforward_first=transformer_feedforward_first,
431
440
  transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
432
- transformer_initial_ff_linear_module=transformer_initial_ff_linear_module,
441
+ transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
442
+ transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
433
443
  transformer_pre_norm=transformer_pre_norm,
434
444
  transformer_normformer=transformer_normformer,
435
445
  transformer_post_norm=transformer_post_norm,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.34.0
3
+ Version: 0.35.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -7,11 +7,11 @@ broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
7
7
  broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
8
  broccoli/linear.py,sha256=8Y9vD85ZEgNZsIQgO3uRQ3lOQR-JjwvabY8liCrfNCk,4831
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
- broccoli/tensor.py,sha256=ks2TRCdS10k2XvxEieh2sj_LzjTNRuiO6gekKFTtziI,4533
11
- broccoli/transformer.py,sha256=aF452hNruRRBZCqBqb0FM-C2tHH0sB5Fan937Plh7Cc,17106
10
+ broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
11
+ broccoli/transformer.py,sha256=t0gsADJC9UOlwjm7tDKdy0pAZ8l3clTcCnes86zvH-k,17203
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=HEdG9PPra3295Gv7jwzJbtiaRhJMp8QKFwv_dUzF67o,16827
14
- broccoli_ml-0.34.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.34.0.dist-info/METADATA,sha256=Fzkgbyz9r3VyBLP90AD_Vfv9DyDORiEoKsi0ipJDJ8w,1257
16
- broccoli_ml-0.34.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.34.0.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=c-ZRHiLDOoQDJO9OJ51zD9HqaluG33flIwTXQQfms-g,17389
14
+ broccoli_ml-0.35.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.35.1.dist-info/METADATA,sha256=5pQA45ytAkkn0F5il2zuSN0vY7hFtVJvyUi9MXF-0EA,1257
16
+ broccoli_ml-0.35.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.35.1.dist-info/RECORD,,