broccoli-ml 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/tensor.py +13 -13
- broccoli/transformer.py +7 -5
- broccoli/vit.py +16 -6
- {broccoli_ml-0.34.0.dist-info → broccoli_ml-0.35.1.dist-info}/METADATA +1 -1
- {broccoli_ml-0.34.0.dist-info → broccoli_ml-0.35.1.dist-info}/RECORD +7 -7
- {broccoli_ml-0.34.0.dist-info → broccoli_ml-0.35.1.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.34.0.dist-info → broccoli_ml-0.35.1.dist-info}/WHEEL +0 -0
broccoli/tensor.py
CHANGED
@@ -77,25 +77,25 @@ class AnchoredReparamTensor(nn.Module):
|
|
77
77
|
|
78
78
|
super().__init__()
|
79
79
|
|
80
|
-
self.
|
80
|
+
self.weight = nn.Parameter(init_tensor, requires_grad=True)
|
81
81
|
|
82
82
|
with torch.no_grad():
|
83
|
-
_, sigma, v_transpose = torch.linalg.svd(
|
84
|
-
self.nondecay_weight, full_matrices=False
|
85
|
-
)
|
83
|
+
_, sigma, v_transpose = torch.linalg.svd(self.weight, full_matrices=False)
|
86
84
|
|
87
85
|
self.register_buffer("rayleigh_norm", sigma[:1])
|
88
86
|
self.register_buffer("initial_right_singular", v_transpose[0])
|
89
|
-
self.
|
87
|
+
self.nondecay_scale = nn.Parameter(
|
88
|
+
sigma[:1].clone().detach(), requires_grad=True
|
89
|
+
)
|
90
90
|
|
91
91
|
def _update_rayleigh_norm(self):
|
92
92
|
with torch.no_grad():
|
93
|
-
product = self.
|
93
|
+
product = self.weight.mv(self.initial_right_singular)
|
94
94
|
normed_product = F.normalize(product, dim=0)
|
95
95
|
rayleigh_norm = torch.einsum(
|
96
96
|
"m,mn,n->",
|
97
97
|
normed_product,
|
98
|
-
self.
|
98
|
+
self.weight,
|
99
99
|
self.initial_right_singular,
|
100
100
|
)
|
101
101
|
self.rayleigh_norm.data.copy_(rayleigh_norm)
|
@@ -103,7 +103,7 @@ class AnchoredReparamTensor(nn.Module):
|
|
103
103
|
def forward(self):
|
104
104
|
if self.training:
|
105
105
|
self._update_rayleigh_norm()
|
106
|
-
return self.
|
106
|
+
return self.nondecay_scale * (self.weight / (self.rayleigh_norm + 1e-6))
|
107
107
|
|
108
108
|
|
109
109
|
class NormReparamTensor(nn.Module):
|
@@ -118,11 +118,11 @@ class NormReparamTensor(nn.Module):
|
|
118
118
|
|
119
119
|
# Use the gradboard convention of calling something nondecay_* if we should
|
120
120
|
# exclude it from weight decay
|
121
|
-
self.
|
122
|
-
self.
|
123
|
-
torch.linalg.norm(self.
|
121
|
+
self.weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
|
122
|
+
self.nondecay_scale = nn.Parameter(
|
123
|
+
torch.linalg.norm(self.weight).clone().detach(), requires_grad=True
|
124
124
|
)
|
125
125
|
|
126
126
|
def forward(self) -> torch.Tensor:
|
127
|
-
norm = torch.linalg.norm(self.
|
128
|
-
return self.
|
127
|
+
norm = torch.linalg.norm(self.weight)
|
128
|
+
return self.nondecay_scale * (self.weight / (norm + 1e-6))
|
broccoli/transformer.py
CHANGED
@@ -235,7 +235,8 @@ class FeedforwardBlock(nn.Module):
|
|
235
235
|
activation=nn.ReLU,
|
236
236
|
activation_kwargs=None,
|
237
237
|
dropout=0.0,
|
238
|
-
|
238
|
+
linear_module_up=nn.Linear,
|
239
|
+
linear_module_down=nn.Linear,
|
239
240
|
pre_norm=True,
|
240
241
|
normformer=False,
|
241
242
|
post_norm=True,
|
@@ -265,10 +266,10 @@ class FeedforwardBlock(nn.Module):
|
|
265
266
|
self.process = nn.Sequential(
|
266
267
|
*[
|
267
268
|
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
268
|
-
|
269
|
+
linear_module_up(input_features, self.max_features),
|
269
270
|
self.activation,
|
270
271
|
nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
|
271
|
-
|
272
|
+
linear_module_down(ratio * output_features, output_features),
|
272
273
|
self.dropout,
|
273
274
|
]
|
274
275
|
)
|
@@ -346,7 +347,7 @@ class TransformerBlock(nn.Module):
|
|
346
347
|
bos_tokens=bos_tokens,
|
347
348
|
)
|
348
349
|
|
349
|
-
#
|
350
|
+
# Submodule for the feedforward process
|
350
351
|
self.ff = FeedforwardBlock(
|
351
352
|
d_model,
|
352
353
|
mlp_ratio,
|
@@ -354,7 +355,8 @@ class TransformerBlock(nn.Module):
|
|
354
355
|
activation=activation,
|
355
356
|
activation_kwargs=activation_kwargs,
|
356
357
|
dropout=mlp_dropout,
|
357
|
-
|
358
|
+
linear_module_up=linear_module,
|
359
|
+
linear_module_down=linear_module,
|
358
360
|
pre_norm=pre_norm,
|
359
361
|
normformer=normformer,
|
360
362
|
post_norm=post_norm,
|
broccoli/vit.py
CHANGED
@@ -118,7 +118,8 @@ class ViTEncoder(nn.Module):
|
|
118
118
|
pooling_padding=1,
|
119
119
|
transformer_feedforward_first=True,
|
120
120
|
transformer_initial_ff_residual_path=True,
|
121
|
-
|
121
|
+
transformer_initial_ff_linear_module_up=None,
|
122
|
+
transformer_initial_ff_linear_module_down=None,
|
122
123
|
transformer_pre_norm=True,
|
123
124
|
transformer_normformer=False,
|
124
125
|
transformer_post_norm=False,
|
@@ -296,6 +297,7 @@ class ViTEncoder(nn.Module):
|
|
296
297
|
return_bos_tokens=transformer_return_bos_tokens,
|
297
298
|
pre_norm=transformer_pre_norm,
|
298
299
|
normformer=transformer_normformer,
|
300
|
+
post_norm=transformer_post_norm,
|
299
301
|
)
|
300
302
|
else:
|
301
303
|
self.transformer = nn.Identity()
|
@@ -308,13 +310,19 @@ class ViTEncoder(nn.Module):
|
|
308
310
|
activation=transformer_activation,
|
309
311
|
activation_kwargs=transformer_activation_kwargs,
|
310
312
|
dropout=transformer_mlp_dropout,
|
311
|
-
|
312
|
-
|
313
|
-
if
|
313
|
+
linear_module_up=(
|
314
|
+
transformer_initial_ff_linear_module_up
|
315
|
+
if transformer_initial_ff_linear_module_up is not None
|
316
|
+
else linear_module
|
317
|
+
),
|
318
|
+
linear_module_down=(
|
319
|
+
transformer_initial_ff_linear_module_down
|
320
|
+
if transformer_initial_ff_linear_module_down is not None
|
314
321
|
else linear_module
|
315
322
|
),
|
316
323
|
pre_norm=transformer_pre_norm,
|
317
324
|
normformer=transformer_normformer,
|
325
|
+
post_norm=transformer_post_norm,
|
318
326
|
residual_path=transformer_initial_ff_residual_path,
|
319
327
|
)
|
320
328
|
else:
|
@@ -370,7 +378,8 @@ class ViT(nn.Module):
|
|
370
378
|
pooling_padding=1,
|
371
379
|
transformer_feedforward_first=True,
|
372
380
|
transformer_initial_ff_residual_path=True,
|
373
|
-
|
381
|
+
transformer_initial_ff_linear_module_up=None,
|
382
|
+
transformer_initial_ff_linear_module_down=None,
|
374
383
|
transformer_pre_norm=True,
|
375
384
|
transformer_normformer=False,
|
376
385
|
transformer_post_norm=False,
|
@@ -429,7 +438,8 @@ class ViT(nn.Module):
|
|
429
438
|
pooling_padding=pooling_padding,
|
430
439
|
transformer_feedforward_first=transformer_feedforward_first,
|
431
440
|
transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
|
432
|
-
|
441
|
+
transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
|
442
|
+
transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
|
433
443
|
transformer_pre_norm=transformer_pre_norm,
|
434
444
|
transformer_normformer=transformer_normformer,
|
435
445
|
transformer_post_norm=transformer_post_norm,
|
@@ -7,11 +7,11 @@ broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
|
|
7
7
|
broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
8
8
|
broccoli/linear.py,sha256=8Y9vD85ZEgNZsIQgO3uRQ3lOQR-JjwvabY8liCrfNCk,4831
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
|
-
broccoli/tensor.py,sha256=
|
11
|
-
broccoli/transformer.py,sha256=
|
10
|
+
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
11
|
+
broccoli/transformer.py,sha256=t0gsADJC9UOlwjm7tDKdy0pAZ8l3clTcCnes86zvH-k,17203
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
13
|
+
broccoli/vit.py,sha256=c-ZRHiLDOoQDJO9OJ51zD9HqaluG33flIwTXQQfms-g,17389
|
14
|
+
broccoli_ml-0.35.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.35.1.dist-info/METADATA,sha256=5pQA45ytAkkn0F5il2zuSN0vY7hFtVJvyUi9MXF-0EA,1257
|
16
|
+
broccoli_ml-0.35.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.35.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|