gradboard 2.5.0__tar.gz → 3.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of gradboard might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: gradboard
3
- Version: 2.5.0
3
+ Version: 3.0.0
4
4
  Summary: Easily snowboard down gnarly loss gradients
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -24,13 +24,12 @@ class PASS:
24
24
  model,
25
25
  optimiser,
26
26
  scaler: Optional[GradScaler] = None,
27
- range_test: bool = True,
27
+ range_test: bool = False,
28
28
  max_lr: float = None,
29
- cool_point: float = None,
29
+ cool_point_multiplier: float = 1 / 60,
30
30
  ):
31
31
  if not range_test:
32
32
  assert max_lr is not None
33
- assert cool_point is not None
34
33
 
35
34
  self.model = model
36
35
  self.optimiser = optimiser
@@ -40,8 +39,10 @@ class PASS:
40
39
 
41
40
  self.range_test = range_test
42
41
 
42
+ self.original_param_groups = copy.deepcopy(optimiser.param_groups)
43
+
43
44
  self.max_lr = max_lr
44
- self.cool_point = cool_point
45
+ self.cool_point_multiplier = cool_point_multiplier
45
46
 
46
47
  self.original_states = self._saved_states()
47
48
 
@@ -56,7 +57,8 @@ class PASS:
56
57
  def lr(self):
57
58
  """
58
59
  Return first lr from self.optimiser.param_groups
59
- (we assume they are all the same!)
60
+ (this is used in learning rate range tests, in which case we can
61
+ assume they are all the same!)
60
62
  """
61
63
  for group in self.optimiser.param_groups:
62
64
  return group["lr"]
@@ -107,28 +109,24 @@ class PASS:
107
109
  self.load_states(self.saved_states)
108
110
 
109
111
  @property
110
- def _schedule_lr(self):
111
- return (
112
- self.learning_rate_schedule(
113
- min(self.step_count, self.learning_rate_schedule.total_steps)
114
- )
115
- * (self.max_lr - self.cool_point)
116
- + self.cool_point
112
+ def _schedule_multiplier(self):
113
+ return self.learning_rate_schedule(
114
+ min(self.step_count, self.learning_rate_schedule.total_steps)
117
115
  )
118
116
 
119
- def set_lr(self, lr):
117
+ def set_all_lr(self, lr):
120
118
  for group in self.optimiser.param_groups:
121
119
  group["lr"] = lr
122
120
 
123
- def scale_lr(self, scaling_factor):
124
- self.set_lr(self.lr * scaling_factor)
125
-
126
121
  def start_range_test(self):
127
122
  self.save_states()
128
123
  self.optimiser.load_state_dict(self.original_states["optimiser"])
129
124
  if self.scaler is not None:
130
125
  self.scaler.load_state_dict(self.original_states["scaler"])
131
- self.set_lr(1e-7)
126
+ self.set_all_lr(1e-7)
127
+
128
+ def scale_all_lr(self, scaling_factor):
129
+ self.set_all_lr(self.lr * scaling_factor)
132
130
 
133
131
  def end_range_test(self):
134
132
  self.recover_states()
@@ -163,21 +161,24 @@ class PASS:
163
161
  max_left_of_min = max(points_left_of_min, key=lambda x: x[1])
164
162
  difference = max_left_of_min[1] - minimum[1]
165
163
  self.max_lr = None
166
- self.cool_point = None
167
164
  for p in sorted(points_left_of_min, key=lambda x: x[0]):
168
165
  if (self.max_lr is None) and (p[1] < minimum[1] + 0.2 * difference):
169
166
  self.max_lr = p[0]
170
167
  else:
171
168
  continue
172
- self.cool_point = self.max_lr / 60
169
+ self.set_all_lr(self.max_lr)
170
+ self.original_param_groups = copy.deepcopy(self.optimiser.param_groups)
173
171
  print("High LR", self.max_lr)
174
- print("Cool point", self.cool_point)
175
172
 
176
173
  def update_learning_rates(self):
177
- if self.finished:
178
- pass
179
- else:
180
- self.set_lr(self._schedule_lr)
174
+ if not self.finished:
175
+ for original, current in zip(
176
+ self.original_param_groups, self.optimiser.param_groups, strict=True
177
+ ):
178
+ base_lr = original["lr"]
179
+ min_lr = base_lr * self.cool_point_multiplier
180
+ current_lr = min_lr + (base_lr - min_lr) * self._schedule_multiplier
181
+ current["lr"] = current_lr
181
182
 
182
183
  def _append_to_range_test(self, loss_item: float):
183
184
 
@@ -188,7 +189,7 @@ class PASS:
188
189
  self.end_range_test()
189
190
  else:
190
191
  # Continue range test, step up learning rate
191
- self.scale_lr(1.05)
192
+ self.scale_all_lr(1.05)
192
193
 
193
194
  def step(self, loss_item: float):
194
195
  """
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "gradboard"
3
- version = "2.5.0"
3
+ version = "3.0.0"
4
4
  description = "Easily snowboard down gnarly loss gradients"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes