gradboard 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gradboard/__init__.py +3 -0
- gradboard/cycles.py +319 -0
- gradboard/optimiser.py +48 -0
- gradboard/scheduler.py +214 -0
- gradboard-5.1.0.dist-info/LICENSE +21 -0
- gradboard-5.1.0.dist-info/METADATA +63 -0
- gradboard-5.1.0.dist-info/RECORD +8 -0
- gradboard-5.1.0.dist-info/WHEEL +4 -0
gradboard/__init__.py
ADDED
gradboard/cycles.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for generating a range of learning rate schedules.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from typing import Optional, List, Union, Callable
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def ascent(step: int, total_steps: int) -> float:
|
|
10
|
+
"""
|
|
11
|
+
Get a sequence of numbers evenly spaced between 0 and 1 so that the first
|
|
12
|
+
number is 0 and the last is 1 and there are `total_steps` numbers in
|
|
13
|
+
the sequence.
|
|
14
|
+
"""
|
|
15
|
+
return round(step / (total_steps - 0.999), 8)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def triangle(step: int, total_steps: int) -> float:
|
|
19
|
+
"""
|
|
20
|
+
Get a triangular sequence of numbers between 0 and 1, going up in half of
|
|
21
|
+
`total_steps` and coming down in the other half, peaking at ~1.
|
|
22
|
+
"""
|
|
23
|
+
half = int(math.ceil(total_steps / 2))
|
|
24
|
+
if step < half:
|
|
25
|
+
return 2 * ascent(step, total_steps)
|
|
26
|
+
else:
|
|
27
|
+
return 2 - 2 * ascent(step, total_steps)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def cosine(step: int, total_steps: int) -> float:
|
|
31
|
+
"""
|
|
32
|
+
Get a sequence of numbers between 0 and 1 in the shape of a cosine wave with
|
|
33
|
+
wavelength `total_steps`.
|
|
34
|
+
"""
|
|
35
|
+
if total_steps < 1:
|
|
36
|
+
raise ValueError(f"total_steps must be >= 1, got {total_steps}")
|
|
37
|
+
angle = (step / (total_steps - 0.999)) * (2 * math.pi)
|
|
38
|
+
return round((math.cos(angle) + 1) / 2, 8)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def quarter_circle(step: int, total_steps: int) -> float:
|
|
42
|
+
"""
|
|
43
|
+
Get a sequence of numbers between 0 and 1 in the shape of a quarter-circle with
|
|
44
|
+
radius `total_steps'.
|
|
45
|
+
"""
|
|
46
|
+
if total_steps < 1:
|
|
47
|
+
raise ValueError(f"total_steps must be >= 1, got {total_steps}")
|
|
48
|
+
x = 0 if total_steps == 1 else step / (total_steps - 1)
|
|
49
|
+
# x^2 + y^2 = r^2 = 1
|
|
50
|
+
# Therefore y^2 = 1 - x^2
|
|
51
|
+
# Therefore y^2 = (1 + x)(1 - x)
|
|
52
|
+
y_squared = max(1 - x, 0) * (1 + x)
|
|
53
|
+
return math.sqrt(y_squared)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def half_cosine(step: int, total_steps: int) -> float:
|
|
57
|
+
"""
|
|
58
|
+
Get a sequence of numbers between 0 and 1 in the shape of the descending
|
|
59
|
+
half of a cosine wave with wavelength 2*`total_steps`.
|
|
60
|
+
"""
|
|
61
|
+
return cosine(step, (total_steps * 2) - 1)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def cycloid(step: int, total_steps: int) -> float:
|
|
65
|
+
"""
|
|
66
|
+
Get a sequence of numbers between 0 and 1 in the shape of a cycloid with
|
|
67
|
+
circle diameter 1.0 and `total_steps/(2*math.pi)` steps per cycle.
|
|
68
|
+
"""
|
|
69
|
+
x = step * (math.pi / (total_steps - 1))
|
|
70
|
+
|
|
71
|
+
def fx(t):
|
|
72
|
+
return 0.5 * (t - math.sin(t)) - x
|
|
73
|
+
|
|
74
|
+
def fx_prime(t):
|
|
75
|
+
return 0.5 - 0.5 * math.cos(t)
|
|
76
|
+
|
|
77
|
+
def fy_prime(t):
|
|
78
|
+
return 0.5 - 0.5 * -math.sin(t)
|
|
79
|
+
|
|
80
|
+
angle_estimate = 0.5 * x
|
|
81
|
+
|
|
82
|
+
# XXX: 200 iterations is too many! Use a more efficient root finding algorithm
|
|
83
|
+
for _ in range(200):
|
|
84
|
+
if abs(fx_prime(angle_estimate)) > 0.1:
|
|
85
|
+
update = fx(angle_estimate) / fx_prime(angle_estimate)
|
|
86
|
+
else:
|
|
87
|
+
update = fx(angle_estimate) / fy_prime(angle_estimate)
|
|
88
|
+
angle_estimate = angle_estimate - update
|
|
89
|
+
|
|
90
|
+
return 0.5 * (1 - math.cos(angle_estimate))
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def half_cycloid(step: int, total_steps: int) -> float:
|
|
94
|
+
return cycloid(total_steps + step, 2 * total_steps)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
FN_LIBRARY = {
|
|
98
|
+
"ascent": ascent,
|
|
99
|
+
"triangle": triangle,
|
|
100
|
+
"cosine": cosine,
|
|
101
|
+
"half_cosine": half_cosine,
|
|
102
|
+
"quarter_circle": quarter_circle,
|
|
103
|
+
"half_cycloid": half_cycloid,
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class Cycle:
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
generating_function: Union[str, Callable],
|
|
111
|
+
training_examples,
|
|
112
|
+
epochs,
|
|
113
|
+
batch_size,
|
|
114
|
+
t_0: Optional[int] = None,
|
|
115
|
+
t_mult: float = 1.0,
|
|
116
|
+
t_scale: float = 1.0,
|
|
117
|
+
low=0.0,
|
|
118
|
+
high=1.0,
|
|
119
|
+
reflect=False,
|
|
120
|
+
):
|
|
121
|
+
self.training_examples = training_examples
|
|
122
|
+
self.epochs = epochs
|
|
123
|
+
self.batch_size = batch_size
|
|
124
|
+
self.total_steps = int(
|
|
125
|
+
epochs * (math.floor(training_examples / batch_size) + 1)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
self.t_0 = (
|
|
129
|
+
t_0 * (training_examples / batch_size)
|
|
130
|
+
if t_0 is not None
|
|
131
|
+
else self.total_steps
|
|
132
|
+
)
|
|
133
|
+
self.t_mult = t_mult
|
|
134
|
+
self.t_scale = t_scale
|
|
135
|
+
|
|
136
|
+
self.low = low
|
|
137
|
+
self.high = high
|
|
138
|
+
|
|
139
|
+
self.reflect = reflect
|
|
140
|
+
|
|
141
|
+
if not callable(generating_function):
|
|
142
|
+
if generating_function in FN_LIBRARY:
|
|
143
|
+
self._generating_function = FN_LIBRARY[generating_function]
|
|
144
|
+
else:
|
|
145
|
+
raise NotImplementedError(
|
|
146
|
+
"`generating_function` must be a callable object or one of "
|
|
147
|
+
'"ascent", "triangle", "cosine", "half_cosine", "quarter_circle" '
|
|
148
|
+
'or "half_cycloid"'
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
self._generating_function = generating_function
|
|
152
|
+
|
|
153
|
+
def _get_window(self, step):
|
|
154
|
+
windows = self._windows()
|
|
155
|
+
cumulative = [
|
|
156
|
+
sum([w[0] for w in windows][: i + 1]) for i in range(len(windows))
|
|
157
|
+
]
|
|
158
|
+
position = None
|
|
159
|
+
local_step = None
|
|
160
|
+
for i, c in enumerate(cumulative):
|
|
161
|
+
if c > step:
|
|
162
|
+
position = i
|
|
163
|
+
local_step = step if i == 0 else step - cumulative[i - 1]
|
|
164
|
+
break
|
|
165
|
+
window_width, window_height = windows[position]
|
|
166
|
+
return window_width, local_step, window_height
|
|
167
|
+
|
|
168
|
+
def _generate(self, step) -> list:
|
|
169
|
+
total_steps, step, scale = self._get_window(step)
|
|
170
|
+
y = self._generating_function(step, total_steps)
|
|
171
|
+
y = y * scale
|
|
172
|
+
y = 1 - y if self.reflect else y
|
|
173
|
+
return y * (self.high - self.low) + self.low
|
|
174
|
+
|
|
175
|
+
def __call__(self, n):
|
|
176
|
+
return self._generate(n)
|
|
177
|
+
|
|
178
|
+
def __len__(self):
|
|
179
|
+
return self.total_steps
|
|
180
|
+
|
|
181
|
+
def _windows(self):
|
|
182
|
+
assert self.t_mult > 0
|
|
183
|
+
|
|
184
|
+
# Get tile widths
|
|
185
|
+
widths = [self.t_0]
|
|
186
|
+
while True:
|
|
187
|
+
next_item = widths[-1] * self.t_mult
|
|
188
|
+
if sum(widths) + next_item <= self.total_steps:
|
|
189
|
+
widths.append(next_item)
|
|
190
|
+
else:
|
|
191
|
+
break
|
|
192
|
+
for i in range(1, len(widths)):
|
|
193
|
+
widths[i] = int(widths[i] * (self.total_steps / sum(widths)))
|
|
194
|
+
widths[-1] += self.total_steps - sum(widths)
|
|
195
|
+
|
|
196
|
+
# Get tile heights
|
|
197
|
+
heights = [1.0 * self.t_scale**i for i in range(len(widths))]
|
|
198
|
+
|
|
199
|
+
return list(zip(widths, heights, strict=True))
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def stats(self) -> float:
|
|
203
|
+
"""
|
|
204
|
+
Returns the area (as a percentage of the area of a curve where the learning
|
|
205
|
+
rate is constant max_lr), percentage ascent steps and percentage descent
|
|
206
|
+
steps of a learning rate schedule.
|
|
207
|
+
"""
|
|
208
|
+
total_area = 0
|
|
209
|
+
max_area = 0
|
|
210
|
+
ascent_steps = 0
|
|
211
|
+
descent_steps = 0
|
|
212
|
+
avg_up_gradient = 0
|
|
213
|
+
avg_down_gradient = 0
|
|
214
|
+
total_gradient = 0
|
|
215
|
+
previous_lr = None
|
|
216
|
+
for s in range(self.total_steps):
|
|
217
|
+
height = self(s)
|
|
218
|
+
total_area += height
|
|
219
|
+
max_area += 1
|
|
220
|
+
if previous_lr is None:
|
|
221
|
+
pass
|
|
222
|
+
elif previous_lr > height:
|
|
223
|
+
descent_steps += 1
|
|
224
|
+
avg_down_gradient += height - previous_lr
|
|
225
|
+
total_gradient += height - previous_lr
|
|
226
|
+
elif previous_lr < height:
|
|
227
|
+
ascent_steps += 1
|
|
228
|
+
avg_up_gradient += height - previous_lr
|
|
229
|
+
total_gradient += height - previous_lr
|
|
230
|
+
else:
|
|
231
|
+
total_gradient += height
|
|
232
|
+
previous_lr = height
|
|
233
|
+
return {
|
|
234
|
+
"area": total_area / max_area,
|
|
235
|
+
"pc_ascent": round(ascent_steps / self.total_steps, 3),
|
|
236
|
+
"pc_descent": round(descent_steps / self.total_steps, 3),
|
|
237
|
+
"avg_up_gradient": round(avg_up_gradient, 3),
|
|
238
|
+
"avg_down_gradient": round(avg_down_gradient, 3),
|
|
239
|
+
"avg_gradient": round(-(self.high - self.low) / self.total_steps, 3),
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class CycleProduct(Cycle):
|
|
244
|
+
def __init__(self, cycles: List[Cycle], reflect=False, normalise: bool = False):
|
|
245
|
+
"""
|
|
246
|
+
Args:
|
|
247
|
+
normalise: if true, the square root of the product is returned (i.e.
|
|
248
|
+
the geometric mean of the two cycles that were multiplied together)
|
|
249
|
+
"""
|
|
250
|
+
main_training_examples = cycles[0].training_examples
|
|
251
|
+
main_batch_size = cycles[0].batch_size
|
|
252
|
+
|
|
253
|
+
assert all(c.training_examples == main_training_examples for c in cycles)
|
|
254
|
+
assert all(c.batch_size == main_batch_size for c in cycles)
|
|
255
|
+
|
|
256
|
+
self.cycles = cycles
|
|
257
|
+
self.reflect = reflect
|
|
258
|
+
self.normalise = normalise
|
|
259
|
+
|
|
260
|
+
def generating_function(step: int, total_steps: int) -> float:
|
|
261
|
+
output = self.cycles[0](step)
|
|
262
|
+
for c in self.cycles[1:]:
|
|
263
|
+
output *= c(step % c.total_steps)
|
|
264
|
+
if self.normalise:
|
|
265
|
+
output = math.sqrt(output)
|
|
266
|
+
return output
|
|
267
|
+
|
|
268
|
+
super().__init__(
|
|
269
|
+
generating_function=generating_function,
|
|
270
|
+
training_examples=self.cycles[0].training_examples,
|
|
271
|
+
epochs=self.cycles[0].epochs,
|
|
272
|
+
batch_size=self.cycles[0].batch_size,
|
|
273
|
+
reflect=reflect,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class CycleSequence:
|
|
278
|
+
def __init__(self, cycles: List[Cycle]):
|
|
279
|
+
self.total_steps = sum([c.total_steps for c in cycles])
|
|
280
|
+
self.cycles = cycles
|
|
281
|
+
|
|
282
|
+
def _generate(self, step):
|
|
283
|
+
cycle, step = self._get_cycle_and_step(step)
|
|
284
|
+
return self.cycles[cycle](step)
|
|
285
|
+
|
|
286
|
+
def _get_cycle_and_step(self, step):
|
|
287
|
+
cycle_lengths = [c.total_steps for c in self.cycles]
|
|
288
|
+
cumulative = [sum(cycle_lengths[: i + 1]) for i in range(len(cycle_lengths))]
|
|
289
|
+
cycle = None
|
|
290
|
+
local_step = None
|
|
291
|
+
for i, c in enumerate(cumulative):
|
|
292
|
+
if c > step:
|
|
293
|
+
cycle = i
|
|
294
|
+
local_step = step if i == 0 else step - cumulative[i - 1]
|
|
295
|
+
break
|
|
296
|
+
return cycle, local_step
|
|
297
|
+
|
|
298
|
+
def __call__(self, step):
|
|
299
|
+
return self._generate(step)
|
|
300
|
+
|
|
301
|
+
def __len__(self):
|
|
302
|
+
return self.total_steps
|
|
303
|
+
|
|
304
|
+
@property
|
|
305
|
+
def stats(self) -> float:
|
|
306
|
+
"""
|
|
307
|
+
Returns the area (as a percentage of the area of a curve where the learning
|
|
308
|
+
rate is constant max_lr), percentage ascent steps and percentage descent
|
|
309
|
+
steps of a learning rate schedule.
|
|
310
|
+
"""
|
|
311
|
+
cycle_ratios = [c.total_steps for c in self.cycles]
|
|
312
|
+
cycle_stats = {k: v * cycle_ratios[0] for k, v in self.cycles[0].stats.items()}
|
|
313
|
+
for i, cycle in enumerate(self.cycles):
|
|
314
|
+
if i == 0:
|
|
315
|
+
continue # We already did the first one, above
|
|
316
|
+
for k, v in cycle.stats.items():
|
|
317
|
+
cycle_stats[k] += cycle_ratios[i] * v
|
|
318
|
+
|
|
319
|
+
return {k: v / self.total_steps for k, v in cycle_stats.items()}
|
gradboard/optimiser.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import warnings
|
|
3
|
+
import torch
|
|
4
|
+
from torch.optim.optimizer import Optimizer
|
|
5
|
+
from torch.optim import AdamW
|
|
6
|
+
|
|
7
|
+
EXCLUDE_FROM_WEIGHT_DECAY = ["nondecay", "bias", "norm", "embedding", "beta"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_optimiser(
|
|
11
|
+
model,
|
|
12
|
+
optimiser=AdamW,
|
|
13
|
+
lr=1e-3,
|
|
14
|
+
weight_decay=1e-2,
|
|
15
|
+
exclude_keywords=EXCLUDE_FROM_WEIGHT_DECAY,
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Defaults are from one of the presets from the accompanying repo to Hassani
|
|
19
|
+
et al. (2023) "Escaping the Big Data Paradigm with Compact Transformers",
|
|
20
|
+
https://github.com/SHI-Labs/Compact-Transformers/blob/main/configs/
|
|
21
|
+
pretrained/cct_7-3x1_cifar100_1500epochs.yml
|
|
22
|
+
"""
|
|
23
|
+
weight_decay_exclude = []
|
|
24
|
+
|
|
25
|
+
for keyword in exclude_keywords:
|
|
26
|
+
weight_decay_exclude += [
|
|
27
|
+
p for name, p in model.named_parameters() if keyword in name.lower()
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
weight_decay_exclude = set(weight_decay_exclude)
|
|
31
|
+
|
|
32
|
+
if len(weight_decay_exclude) > 0:
|
|
33
|
+
warnings.warn(
|
|
34
|
+
"Excluded the following parameters from weight decay based on "
|
|
35
|
+
"exclude keywords: {weight_decay_exclude}",
|
|
36
|
+
stacklevel=2,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
weight_decay_include = set(model.parameters()) - weight_decay_exclude
|
|
40
|
+
|
|
41
|
+
return optimiser(
|
|
42
|
+
[
|
|
43
|
+
{"params": list(weight_decay_include)},
|
|
44
|
+
{"params": list(weight_decay_exclude), "weight_decay": 0.0},
|
|
45
|
+
],
|
|
46
|
+
weight_decay=weight_decay,
|
|
47
|
+
lr=lr,
|
|
48
|
+
)
|
gradboard/scheduler.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Based on Smith (2017) https://arxiv.org/abs/1506.01186
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Optional
|
|
6
|
+
import copy
|
|
7
|
+
import math
|
|
8
|
+
|
|
9
|
+
from torch.amp import GradScaler
|
|
10
|
+
|
|
11
|
+
from .cycles import Cycle
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PASS:
|
|
15
|
+
"""
|
|
16
|
+
A self-configuring learning rate scheduler
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
learning_rate_schedule: Cycle,
|
|
22
|
+
model,
|
|
23
|
+
optimiser,
|
|
24
|
+
scaler: Optional[GradScaler] = None,
|
|
25
|
+
range_test: bool = False,
|
|
26
|
+
cool_point_multiplier: float = 1 / 60,
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
If not using range test, we assume the optimiser has the learning rates
|
|
30
|
+
set as desired.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
self.model = model
|
|
34
|
+
self.optimiser = optimiser
|
|
35
|
+
self.scaler = scaler
|
|
36
|
+
|
|
37
|
+
self.learning_rate_schedule = learning_rate_schedule
|
|
38
|
+
|
|
39
|
+
self.range_test = range_test
|
|
40
|
+
|
|
41
|
+
self.original_param_groups = copy.deepcopy(optimiser.param_groups)
|
|
42
|
+
|
|
43
|
+
self.cool_point_multiplier = cool_point_multiplier
|
|
44
|
+
|
|
45
|
+
self.original_states = self._saved_states()
|
|
46
|
+
|
|
47
|
+
self.range_test_results = []
|
|
48
|
+
|
|
49
|
+
self.step_count = 0
|
|
50
|
+
|
|
51
|
+
if range_test:
|
|
52
|
+
self.start_range_test() # sets LR to 1E-7
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def lr(self):
|
|
56
|
+
"""
|
|
57
|
+
Return first lr from self.optimiser.param_groups
|
|
58
|
+
(this is used in learning rate range tests, in which case we can
|
|
59
|
+
assume they are all the same!)
|
|
60
|
+
"""
|
|
61
|
+
for group in self.optimiser.param_groups:
|
|
62
|
+
return group["lr"]
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def in_range_test(self):
|
|
66
|
+
if not self.range_test:
|
|
67
|
+
return False
|
|
68
|
+
elif (len(self.range_test_results) == 0) or (
|
|
69
|
+
not math.isnan(self.range_test_results[-1][1])
|
|
70
|
+
):
|
|
71
|
+
return True
|
|
72
|
+
else:
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def trained(self):
|
|
77
|
+
if not self.range_test:
|
|
78
|
+
return True
|
|
79
|
+
elif math.isnan(self.range_test_results[-1][1]):
|
|
80
|
+
return True
|
|
81
|
+
else:
|
|
82
|
+
return False
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def finished(self):
|
|
86
|
+
return self.step_count >= len(self.learning_rate_schedule) - 1
|
|
87
|
+
|
|
88
|
+
def _saved_states(self):
|
|
89
|
+
saved_states = {
|
|
90
|
+
"model": copy.deepcopy(self.model.state_dict()),
|
|
91
|
+
"optimiser": copy.deepcopy(self.optimiser.state_dict()),
|
|
92
|
+
}
|
|
93
|
+
if self.scaler is not None:
|
|
94
|
+
saved_states["scaler"] = copy.deepcopy(self.scaler.state_dict())
|
|
95
|
+
return saved_states
|
|
96
|
+
|
|
97
|
+
def save_states(self):
|
|
98
|
+
self.saved_states = self._saved_states()
|
|
99
|
+
|
|
100
|
+
def load_states(self, saved_states):
|
|
101
|
+
self.model.load_state_dict(saved_states["model"])
|
|
102
|
+
self.optimiser.load_state_dict(saved_states["optimiser"])
|
|
103
|
+
if self.scaler is not None:
|
|
104
|
+
self.scaler.load_state_dict(saved_states["scaler"])
|
|
105
|
+
|
|
106
|
+
def recover_states(self):
|
|
107
|
+
self.load_states(self.saved_states)
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def _schedule_multiplier(self):
|
|
111
|
+
return self.learning_rate_schedule(
|
|
112
|
+
min(self.step_count, self.learning_rate_schedule.total_steps)
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def set_all_lr(self, lr):
|
|
116
|
+
for group in self.optimiser.param_groups:
|
|
117
|
+
group["lr"] = lr
|
|
118
|
+
|
|
119
|
+
def start_range_test(self):
|
|
120
|
+
self.save_states()
|
|
121
|
+
self.optimiser.load_state_dict(self.original_states["optimiser"])
|
|
122
|
+
if self.scaler is not None:
|
|
123
|
+
self.scaler.load_state_dict(self.original_states["scaler"])
|
|
124
|
+
self.set_all_lr(1e-7)
|
|
125
|
+
|
|
126
|
+
def scale_all_lr(self, scaling_factor):
|
|
127
|
+
self.set_all_lr(self.lr * scaling_factor)
|
|
128
|
+
|
|
129
|
+
def end_range_test(self):
|
|
130
|
+
self.recover_states()
|
|
131
|
+
self.update_learning_rates()
|
|
132
|
+
|
|
133
|
+
def _smoothed_range_test(self, range_test_results):
|
|
134
|
+
range_test_results = sorted(range_test_results, key=lambda x: x[0])
|
|
135
|
+
learning_rates = [t[0] for t in range_test_results]
|
|
136
|
+
losses = [t[1] for t in self.range_test_results]
|
|
137
|
+
losses = losses[:-1] + [10 * max(losses)]
|
|
138
|
+
return list(zip(learning_rates, losses, strict=True))
|
|
139
|
+
|
|
140
|
+
def _plot_range_test(self, range_test_results):
|
|
141
|
+
"""
|
|
142
|
+
Returns a tuple with x values (learning rates) and y values (losses)
|
|
143
|
+
which can then be passed to e.g. pyplot. We recommend presenting
|
|
144
|
+
the plot with a logarithmic x axis.
|
|
145
|
+
"""
|
|
146
|
+
range_test_results = sorted(range_test_results, key=lambda x: x[0])
|
|
147
|
+
learning_rates = [t[0] for t in range_test_results]
|
|
148
|
+
losses = [t[1] for t in range_test_results]
|
|
149
|
+
return learning_rates, losses
|
|
150
|
+
|
|
151
|
+
def _apply_range_test_result(self):
|
|
152
|
+
"""
|
|
153
|
+
...
|
|
154
|
+
"""
|
|
155
|
+
range_test_results = self._smoothed_range_test(self.range_test_results)
|
|
156
|
+
minimum = min(range_test_results, key=lambda x: x[1])
|
|
157
|
+
points_left_of_min = [r for r in range_test_results if r[0] < minimum[0]]
|
|
158
|
+
max_left_of_min = max(points_left_of_min, key=lambda x: x[1])
|
|
159
|
+
difference = max_left_of_min[1] - minimum[1]
|
|
160
|
+
max_lr = None
|
|
161
|
+
for p in sorted(points_left_of_min, key=lambda x: x[0]):
|
|
162
|
+
if (max_lr is None) and (p[1] < minimum[1] + 0.2 * difference):
|
|
163
|
+
max_lr = p[0]
|
|
164
|
+
else:
|
|
165
|
+
continue
|
|
166
|
+
self.set_all_lr(max_lr)
|
|
167
|
+
self.original_param_groups = copy.deepcopy(self.optimiser.param_groups)
|
|
168
|
+
print("High LR", max_lr)
|
|
169
|
+
|
|
170
|
+
def update_learning_rates(self):
|
|
171
|
+
if not self.finished:
|
|
172
|
+
for original, current in zip(
|
|
173
|
+
self.original_param_groups, self.optimiser.param_groups, strict=True
|
|
174
|
+
):
|
|
175
|
+
base_lr = original["lr"]
|
|
176
|
+
min_lr = base_lr * self.cool_point_multiplier
|
|
177
|
+
current_lr = min_lr + (base_lr - min_lr) * self._schedule_multiplier
|
|
178
|
+
current["lr"] = current_lr
|
|
179
|
+
|
|
180
|
+
def _append_to_range_test(self, loss_item: float):
|
|
181
|
+
|
|
182
|
+
lr = self.lr
|
|
183
|
+
|
|
184
|
+
self.range_test_results.append((lr, loss_item))
|
|
185
|
+
|
|
186
|
+
if math.isnan(loss_item) or (lr >= 1.0):
|
|
187
|
+
self._apply_range_test_result()
|
|
188
|
+
self.end_range_test()
|
|
189
|
+
else:
|
|
190
|
+
# Continue range test, step up learning rate
|
|
191
|
+
self.scale_all_lr(1.05)
|
|
192
|
+
|
|
193
|
+
def step(self, loss_item: Optional[float] = None):
|
|
194
|
+
"""
|
|
195
|
+
This function manages the process of
|
|
196
|
+
* Doing an initial range test
|
|
197
|
+
* Training for one microcycle using the learning rates from the
|
|
198
|
+
initial range test ("burn in")
|
|
199
|
+
* Doing a second range test to set the learning rate schedule for
|
|
200
|
+
the rest of training
|
|
201
|
+
* Updating learning rates during training according to the macrocycle
|
|
202
|
+
"""
|
|
203
|
+
if self.in_range_test: # True at init unless self.range_test = False
|
|
204
|
+
if not isinstance(loss_item, float):
|
|
205
|
+
raise ValueError(
|
|
206
|
+
"When using range test functionality, "
|
|
207
|
+
"`step()` expects a loss item."
|
|
208
|
+
)
|
|
209
|
+
self._append_to_range_test(loss_item)
|
|
210
|
+
elif self.trained and not self.finished:
|
|
211
|
+
self.step_count += 1
|
|
212
|
+
self.update_learning_rates()
|
|
213
|
+
else:
|
|
214
|
+
pass
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 nicholasbailey87
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: gradboard
|
|
3
|
+
Version: 5.1.0
|
|
4
|
+
Summary: Easily snowboard down gnarly loss gradients
|
|
5
|
+
License: MIT
|
|
6
|
+
Author: Nicholas Bailey
|
|
7
|
+
Requires-Python: >=3.8
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
|
|
18
|
+
# gradboard
|
|
19
|
+

|
|
20
|
+
|
|
21
|
+
Easily snowboard down gnarly loss gradients
|
|
22
|
+
|
|
23
|
+
## Getting started
|
|
24
|
+
|
|
25
|
+
You can install gradboard with
|
|
26
|
+
|
|
27
|
+
```
|
|
28
|
+
pip install gradboard
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
PyTorch is a peer dependency of `gradboard`, which means
|
|
32
|
+
* You will need to make sure you have PyTorch installed in order to use `gradboard`
|
|
33
|
+
* PyTorch will **not** be installed automatically when you install `gradboard`
|
|
34
|
+
|
|
35
|
+
We take this approach because PyTorch versioning is environment-specific and
|
|
36
|
+
we don't know where you will want to use `gradboard`. If we automatically install
|
|
37
|
+
PyTorch for you, there's a good chance we would get it wrong!
|
|
38
|
+
|
|
39
|
+
Therefore, please also make sure you install PyTorch.
|
|
40
|
+
|
|
41
|
+
## Usage examples
|
|
42
|
+
|
|
43
|
+
### Decent model training outcomes without tuning hyperparameters
|
|
44
|
+
|
|
45
|
+
`gradboard` includes
|
|
46
|
+
|
|
47
|
+
* An implementation of AdamS as proposed in Xie et al. (2023) "On the Overlooked
|
|
48
|
+
Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm
|
|
49
|
+
Perspective" (https://openreview.net/pdf?id=vnGcubtzR1), which in practice
|
|
50
|
+
makes model training more robust to the weight decay setting.
|
|
51
|
+
* Utilities for implementing popular learning rate schedules
|
|
52
|
+
* An implementation of an automatic max/min learning rate finder based on Smith
|
|
53
|
+
(2017) "Cyclical Learning Rates for Training Neural Networks"
|
|
54
|
+
(https://arxiv.org/abs/1506.01186)
|
|
55
|
+
* Sensible defaults
|
|
56
|
+
|
|
57
|
+
In practice this means that you can train a neural network and get decent performance
|
|
58
|
+
right out of the box, just by using the `PASS` (point-and-shoot scheduler), even
|
|
59
|
+
for unfamiliar architectures or problem domains.
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
gradboard/__init__.py,sha256=57AkHusYwLCsusiVnajH5pMFKioRCj-3IjF9qpdOzE0,69
|
|
2
|
+
gradboard/cycles.py,sha256=XnXNzCBI3J7OmzZQ3bKItffeDVDMXhTK6qicuYow6_4,10507
|
|
3
|
+
gradboard/optimiser.py,sha256=ds7rk67eiOYRsSwJyAWp5nK2vhCtko03FU9vHPUSC6E,1417
|
|
4
|
+
gradboard/scheduler.py,sha256=u3ojuGJQ4ZVn5fZ9L49CsFTu0TS-iqijarSEkmIWIfA,7056
|
|
5
|
+
gradboard-5.1.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
6
|
+
gradboard-5.1.0.dist-info/METADATA,sha256=b0TWoWtSAJj2RTWleIHD_SquassUmdfR1rvIyXXM-Fo,2246
|
|
7
|
+
gradboard-5.1.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
8
|
+
gradboard-5.1.0.dist-info/RECORD,,
|