broccoli-ml 9.0.0__tar.gz → 9.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-9.0.0 → broccoli_ml-9.1.0}/PKG-INFO +1 -1
- {broccoli_ml-9.0.0 → broccoli_ml-9.1.0}/broccoli/linear.py +44 -4
- {broccoli_ml-9.0.0 → broccoli_ml-9.1.0}/pyproject.toml +1 -1
- {broccoli_ml-9.0.0 → broccoli_ml-9.1.0}/LICENSE +0 -0
- {broccoli_ml-9.0.0 → broccoli_ml-9.1.0}/README.md +0 -0
- {broccoli_ml-9.0.0 → broccoli_ml-9.1.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-9.0.0 → broccoli_ml-9.1.0}/broccoli/activation.py +0 -0
- {broccoli_ml-9.0.0 → broccoli_ml-9.1.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-9.0.0 → broccoli_ml-9.1.0}/broccoli/rope.py +0 -0
- {broccoli_ml-9.0.0 → broccoli_ml-9.1.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-9.0.0 → broccoli_ml-9.1.0}/broccoli/transformer.py +0 -0
- {broccoli_ml-9.0.0 → broccoli_ml-9.1.0}/broccoli/utils.py +0 -0
- {broccoli_ml-9.0.0 → broccoli_ml-9.1.0}/broccoli/vit.py +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import random
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Union, List, Iterable
|
|
4
5
|
|
|
5
6
|
import torch
|
|
@@ -149,34 +150,73 @@ class RecyclingLinear(nn.Module):
|
|
|
149
150
|
bias: bool = True,
|
|
150
151
|
row_recycling_rate: float = 0.0,
|
|
151
152
|
column_recycling_rate: float = 0.0,
|
|
153
|
+
adaptive=False,
|
|
152
154
|
):
|
|
153
155
|
super().__init__()
|
|
154
156
|
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
155
157
|
self.row_recycling_rate = row_recycling_rate
|
|
156
158
|
self.column_recycling_rate = column_recycling_rate
|
|
159
|
+
self.adaptive = adaptive
|
|
157
160
|
self.optimisers = []
|
|
161
|
+
self.initial_learning_rates = []
|
|
162
|
+
self._warned_about_registration = False
|
|
158
163
|
|
|
159
164
|
def register_optimiser(self, optimiser: torch.optim.Optimizer):
|
|
160
165
|
self.optimisers.append(optimiser)
|
|
166
|
+
self.initial_learning_rates.append(self._get_learning_rate(optimiser))
|
|
167
|
+
if self.initial_learning_rates[-1] == 0.0:
|
|
168
|
+
warnings.warn(
|
|
169
|
+
"Learning rate of registered optimiser was 0.0 - make sure "
|
|
170
|
+
"you haven't initialised a scheduler before registering the "
|
|
171
|
+
"optimiser",
|
|
172
|
+
stacklevel=2,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def _get_learning_rate(self, optimiser: torch.optim.Optimizer):
|
|
176
|
+
for group in optimiser.param_groups:
|
|
177
|
+
for param in group["params"]:
|
|
178
|
+
if param is self.linear.weight:
|
|
179
|
+
return group["lr"]
|
|
180
|
+
|
|
181
|
+
def _get_multiplier(self):
|
|
182
|
+
if not self.adaptive or not self.optimisers:
|
|
183
|
+
return 1.0
|
|
184
|
+
else:
|
|
185
|
+
init = self.initial_learning_rates
|
|
186
|
+
current = [self._get_learning_rate(o) for o in self.optimisers]
|
|
187
|
+
pairs = zip(current, init, strict=True)
|
|
188
|
+
multipliers = [a / b for a, b in pairs if b != 0.0]
|
|
189
|
+
return min(multipliers) if multipliers else 0.0
|
|
161
190
|
|
|
162
191
|
def forward(self, x):
|
|
192
|
+
multiplier = self._get_multiplier()
|
|
193
|
+
col_recycling_rate = self.column_recycling_rate * multiplier
|
|
194
|
+
row_recycling_rate = self.row_recycling_rate * multiplier
|
|
195
|
+
|
|
163
196
|
if self.training and self.optimisers:
|
|
164
197
|
|
|
165
|
-
if
|
|
198
|
+
if row_recycling_rate > 0:
|
|
166
199
|
probs = torch.rand(self.linear.out_features, device=x.device)
|
|
167
|
-
mask = probs <
|
|
200
|
+
mask = probs < row_recycling_rate
|
|
168
201
|
if mask.any():
|
|
169
202
|
# nonzero returns [N, 1], squeeze to get [N]
|
|
170
203
|
indices = torch.nonzero(mask).squeeze(-1)
|
|
171
204
|
self.reset_rows(indices, self.optimisers)
|
|
172
205
|
|
|
173
|
-
if
|
|
206
|
+
if col_recycling_rate > 0:
|
|
174
207
|
probs = torch.rand(self.linear.in_features, device=x.device)
|
|
175
|
-
mask = probs <
|
|
208
|
+
mask = probs < col_recycling_rate
|
|
176
209
|
if mask.any():
|
|
177
210
|
indices = torch.nonzero(mask).squeeze(-1)
|
|
178
211
|
self.reset_columns(indices, self.optimisers)
|
|
179
212
|
|
|
213
|
+
elif self.training and not self._warned_about_registration:
|
|
214
|
+
warnings.warn(
|
|
215
|
+
"RecyclingLinear: No optimiser registered. Recycling disabled.",
|
|
216
|
+
stacklevel=2,
|
|
217
|
+
)
|
|
218
|
+
self._warned_about_registration = True
|
|
219
|
+
|
|
180
220
|
return self.linear(x)
|
|
181
221
|
|
|
182
222
|
def reset_rows(
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|