broccoli-ml 0.35.0__py3-none-any.whl → 0.36.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/tensor.py +13 -13
- broccoli/transformer.py +4 -41
- broccoli/vit.py +1 -7
- {broccoli_ml-0.35.0.dist-info → broccoli_ml-0.36.0.dist-info}/METADATA +1 -1
- {broccoli_ml-0.35.0.dist-info → broccoli_ml-0.36.0.dist-info}/RECORD +7 -7
- {broccoli_ml-0.35.0.dist-info → broccoli_ml-0.36.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.35.0.dist-info → broccoli_ml-0.36.0.dist-info}/WHEEL +0 -0
broccoli/tensor.py
CHANGED
@@ -77,25 +77,25 @@ class AnchoredReparamTensor(nn.Module):
|
|
77
77
|
|
78
78
|
super().__init__()
|
79
79
|
|
80
|
-
self.
|
80
|
+
self.weight = nn.Parameter(init_tensor, requires_grad=True)
|
81
81
|
|
82
82
|
with torch.no_grad():
|
83
|
-
_, sigma, v_transpose = torch.linalg.svd(
|
84
|
-
self.nondecay_weight, full_matrices=False
|
85
|
-
)
|
83
|
+
_, sigma, v_transpose = torch.linalg.svd(self.weight, full_matrices=False)
|
86
84
|
|
87
85
|
self.register_buffer("rayleigh_norm", sigma[:1])
|
88
86
|
self.register_buffer("initial_right_singular", v_transpose[0])
|
89
|
-
self.
|
87
|
+
self.nondecay_scale = nn.Parameter(
|
88
|
+
sigma[:1].clone().detach(), requires_grad=True
|
89
|
+
)
|
90
90
|
|
91
91
|
def _update_rayleigh_norm(self):
|
92
92
|
with torch.no_grad():
|
93
|
-
product = self.
|
93
|
+
product = self.weight.mv(self.initial_right_singular)
|
94
94
|
normed_product = F.normalize(product, dim=0)
|
95
95
|
rayleigh_norm = torch.einsum(
|
96
96
|
"m,mn,n->",
|
97
97
|
normed_product,
|
98
|
-
self.
|
98
|
+
self.weight,
|
99
99
|
self.initial_right_singular,
|
100
100
|
)
|
101
101
|
self.rayleigh_norm.data.copy_(rayleigh_norm)
|
@@ -103,7 +103,7 @@ class AnchoredReparamTensor(nn.Module):
|
|
103
103
|
def forward(self):
|
104
104
|
if self.training:
|
105
105
|
self._update_rayleigh_norm()
|
106
|
-
return self.
|
106
|
+
return self.nondecay_scale * (self.weight / (self.rayleigh_norm + 1e-6))
|
107
107
|
|
108
108
|
|
109
109
|
class NormReparamTensor(nn.Module):
|
@@ -118,11 +118,11 @@ class NormReparamTensor(nn.Module):
|
|
118
118
|
|
119
119
|
# Use the gradboard convention of calling something nondecay_* if we should
|
120
120
|
# exclude it from weight decay
|
121
|
-
self.
|
122
|
-
self.
|
123
|
-
torch.linalg.norm(self.
|
121
|
+
self.weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
|
122
|
+
self.nondecay_scale = nn.Parameter(
|
123
|
+
torch.linalg.norm(self.weight).clone().detach(), requires_grad=True
|
124
124
|
)
|
125
125
|
|
126
126
|
def forward(self) -> torch.Tensor:
|
127
|
-
norm = torch.linalg.norm(self.
|
128
|
-
return self.
|
127
|
+
norm = torch.linalg.norm(self.weight)
|
128
|
+
return self.nondecay_scale * (self.weight / (norm + 1e-6))
|
broccoli/transformer.py
CHANGED
@@ -21,45 +21,6 @@ class MHAttention(nn.Module):
|
|
21
21
|
are the same shape.
|
22
22
|
|
23
23
|
Assumes bias=False and batch_first=True, as God intended.
|
24
|
-
|
25
|
-
Optionally adds various bells and whistles suggested in the
|
26
|
-
literature, including:
|
27
|
-
|
28
|
-
Noam Shazeer's scaled attention per "Attention is All You Need"
|
29
|
-
(https://arxiv.org/abs/1706.03762).
|
30
|
-
|
31
|
-
Max subtract softmax as discussed in "Attention As An RNN"
|
32
|
-
(https://arxiv.org/abs/2405.13956)
|
33
|
-
|
34
|
-
Log-length scaled softmax per "Overcoming a Theoretical Limitation of
|
35
|
-
Self-Attention" (https://arxiv.org/abs/2202.12172).
|
36
|
-
|
37
|
-
Quiet softmax per
|
38
|
-
https://www.evanmiller.org/attention-is-off-by-one.html
|
39
|
-
|
40
|
-
Args:
|
41
|
-
d_model: ...
|
42
|
-
n_heads: ...
|
43
|
-
dropout: ...
|
44
|
-
causal: should a causal mask be applied to the logits before attention
|
45
|
-
is applied? This is standard when using self-attention. Cannot be
|
46
|
-
True if inputs won't be square (e.g. if sequence length for
|
47
|
-
encoder and decoder are different)
|
48
|
-
sequence_length: ...
|
49
|
-
share_kv: ...
|
50
|
-
linear_module: ...
|
51
|
-
max_subtract: if True, the maximum logit value is subtracted from all
|
52
|
-
logits before performing the softmax operation to create a more
|
53
|
-
numerically stable softmax. This is discussed in "Attention As An
|
54
|
-
RNN" (https://arxiv.org/abs/2405.13956).
|
55
|
-
d_model_scale: ...
|
56
|
-
log_length_scale: if True, multiplies logits by the log length of
|
57
|
-
the decoder sequence before performing the softmax operation, as
|
58
|
-
proposed in "Overcoming a Theoretical Limitation of Self-Attention"
|
59
|
-
(https://arxiv.org/abs/2202.12172).
|
60
|
-
quiet: if True, adds 1 to the denominator of the softmax operation,
|
61
|
-
allowing some tokens to attend to no other tokens as described in
|
62
|
-
https://www.evanmiller.org/attention-is-off-by-one.html.
|
63
24
|
"""
|
64
25
|
|
65
26
|
def __init__(
|
@@ -280,7 +241,7 @@ class FeedforwardBlock(nn.Module):
|
|
280
241
|
elif self.residual_path:
|
281
242
|
return x + self.process(x)
|
282
243
|
else:
|
283
|
-
return x
|
244
|
+
return self.process(x)
|
284
245
|
|
285
246
|
|
286
247
|
class TransformerBlock(nn.Module):
|
@@ -374,7 +335,9 @@ class TransformerBlock(nn.Module):
|
|
374
335
|
identity_probability = self.identity_probability
|
375
336
|
|
376
337
|
# perform the identity operation for some rows in the batch
|
377
|
-
|
338
|
+
dist = torch.distributions.Binomial(x.size(0), identity_probability)
|
339
|
+
identity_count = int(dist.sample().item())
|
340
|
+
|
378
341
|
shuffle_indices = torch.randperm(x.size(0), device=x.device)
|
379
342
|
unshuffle_indices = torch.argsort(shuffle_indices)
|
380
343
|
shuffled = x[shuffle_indices, :, :]
|
broccoli/vit.py
CHANGED
@@ -236,13 +236,7 @@ class ViTEncoder(nn.Module):
|
|
236
236
|
|
237
237
|
if pooling_type is None:
|
238
238
|
pooling_out_channels = cnn_activation_out_channels
|
239
|
-
self.pool = nn.
|
240
|
-
*[
|
241
|
-
Rearrange(
|
242
|
-
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
243
|
-
), # for transformer
|
244
|
-
]
|
245
|
-
)
|
239
|
+
self.pool = nn.Identity()
|
246
240
|
|
247
241
|
elif pooling_type == "max":
|
248
242
|
pooling_out_channels = cnn_activation_out_channels
|
@@ -7,11 +7,11 @@ broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
|
|
7
7
|
broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
8
8
|
broccoli/linear.py,sha256=8Y9vD85ZEgNZsIQgO3uRQ3lOQR-JjwvabY8liCrfNCk,4831
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
|
-
broccoli/tensor.py,sha256=
|
11
|
-
broccoli/transformer.py,sha256=
|
10
|
+
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
11
|
+
broccoli/transformer.py,sha256=NH94U6lxHzmDGDHTTtJV2kUs7IcS2iNmFJl44_6KtQ0,15456
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
13
|
+
broccoli/vit.py,sha256=05xqIw9xvE5easXcp4wIA1jQ0xUyRIq6h0ZDtbitXi4,17184
|
14
|
+
broccoli_ml-0.36.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.36.0.dist-info/METADATA,sha256=csog4ZG1PGeRuFO5QnHdVPgmDYXsGQQJ621JgU0D83w,1257
|
16
|
+
broccoli_ml-0.36.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.36.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|