gradboard 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of gradboard might be problematic. Click here for more details.

@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 nicholasbailey87
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,62 @@
1
+ Metadata-Version: 2.3
2
+ Name: gradboard
3
+ Version: 0.1.1
4
+ Summary: Easily snowboard down gnarly loss gradients
5
+ License: MIT
6
+ Author: Nicholas Bailey
7
+ Requires-Python: >=3.11
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Classifier: Programming Language :: Python :: 3.13
13
+ Requires-Dist: numpy (>=2.0.2,<3.0.0)
14
+ Requires-Dist: scipy (>=1.15.3,<2.0.0)
15
+ Description-Content-Type: text/markdown
16
+
17
+ # gradboard
18
+ ![snowboarder](snowboarder.png "Image of a snowboarder")
19
+
20
+ Easily snowboard down gnarly loss gradients
21
+
22
+ ## Getting started
23
+
24
+ You can install gradboard with
25
+
26
+ ```
27
+ pip install gradboard
28
+ ```
29
+
30
+ PyTorch is a peer dependency of `gradboard`, which means
31
+ * You will need to make sure you have PyTorch installed in order to use `gradboard`
32
+ * PyTorch will **not** be installed automatically when you install `gradboard`
33
+
34
+ We take this approach because PyTorch versioning is environment-specific and
35
+ we don't know where you will want to use `gradboard`. If we automatically install
36
+ PyTorch for you, there's a good chance we would get it wrong!
37
+
38
+ Therefore, please also make sure you install PyTorch.
39
+
40
+ ## Usage examples
41
+
42
+ ### Decent model training outcomes without tuning hyperparameters
43
+
44
+ `gradboard` includes
45
+
46
+ * An implementation of AdamS as proposed in Xie et al. (2023) "On the Overlooked
47
+ Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm
48
+ Perspective" (https://openreview.net/pdf?id=vnGcubtzR1), which in practice
49
+ makes model training more robust to the weight decay setting.
50
+ * Utilities for implementing popular learning rate schedules
51
+ * An implementation of an automatic max/min learning rate finder based on Smith
52
+ (2017) "Cyclical Learning Rates for Training Neural Networks"
53
+ (https://arxiv.org/abs/1506.01186)
54
+ * Sensible defaults
55
+
56
+ In practice this means that you can train a neural network and get decent performance
57
+ right out of the box, just by using the `PASS` (point-and-shoot scheduler), even
58
+ for unfamiliar architectures or problem domains.
59
+
60
+
61
+
62
+
@@ -0,0 +1,45 @@
1
+ # gradboard
2
+ ![snowboarder](snowboarder.png "Image of a snowboarder")
3
+
4
+ Easily snowboard down gnarly loss gradients
5
+
6
+ ## Getting started
7
+
8
+ You can install gradboard with
9
+
10
+ ```
11
+ pip install gradboard
12
+ ```
13
+
14
+ PyTorch is a peer dependency of `gradboard`, which means
15
+ * You will need to make sure you have PyTorch installed in order to use `gradboard`
16
+ * PyTorch will **not** be installed automatically when you install `gradboard`
17
+
18
+ We take this approach because PyTorch versioning is environment-specific and
19
+ we don't know where you will want to use `gradboard`. If we automatically install
20
+ PyTorch for you, there's a good chance we would get it wrong!
21
+
22
+ Therefore, please also make sure you install PyTorch.
23
+
24
+ ## Usage examples
25
+
26
+ ### Decent model training outcomes without tuning hyperparameters
27
+
28
+ `gradboard` includes
29
+
30
+ * An implementation of AdamS as proposed in Xie et al. (2023) "On the Overlooked
31
+ Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm
32
+ Perspective" (https://openreview.net/pdf?id=vnGcubtzR1), which in practice
33
+ makes model training more robust to the weight decay setting.
34
+ * Utilities for implementing popular learning rate schedules
35
+ * An implementation of an automatic max/min learning rate finder based on Smith
36
+ (2017) "Cyclical Learning Rates for Training Neural Networks"
37
+ (https://arxiv.org/abs/1506.01186)
38
+ * Sensible defaults
39
+
40
+ In practice this means that you can train a neural network and get decent performance
41
+ right out of the box, just by using the `PASS` (point-and-shoot scheduler), even
42
+ for unfamiliar architectures or problem domains.
43
+
44
+
45
+
@@ -0,0 +1,277 @@
1
+ """
2
+ Utilities for generating a range of learning rate schedules.
3
+ """
4
+
5
+ import math
6
+ from typing import Optional, List, Union, Callable
7
+
8
+
9
+ def ascent(step: int, total_steps: int) -> float:
10
+ """
11
+ Get a sequence of numbers evenly spaced between 0 and 1 so that the first
12
+ number is 0 and the last is 1 and there are `total_steps` numbers in
13
+ the sequence.
14
+ """
15
+ return round(step / (total_steps - 0.999), 8)
16
+
17
+
18
+ def triangle(step: int, total_steps: int) -> float:
19
+ """
20
+ Get a triangular sequence of numbers between 0 and 1, going up in half of
21
+ `total_steps` and coming down in the other half, peaking at ~1.
22
+ """
23
+ half = int(math.ceil(total_steps / 2))
24
+ if step < half:
25
+ return 2 * ascent(step, total_steps)
26
+ else:
27
+ return 2 - 2 * ascent(step, total_steps)
28
+
29
+
30
+ def cosine(step: int, total_steps: int) -> float:
31
+ """
32
+ Get a sequence of numbers between 0 and 1 in the shape of a cosine wave with
33
+ wavelength `total_steps`.
34
+ """
35
+ assert total_steps != 0
36
+ angle = (step / (total_steps - 0.999)) * (2 * math.pi)
37
+ return round((math.cos(angle) + 1) / 2, 8)
38
+
39
+
40
+ def half_cosine(step: int, total_steps: int) -> float:
41
+ """
42
+ Get a sequence of numbers between 0 and 1 in the shape of the descending
43
+ half of a cosine wave with wavelength 2*`total_steps`.
44
+ """
45
+ return cosine(step, (total_steps * 2) - 1)
46
+
47
+
48
+ def cycloid(step: int, total_steps: int) -> float:
49
+ """
50
+ Get a sequence of numbers between 0 and 1 in the shape of a cycloid with
51
+ circle diameter 1.0 and `total_steps/(2*math.pi)` steps per cycle.
52
+ """
53
+ x = step * (math.pi / (total_steps - 1))
54
+
55
+ def fx(t):
56
+ return 0.5 * (t - math.sin(t)) - x
57
+
58
+ def fx_prime(t):
59
+ return 0.5 - 0.5 * math.cos(t)
60
+
61
+ def fy_prime(t):
62
+ return 0.5 - 0.5 * -math.sin(t)
63
+
64
+ angle_estimate = 0.5 * x
65
+
66
+ # XXX: 200 iterations is too many! Use a more efficient root finding algorithm
67
+ for _ in range(200):
68
+ if abs(fx_prime(angle_estimate)) > 0.1:
69
+ update = fx(angle_estimate) / fx_prime(angle_estimate)
70
+ else:
71
+ update = fx(angle_estimate) / fy_prime(angle_estimate)
72
+ angle_estimate = angle_estimate - update
73
+
74
+ return 0.5 * (1 - math.cos(angle_estimate))
75
+
76
+
77
+ def half_cycloid(step: int, total_steps: int) -> float:
78
+ return cycloid(total_steps + step, 2 * total_steps)
79
+
80
+
81
+ class Cycle:
82
+ def __init__(
83
+ self,
84
+ generating_function: Union[str, Callable],
85
+ training_examples,
86
+ epochs,
87
+ batch_size,
88
+ t_0: Optional[int] = None,
89
+ t_mult: float = 1.0,
90
+ t_scale: float = 1.0,
91
+ low=0.0,
92
+ high=1.0,
93
+ reflect=False,
94
+ ):
95
+ self.training_examples = training_examples
96
+ self.epochs = epochs
97
+ self.batch_size = batch_size
98
+ self.total_steps = int(
99
+ epochs * (math.floor(training_examples / batch_size) + 1)
100
+ )
101
+
102
+ self.t_0 = (
103
+ t_0 * (training_examples / batch_size)
104
+ if t_0 is not None
105
+ else self.total_steps
106
+ )
107
+ self.t_mult = t_mult
108
+ self.t_scale = t_scale
109
+
110
+ self.low = low
111
+ self.high = high
112
+
113
+ self.reflect = reflect
114
+
115
+ if callable(generating_function):
116
+ self._generating_function = generating_function
117
+ elif generating_function == "ascent":
118
+ self._generating_function = ascent
119
+ elif generating_function == "triangle":
120
+ self._generating_function = triangle
121
+ elif generating_function == "cosine":
122
+ self._generating_function = cosine
123
+ elif generating_function == "half_cosine":
124
+ self._generating_function = half_cosine
125
+ elif generating_function == "half_cycloid":
126
+ self._generating_function = half_cycloid
127
+ else:
128
+ raise NotImplementedError(
129
+ "`generating_function` must be a callable object or one of "
130
+ '"ascent", "triangle", "cosine", "half_cosine" or "half_cycloid"'
131
+ )
132
+
133
+ def _get_window(self, step):
134
+ windows = self._windows()
135
+ cumulative = [
136
+ sum([w[0] for w in windows][: i + 1]) for i in range(len(windows))
137
+ ]
138
+ position = None
139
+ local_step = None
140
+ for i, c in enumerate(cumulative):
141
+ if c > step:
142
+ position = i
143
+ local_step = step if i == 0 else step - cumulative[i - 1]
144
+ break
145
+ window_width, window_height = windows[position]
146
+ return window_width, local_step, window_height
147
+
148
+ def _generate(self, step) -> list:
149
+ total_steps, step, scale = self._get_window(step)
150
+ y = self._generating_function(step, total_steps)
151
+ y = y * scale
152
+ y = 1 - y if self.reflect else y
153
+ return y * (self.high - self.low) + self.low
154
+
155
+ def __call__(self, n):
156
+ return self._generate(n)
157
+
158
+ def __len__(self):
159
+ return self.total_steps
160
+
161
+ def _windows(self):
162
+ assert self.t_mult > 0
163
+
164
+ # Get tile widths
165
+ widths = [self.t_0]
166
+ while True:
167
+ next_item = widths[-1] * self.t_mult
168
+ if sum(widths) + next_item <= self.total_steps:
169
+ widths.append(next_item)
170
+ else:
171
+ break
172
+ for i in range(1, len(widths)):
173
+ widths[i] = int(widths[i] * (self.total_steps / sum(widths)))
174
+ widths[-1] += self.total_steps - sum(widths)
175
+
176
+ # Get tile heights
177
+ heights = [1.0 * self.t_scale**i for i in range(len(widths))]
178
+
179
+ return list(zip(widths, heights, strict=True))
180
+
181
+ def stats(self) -> float:
182
+ """
183
+ Returns the area (as a percentage of the area of a curve where the learning
184
+ rate is constant max_lr), percentage ascent steps and percentage descent
185
+ steps of a learning rate schedule.
186
+ """
187
+ total_area = 0
188
+ max_area = 0
189
+ ascent_steps = 0
190
+ descent_steps = 0
191
+ total_up_gradient = 0
192
+ total_down_gradient = 0
193
+ total_gradient = 0
194
+ previous_lr = None
195
+ for s in range(self.total_steps):
196
+ height = self(s)
197
+ total_area += height
198
+ max_area += 1
199
+ if previous_lr is None:
200
+ pass
201
+ elif previous_lr > height:
202
+ descent_steps += 1
203
+ total_down_gradient += height - previous_lr
204
+ total_gradient += height - previous_lr
205
+ elif previous_lr < height:
206
+ ascent_steps += 1
207
+ total_up_gradient += height - previous_lr
208
+ total_gradient += height - previous_lr
209
+ else:
210
+ total_gradient += height
211
+ previous_lr = height
212
+ return {
213
+ "area": total_area / max_area,
214
+ "pc_ascent": round(ascent_steps / self.total_steps, 3),
215
+ "pc_descent": round(descent_steps / self.total_steps, 3),
216
+ "avg_up_gradient": round(
217
+ total_up_gradient / ascent_steps if ascent_steps > 0 else 0.0, 3
218
+ ),
219
+ "avg_down_gradient": round(
220
+ total_down_gradient / descent_steps if descent_steps > 0 else 0.0, 3
221
+ ),
222
+ "avg_gradient": round(-(self.high - self.low) / self.total_steps, 3),
223
+ }
224
+
225
+
226
+ class CycleProduct(Cycle):
227
+ def __init__(self, cycles: List[Cycle], reflect=False):
228
+ main_training_examples = cycles[0].training_examples
229
+ main_batch_size = cycles[0].batch_size
230
+
231
+ assert all(c.training_examples == main_training_examples for c in cycles)
232
+ assert all(c.batch_size == main_batch_size for c in cycles)
233
+
234
+ self.cycles = cycles
235
+ self.reflect = reflect
236
+
237
+ def generating_function(step: int, total_steps: int) -> float:
238
+ output = self.cycles[0](step)
239
+ for c in self.cycles[1:]:
240
+ output *= c(step % c.total_steps)
241
+ return output
242
+
243
+ super().__init__(
244
+ generating_function=generating_function,
245
+ training_examples=self.cycles[0].training_examples,
246
+ epochs=self.cycles[0].epochs,
247
+ batch_size=self.cycles[0].batch_size,
248
+ reflect=reflect,
249
+ )
250
+
251
+
252
+ class CycleSequence:
253
+ def __init__(self, cycles: List[Cycle]):
254
+ self.total_steps = sum([c.total_steps for c in cycles])
255
+ self.cycles = cycles
256
+
257
+ def _generate(self, step):
258
+ cycle, step = self._get_cycle_and_step(step)
259
+ return self.cycles[cycle](step)
260
+
261
+ def _get_cycle_and_step(self, step):
262
+ cycle_lengths = [c.total_steps for c in self.cycles]
263
+ cumulative = [sum(cycle_lengths[: i + 1]) for i in range(len(cycle_lengths))]
264
+ cycle = None
265
+ local_step = None
266
+ for i, c in enumerate(cumulative):
267
+ if c > step:
268
+ cycle = i
269
+ local_step = step if i == 0 else step - cumulative[i - 1]
270
+ break
271
+ return cycle, local_step
272
+
273
+ def __call__(self, step):
274
+ return self._generate(step)
275
+
276
+ def __len__(self):
277
+ return self.total_steps
@@ -0,0 +1,163 @@
1
+ import math
2
+ import torch
3
+ from torch.optim.optimizer import Optimizer
4
+ from torch.optim import AdamW
5
+
6
+
7
+ class AdamS(Optimizer):
8
+ r"""
9
+ Implements Adam with stable weight decay (AdamS) as proposed in
10
+ "On the Overlooked Pitfalls of Weight Decay and How to Mitigate Them:
11
+ A Gradient-Norm Perspective" (https://openreview.net/pdf?id=vnGcubtzR1).
12
+
13
+ This implementation was from the git repo
14
+ http://github.com/zeke-xie/stable-weight-decay-regularization/
15
+ blob/master/swd_optim/adams.py (MIT license ca. July 2025)
16
+
17
+ Arguments:
18
+ params (iterable): iterable of parameters to optimize or dicts defining
19
+ parameter groups
20
+ lr (float, optional): learning rate (default: 1e-3)
21
+ betas (Tuple[float, float], optional): coefficients used for computing
22
+ running averages of gradient and its square (default: (0.9, 0.999))
23
+ eps (float, optional): term added to the denominator to improve
24
+ numerical stability (default: 1e-8)
25
+ weight_decay (float, optional): weight decay coefficient (default: 1e-4)
26
+ """
27
+
28
+ def __init__(
29
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-4
30
+ ):
31
+ if not 0.0 <= lr:
32
+ raise ValueError("Invalid learning rate: {}".format(lr))
33
+ if not 0.0 <= eps:
34
+ raise ValueError("Invalid epsilon value: {}".format(eps))
35
+ if not 0.0 <= betas[0] < 1.0:
36
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
37
+ if not 0.0 <= betas[1] < 1.0:
38
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
39
+ if not 0.0 <= weight_decay:
40
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
41
+ defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay}
42
+ super().__init__(params, defaults)
43
+
44
+ @torch.no_grad()
45
+ def step(self, closure=None):
46
+ """Performs a single optimization step.
47
+
48
+ Arguments:
49
+ closure (callable, optional): A closure that reevaluates the model
50
+ and returns the loss.
51
+ """
52
+ loss = None
53
+ if closure is not None:
54
+ with torch.enable_grad():
55
+ loss = closure()
56
+
57
+ param_size = 0
58
+ exp_avg_sq_hat_sum = 0.0
59
+
60
+ for group in self.param_groups:
61
+ for p in group["params"]:
62
+ if p.grad is None:
63
+ continue
64
+ param_size += p.numel()
65
+
66
+ # Perform optimization step
67
+ grad = p.grad
68
+ if grad.is_sparse:
69
+ raise RuntimeError("AdamS does not support sparse gradients")
70
+
71
+ state = self.state[p]
72
+
73
+ # State initialization
74
+ if len(state) == 0:
75
+ state["step"] = 0
76
+ # Exponential moving average of gradient values
77
+ state["exp_avg"] = torch.zeros_like(
78
+ p, memory_format=torch.preserve_format
79
+ )
80
+ # Exponential moving average of squared gradient values
81
+ state["exp_avg_sq"] = torch.zeros_like(
82
+ p, memory_format=torch.preserve_format
83
+ )
84
+
85
+ beta1, beta2 = group["betas"]
86
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
87
+
88
+ state["step"] += 1
89
+ bias_correction2 = 1 - beta2 ** state["step"]
90
+
91
+ # Decay the first and second moment running average coefficient
92
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
93
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
94
+ exp_avg_sq_hat = exp_avg_sq / bias_correction2
95
+
96
+ exp_avg_sq_hat_sum += exp_avg_sq_hat.sum()
97
+
98
+ # Calculate the sqrt of the mean of all elements in exp_avg_sq_hat
99
+ exp_avg_mean_sqrt = math.sqrt(exp_avg_sq_hat_sum / param_size)
100
+
101
+ for group in self.param_groups:
102
+ for p in group["params"]:
103
+ if p.grad is None:
104
+ continue
105
+
106
+ state = self.state[p]
107
+
108
+ # Perform stable weight decay
109
+ if group["weight_decay"] != 0:
110
+ p.data.mul_(
111
+ 1 - group["weight_decay"] * group["lr"] / exp_avg_mean_sqrt
112
+ )
113
+
114
+ beta1, beta2 = group["betas"]
115
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
116
+ bias_correction1 = 1 - beta1 ** state["step"]
117
+ bias_correction2 = 1 - beta2 ** state["step"]
118
+
119
+ exp_avg_sq_hat = exp_avg_sq / bias_correction2
120
+
121
+ denom = exp_avg_sq_hat.sqrt().add(group["eps"])
122
+
123
+ step_size = group["lr"] / bias_correction1
124
+ p.addcdiv_(exp_avg, denom, value=-step_size)
125
+
126
+ # Make sure internal tensors are still leaf tensors
127
+ # state['exp_avg'] = state['exp_avg'].detach()
128
+ # state['exp_avg_sq'] = state['exp_avg_sq'].detach()
129
+
130
+ return loss
131
+
132
+
133
+ def get_optimiser(model, optimiser=AdamW, lr=7e-4, weight_decay=5e-2):
134
+ """
135
+ Defaults are from one of the presets from the accompanying repo to Hassani
136
+ et al. (2023) "Escaping the Big Data Paradigm with Compact Transformers",
137
+ https://github.com/SHI-Labs/Compact-Transformers/blob/main/configs/
138
+ pretrained/cct_7-3x1_cifar100_1500epochs.yml
139
+ """
140
+ weight_decay_exclude = []
141
+ for keyword in [
142
+ "bias",
143
+ "norm",
144
+ "embedding",
145
+ "swiglu_beta",
146
+ "sigma",
147
+ "scale",
148
+ "input_query",
149
+ "reentrant_query",
150
+ ]:
151
+ weight_decay_exclude += [
152
+ p for name, p in model.named_parameters() if keyword in name.lower()
153
+ ]
154
+ weight_decay_exclude = set(weight_decay_exclude)
155
+ weight_decay_include = set(model.parameters()) - weight_decay_exclude
156
+ return optimiser(
157
+ [
158
+ {"params": list(weight_decay_include)},
159
+ {"params": list(weight_decay_exclude), "weight_decay": 0.0},
160
+ ],
161
+ weight_decay=weight_decay,
162
+ lr=lr,
163
+ )
@@ -0,0 +1,212 @@
1
+ """
2
+ Based on Smith (2017) https://arxiv.org/abs/1506.01186
3
+ """
4
+
5
+ from typing import Optional
6
+ import copy
7
+ import math
8
+
9
+ from scipy.ndimage import gaussian_filter1d
10
+
11
+ from torch.amp import GradScaler
12
+
13
+ from .cycles import Cycle
14
+
15
+
16
+ class PASS:
17
+ """
18
+ A self-configuring learning rate scheduler
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ learning_rate_schedule: Cycle,
24
+ model,
25
+ optimiser,
26
+ scaler: Optional[GradScaler] = None,
27
+ range_test: bool = True,
28
+ max_lr: float = None,
29
+ cool_point: float = None,
30
+ ):
31
+ assert (max_lr is not None) == (cool_point is not None)
32
+ assert (
33
+ ((max_lr is not None) and (cool_point is not None))
34
+ != range_test
35
+ is not None
36
+ )
37
+
38
+ self.model = model
39
+ self.optimiser = optimiser
40
+ self.scaler = scaler
41
+
42
+ self.learning_rate_schedule = learning_rate_schedule
43
+
44
+ self.range_test = range_test
45
+
46
+ self.max_lr = max_lr
47
+ self.cool_point = cool_point
48
+
49
+ self.original_states = self._saved_states()
50
+
51
+ self.range_test_results = []
52
+
53
+ self.step_count = 0
54
+
55
+ if range_test:
56
+ self.start_range_test() # sets LR to 1E-7
57
+
58
+ @property
59
+ def lr(self):
60
+ """
61
+ Return first lr from self.optimiser.param_groups
62
+ (we assume they are all the same!)
63
+ """
64
+ for group in self.optimiser.param_groups:
65
+ return group["lr"]
66
+
67
+ @property
68
+ def in_range_test(self):
69
+ if not self.range_test:
70
+ return False
71
+ elif (len(self.range_test_results) == 0) or (
72
+ not math.isnan(self.range_test_results[-1][1])
73
+ ):
74
+ return True
75
+ else:
76
+ return False
77
+
78
+ @property
79
+ def trained(self):
80
+ if not self.range_test:
81
+ return True
82
+ elif math.isnan(self.range_test_results[-1][1]):
83
+ return True
84
+ else:
85
+ return False
86
+
87
+ @property
88
+ def finished(self):
89
+ return self.step_count >= len(self.learning_rate_schedule) - 1
90
+
91
+ def _saved_states(self):
92
+ saved_states = {
93
+ "model": copy.deepcopy(self.model.state_dict()),
94
+ "optimiser": copy.deepcopy(self.optimiser.state_dict()),
95
+ }
96
+ if self.scaler is not None:
97
+ saved_states["scaler"] = copy.deepcopy(self.scaler.state_dict())
98
+ return saved_states
99
+
100
+ def save_states(self):
101
+ self.saved_states = self._saved_states()
102
+
103
+ def load_states(self, saved_states):
104
+ self.model.load_state_dict(saved_states["model"])
105
+ self.optimiser.load_state_dict(saved_states["optimiser"])
106
+ if self.scaler is not None:
107
+ self.scaler.load_state_dict(saved_states["scaler"])
108
+
109
+ def recover_states(self):
110
+ self.load_states(self.saved_states)
111
+
112
+ @property
113
+ def _schedule_lr(self):
114
+ return (
115
+ self.learning_rate_schedule(
116
+ min(self.step_count, self.learning_rate_schedule.total_steps)
117
+ )
118
+ * (self.max_lr - self.cool_point)
119
+ + self.cool_point
120
+ )
121
+
122
+ def set_lr(self, lr):
123
+ for group in self.optimiser.param_groups:
124
+ group["lr"] = lr
125
+
126
+ def scale_lr(self, scaling_factor):
127
+ self.set_lr(self.lr * scaling_factor)
128
+
129
+ def start_range_test(self):
130
+ self.save_states()
131
+ self.optimiser.load_state_dict(self.original_states["optimiser"])
132
+ if self.scaler is not None:
133
+ self.scaler.load_state_dict(self.original_states["scaler"])
134
+ self.set_lr(1e-7)
135
+
136
+ def end_range_test(self):
137
+ self.recover_states()
138
+ self.update_learning_rates()
139
+
140
+ def _smoothed_range_test(self, range_test_results):
141
+ range_test_results = sorted(range_test_results, key=lambda x: x[0])
142
+ learning_rates = [t[0] for t in range_test_results]
143
+ losses = [t[1] for t in self.range_test_results]
144
+ losses = losses[:-1] + [10 * max(losses)]
145
+ smoothed_losses = gaussian_filter1d([t[1] for t in range_test_results][:-1], 3)
146
+ return list(zip(learning_rates, smoothed_losses, strict=True))
147
+
148
+ def _plot_range_test(self, range_test_results):
149
+ """
150
+ Returns a tuple with x values (learning rates) and y values (losses)
151
+ which can then be passed to e.g. pyplot. We recommend presenting
152
+ the plot with a logarithmic x axis.
153
+ """
154
+ range_test_results = sorted(range_test_results, key=lambda x: x[0])
155
+ learning_rates = [t[0] for t in range_test_results]
156
+ losses = [t[1] for t in range_test_results]
157
+ return learning_rates, losses
158
+
159
+ def _apply_range_test_result(self):
160
+ """
161
+ ...
162
+ """
163
+ range_test_results = self._smoothed_range_test(self.range_test_results)
164
+ self._plot_range_test(range_test_results)
165
+
166
+ minimum = min(range_test_results, key=lambda x: x[1])
167
+ points_left_of_min = [p for p in range_test_results if p[0] < minimum[0]]
168
+ highest_point_left_of_min = max(points_left_of_min, key=lambda x: x[1])
169
+ halfway = (highest_point_left_of_min[1] + minimum[1]) / 2
170
+ for r in range_test_results:
171
+ if r[1] < halfway:
172
+ self.max_lr = r[0] * 3
173
+ self.cool_point = r[0] / 3
174
+ print("High LR", self.max_lr)
175
+ print("Cool point", self.cool_point)
176
+ break
177
+
178
+ def update_learning_rates(self):
179
+ if self.finished:
180
+ pass
181
+ else:
182
+ self.set_lr(self._schedule_lr)
183
+
184
+ def _append_to_range_test(self, loss_item: float):
185
+
186
+ self.range_test_results.append((self.lr, loss_item))
187
+
188
+ if math.isnan(loss_item):
189
+ self._apply_range_test_result()
190
+ self.end_range_test()
191
+ else:
192
+ # Continue range test, step up learning rate
193
+ self.scale_lr(1.05)
194
+
195
+ def step(self, loss_item: float):
196
+ """
197
+ This function manages the process of
198
+ * Doing an initial range test
199
+ * Training for one microcycle using the learning rates from the
200
+ initial range test ("burn in")
201
+ * Doing a second range test to set the learning rate schedule for
202
+ the rest of training
203
+ * Updating learning rates during training according to the macrocycle
204
+ """
205
+ if self.in_range_test: # True at init unless self.range_test = False
206
+ assert self.step_count == 0 # No weight updates yet
207
+ self._append_to_range_test(loss_item)
208
+ elif self.trained and not self.finished:
209
+ self.step_count += 1
210
+ self.update_learning_rates()
211
+ else:
212
+ pass
@@ -0,0 +1,39 @@
1
+ [project]
2
+ name = "gradboard"
3
+ version = "0.1.1"
4
+ description = "Easily snowboard down gnarly loss gradients"
5
+ authors = [
6
+ {name = "Nicholas Bailey"}
7
+ ]
8
+ license = {text = "MIT"}
9
+ readme = "README.md"
10
+ requires-python = ">=3.11"
11
+ dependencies = [
12
+ "numpy (>=2.0.2,<3.0.0)",
13
+ "scipy (>=1.15.3,<2.0.0)"
14
+ ]
15
+
16
+ [tool.poetry]
17
+
18
+ [tool.poetry.group.dev.dependencies]
19
+ black = "^25.1.0"
20
+ flake8 = "7.3.0"
21
+ pytest = "^8.4.1"
22
+ pytest-cov = "^6.2.1"
23
+
24
+ [tool.black]
25
+ line-length = 88
26
+ target-version = ['py312']
27
+ include = '\.pyi?$'
28
+ extend-exclude = '''
29
+ # A regex preceded with ^/ will apply only to files and directories
30
+ # in the root of the project.
31
+ (
32
+ ^/foo.py # exclude a file named foo.py in the root of the project
33
+ | .*_pb2.py # exclude autogenerated Protocol Buffer files anywhere in the project
34
+ )
35
+ '''
36
+
37
+ [build-system]
38
+ requires = ["poetry-core>=2.0.0,<3.0.0"]
39
+ build-backend = "poetry.core.masonry.api"