kostyl-toolkit 0.1.35__py3-none-any.whl → 0.1.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kostyl/ml/configs/hyperparams.py +21 -5
- kostyl/ml/configs/training_settings.py +17 -6
- kostyl/ml/dist_utils.py +52 -30
- kostyl/ml/lightning/callbacks/checkpoint.py +10 -10
- kostyl/ml/lightning/extensions/custom_module.py +0 -5
- kostyl/ml/lightning/extensions/pretrained_model.py +6 -4
- kostyl/ml/lightning/loggers/tb_logger.py +2 -2
- kostyl/ml/lightning/utils.py +58 -0
- kostyl/ml/registry_uploader.py +56 -29
- kostyl/ml/schedulers/__init__.py +13 -1
- kostyl/ml/schedulers/base.py +9 -7
- kostyl/ml/schedulers/cosine.py +53 -24
- kostyl/ml/schedulers/cosine_with_plateu.py +277 -0
- kostyl/ml/schedulers/linear.py +36 -11
- kostyl/utils/logging.py +68 -53
- {kostyl_toolkit-0.1.35.dist-info → kostyl_toolkit-0.1.37.dist-info}/METADATA +1 -1
- {kostyl_toolkit-0.1.35.dist-info → kostyl_toolkit-0.1.37.dist-info}/RECORD +18 -17
- {kostyl_toolkit-0.1.35.dist-info → kostyl_toolkit-0.1.37.dist-info}/WHEEL +1 -1
- kostyl/ml/lightning/training_utils.py +0 -241
kostyl/ml/schedulers/cosine.py
CHANGED
|
@@ -2,7 +2,6 @@ from typing import Any
|
|
|
2
2
|
from typing import override
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
|
-
import numpy.typing as npt
|
|
6
5
|
import torch
|
|
7
6
|
|
|
8
7
|
from .base import BaseScheduler
|
|
@@ -29,18 +28,24 @@ class _CosineSchedulerCore(BaseScheduler):
|
|
|
29
28
|
if freeze_ratio is not None:
|
|
30
29
|
if not (0 < freeze_ratio < 1):
|
|
31
30
|
raise ValueError(f"Freeze ratio must be in (0, 1), got {freeze_ratio}.")
|
|
31
|
+
pre_annealing_ratio = (warmup_ratio if warmup_ratio is not None else 0) + (
|
|
32
|
+
freeze_ratio if freeze_ratio is not None else 0
|
|
33
|
+
)
|
|
34
|
+
if pre_annealing_ratio > 1:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
"The sum of warmup_ratio and freeze_ratio must <= 1, got "
|
|
37
|
+
f"{pre_annealing_ratio}."
|
|
38
|
+
)
|
|
32
39
|
|
|
33
40
|
self.param_name = param_name
|
|
34
41
|
self.num_iters = num_iters
|
|
35
42
|
self.base_value = base_value
|
|
36
43
|
self.final_value = final_value
|
|
37
|
-
|
|
38
44
|
self.warmup_ratio = warmup_ratio
|
|
39
45
|
self.warmup_value = warmup_value
|
|
40
|
-
|
|
41
46
|
self.freeze_ratio = freeze_ratio
|
|
42
47
|
|
|
43
|
-
self.
|
|
48
|
+
self.scheduled_values: np.ndarray = np.array([], dtype=np.float64)
|
|
44
49
|
self.current_value_ = self.base_value
|
|
45
50
|
return
|
|
46
51
|
|
|
@@ -63,31 +68,29 @@ class _CosineSchedulerCore(BaseScheduler):
|
|
|
63
68
|
warmup_iters = 0
|
|
64
69
|
warmup_schedule = np.array([], dtype=np.float64)
|
|
65
70
|
|
|
71
|
+
# Create cosine annealing schedule
|
|
66
72
|
cosine_annealing_iters = self.num_iters - warmup_iters - freeze_iters
|
|
67
|
-
if cosine_annealing_iters
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
)
|
|
73
|
+
if cosine_annealing_iters > 0:
|
|
74
|
+
iters = np.arange(cosine_annealing_iters)
|
|
75
|
+
cosine_annealing_schedule = self.final_value + 0.5 * (
|
|
76
|
+
self.base_value - self.final_value
|
|
77
|
+
) * (1 + np.cos(np.pi * iters / len(iters)))
|
|
78
|
+
else:
|
|
79
|
+
cosine_annealing_schedule = np.array([], dtype=np.float64)
|
|
75
80
|
|
|
76
81
|
# Concatenate all parts of the schedule
|
|
77
|
-
self.
|
|
78
|
-
(freeze_schedule, warmup_schedule,
|
|
82
|
+
self.scheduled_values = np.concatenate(
|
|
83
|
+
(freeze_schedule, warmup_schedule, cosine_annealing_schedule)
|
|
79
84
|
)
|
|
80
|
-
|
|
81
|
-
if len(self.scheduler_values) != self.num_iters:
|
|
82
|
-
raise ValueError(
|
|
83
|
-
f"Scheduler length ({len(self.scheduler_values)}) does not match num_iters ({self.num_iters})."
|
|
84
|
-
)
|
|
85
|
+
self._verify()
|
|
85
86
|
return
|
|
86
87
|
|
|
87
88
|
@override
|
|
88
|
-
def
|
|
89
|
-
|
|
90
|
-
|
|
89
|
+
def _verify(self) -> None:
|
|
90
|
+
if len(self.scheduled_values) != self.num_iters:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"Scheduler length ({len(self.scheduled_values)}) does not match num_iters ({self.num_iters})."
|
|
93
|
+
)
|
|
91
94
|
return
|
|
92
95
|
|
|
93
96
|
@override
|
|
@@ -95,13 +98,13 @@ class _CosineSchedulerCore(BaseScheduler):
|
|
|
95
98
|
raise NotImplementedError
|
|
96
99
|
|
|
97
100
|
def _get_value(self, it: int) -> float:
|
|
98
|
-
if len(self.
|
|
101
|
+
if len(self.scheduled_values) == 0:
|
|
99
102
|
self._create_scheduler()
|
|
100
103
|
|
|
101
104
|
if it >= self.num_iters:
|
|
102
105
|
value: float = self.final_value
|
|
103
106
|
else:
|
|
104
|
-
value: float = self.
|
|
107
|
+
value: float = self.scheduled_values[it]
|
|
105
108
|
self.current_value_ = value
|
|
106
109
|
return value
|
|
107
110
|
|
|
@@ -163,6 +166,21 @@ class CosineScheduler(_CosineSchedulerCore):
|
|
|
163
166
|
self.param_group_field = param_group_field
|
|
164
167
|
return
|
|
165
168
|
|
|
169
|
+
@override
|
|
170
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
171
|
+
self.__dict__.update(state_dict)
|
|
172
|
+
self.scheduled_values = np.array([], dtype=np.float64)
|
|
173
|
+
return
|
|
174
|
+
|
|
175
|
+
@override
|
|
176
|
+
def state_dict(self) -> dict[str, Any]:
|
|
177
|
+
state = {
|
|
178
|
+
k: v
|
|
179
|
+
for k, v in self.__dict__.items()
|
|
180
|
+
if k not in ["scheduled_values", "optimizer"]
|
|
181
|
+
}
|
|
182
|
+
return state
|
|
183
|
+
|
|
166
184
|
@override
|
|
167
185
|
def step(self, it: int) -> None:
|
|
168
186
|
value = self._get_value(it)
|
|
@@ -209,3 +227,14 @@ class CosineParamScheduler(_CosineSchedulerCore):
|
|
|
209
227
|
"""
|
|
210
228
|
value = self._get_value(it)
|
|
211
229
|
return value
|
|
230
|
+
|
|
231
|
+
@override
|
|
232
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
233
|
+
self.__dict__.update(state_dict)
|
|
234
|
+
self.scheduled_values = np.array([], dtype=np.float64)
|
|
235
|
+
return
|
|
236
|
+
|
|
237
|
+
@override
|
|
238
|
+
def state_dict(self) -> dict[str, Any]:
|
|
239
|
+
state = {k: v for k, v in self.__dict__.items() if k != "scheduled_values"}
|
|
240
|
+
return state
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from typing import override
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from .base import BaseScheduler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _CosineWithPlateauSchedulerCore(BaseScheduler):
|
|
11
|
+
"""Core cosine with plateau scheduler logic."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
param_name: str,
|
|
16
|
+
num_iters: int,
|
|
17
|
+
base_value: float,
|
|
18
|
+
final_value: float,
|
|
19
|
+
plateau_ratio: float,
|
|
20
|
+
warmup_value: float | None = None,
|
|
21
|
+
warmup_ratio: float | None = None,
|
|
22
|
+
freeze_ratio: float | None = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
if warmup_ratio is not None:
|
|
25
|
+
if not (0 < warmup_ratio < 1):
|
|
26
|
+
raise ValueError(f"Warmup ratio must be in (0, 1), got {warmup_ratio}.")
|
|
27
|
+
if (warmup_value is None) != (warmup_ratio is None):
|
|
28
|
+
raise ValueError(
|
|
29
|
+
"Both warmup_ratio and warmup_value must be provided or neither."
|
|
30
|
+
)
|
|
31
|
+
if freeze_ratio is not None:
|
|
32
|
+
if not (0 < freeze_ratio < 1):
|
|
33
|
+
raise ValueError(f"Freeze ratio must be in (0, 1), got {freeze_ratio}.")
|
|
34
|
+
if not (0 < plateau_ratio < 1):
|
|
35
|
+
raise ValueError(f"Plateau ratio must be in (0, 1), got {plateau_ratio}.")
|
|
36
|
+
|
|
37
|
+
pre_annealing_ratio = (
|
|
38
|
+
plateau_ratio
|
|
39
|
+
+ (warmup_ratio if warmup_ratio is not None else 0)
|
|
40
|
+
+ (freeze_ratio if freeze_ratio is not None else 0)
|
|
41
|
+
)
|
|
42
|
+
if pre_annealing_ratio > 1:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"The sum of plateau_ratio, warmup_ratio, and freeze_ratio must <= 1, got "
|
|
45
|
+
f"{pre_annealing_ratio}."
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
self.param_name = param_name
|
|
49
|
+
self.num_iters = num_iters
|
|
50
|
+
self.base_value = base_value
|
|
51
|
+
self.final_value = final_value
|
|
52
|
+
self.cosine_annealing_ratio = 1 - pre_annealing_ratio
|
|
53
|
+
self.plateau_ratio = plateau_ratio
|
|
54
|
+
self.warmup_ratio = warmup_ratio
|
|
55
|
+
self.warmup_value = warmup_value
|
|
56
|
+
self.freeze_ratio = freeze_ratio
|
|
57
|
+
|
|
58
|
+
self.scheduled_values: np.ndarray = np.array([], dtype=np.float64)
|
|
59
|
+
self.current_value_ = self.base_value
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
def _create_scheduler(self) -> None:
|
|
63
|
+
# Create freeze schedule
|
|
64
|
+
if self.freeze_ratio is not None:
|
|
65
|
+
freeze_iters = int(self.num_iters * self.freeze_ratio)
|
|
66
|
+
freeze_schedule = np.zeros(freeze_iters, dtype=np.float64)
|
|
67
|
+
else:
|
|
68
|
+
freeze_iters = 0
|
|
69
|
+
freeze_schedule = np.array([], dtype=np.float64)
|
|
70
|
+
|
|
71
|
+
# Create linear warmup schedule
|
|
72
|
+
if self.warmup_ratio is not None and self.warmup_value is not None:
|
|
73
|
+
warmup_iters = int(self.num_iters * self.warmup_ratio)
|
|
74
|
+
warmup_schedule = np.linspace(
|
|
75
|
+
self.warmup_value, self.base_value, warmup_iters, dtype=np.float64
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
warmup_iters = 0
|
|
79
|
+
warmup_schedule = np.array([], dtype=np.float64)
|
|
80
|
+
|
|
81
|
+
# Create cosine annealing schedule
|
|
82
|
+
if self.cosine_annealing_ratio > 0:
|
|
83
|
+
cosine_annealing_iters = int(self.num_iters * self.cosine_annealing_ratio)
|
|
84
|
+
iters = np.arange(cosine_annealing_iters)
|
|
85
|
+
cosine_annealing_schedule = self.final_value + 0.5 * (
|
|
86
|
+
self.base_value - self.final_value
|
|
87
|
+
) * (1 + np.cos(np.pi * iters / len(iters)))
|
|
88
|
+
else:
|
|
89
|
+
cosine_annealing_iters = 0
|
|
90
|
+
cosine_annealing_schedule = np.array([], dtype=np.float64)
|
|
91
|
+
|
|
92
|
+
plateau_iters = (
|
|
93
|
+
self.num_iters - warmup_iters - freeze_iters - cosine_annealing_iters
|
|
94
|
+
)
|
|
95
|
+
if plateau_iters > 0:
|
|
96
|
+
plateau_schedule = np.full(plateau_iters, self.base_value, dtype=np.float64)
|
|
97
|
+
else:
|
|
98
|
+
plateau_schedule = np.array([], dtype=np.float64)
|
|
99
|
+
|
|
100
|
+
# Concatenate all parts of the schedule
|
|
101
|
+
self.scheduled_values = np.concatenate(
|
|
102
|
+
(
|
|
103
|
+
freeze_schedule,
|
|
104
|
+
warmup_schedule,
|
|
105
|
+
plateau_schedule,
|
|
106
|
+
cosine_annealing_schedule,
|
|
107
|
+
)
|
|
108
|
+
)
|
|
109
|
+
self._verify()
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
@override
|
|
113
|
+
def _verify(self) -> None:
|
|
114
|
+
if len(self.scheduled_values) != self.num_iters:
|
|
115
|
+
raise ValueError(
|
|
116
|
+
f"Scheduler length ({len(self.scheduled_values)}) does not match num_iters ({self.num_iters})."
|
|
117
|
+
)
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
@override
|
|
121
|
+
def step(self, it: int) -> None | float:
|
|
122
|
+
raise NotImplementedError
|
|
123
|
+
|
|
124
|
+
def _get_value(self, it: int) -> float:
|
|
125
|
+
if len(self.scheduled_values) == 0:
|
|
126
|
+
self._create_scheduler()
|
|
127
|
+
|
|
128
|
+
if it >= self.num_iters:
|
|
129
|
+
value: float = self.final_value
|
|
130
|
+
else:
|
|
131
|
+
value: float = self.scheduled_values[it]
|
|
132
|
+
self.current_value_ = value
|
|
133
|
+
return value
|
|
134
|
+
|
|
135
|
+
@override
|
|
136
|
+
def current_value(self) -> dict[str, float]:
|
|
137
|
+
return {self.param_name: self.current_value_}
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class CosineWithPlateuScheduler(_CosineWithPlateauSchedulerCore):
|
|
141
|
+
"""
|
|
142
|
+
Applies a cosine schedule with plateau to an optimizer param-group field.
|
|
143
|
+
|
|
144
|
+
Schedule phases: freeze (0) → warmup → plateau (base_value) → cosine annealing to final_value.
|
|
145
|
+
The plateau phase maintains the base_value before cosine annealing begins.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
def __init__(
|
|
149
|
+
self,
|
|
150
|
+
optimizer: torch.optim.Optimizer,
|
|
151
|
+
param_group_field: str,
|
|
152
|
+
num_iters: int,
|
|
153
|
+
base_value: float,
|
|
154
|
+
final_value: float,
|
|
155
|
+
plateau_ratio: float,
|
|
156
|
+
warmup_value: float | None = None,
|
|
157
|
+
warmup_ratio: float | None = None,
|
|
158
|
+
freeze_ratio: float | None = None,
|
|
159
|
+
multiplier_field: str | None = None,
|
|
160
|
+
skip_if_zero: bool = False,
|
|
161
|
+
apply_if_field: str | None = None,
|
|
162
|
+
ignore_if_field: str | None = None,
|
|
163
|
+
) -> None:
|
|
164
|
+
"""
|
|
165
|
+
Configure cosine scheduling for matching optimizer groups.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
optimizer: Optimizer whose param groups are updated in-place.
|
|
169
|
+
param_group_field: Name of the field that receives the scheduled value.
|
|
170
|
+
num_iters: Number of scheduler iterations before clamping at ``final_value``.
|
|
171
|
+
base_value: Value maintained during plateau phase and used as cosine start.
|
|
172
|
+
final_value: Value approached as iterations progress during cosine annealing.
|
|
173
|
+
plateau_ratio: Fraction of iterations to maintain ``base_value`` before cosine annealing.
|
|
174
|
+
warmup_ratio: Optional fraction of iterations to linearly ramp from ``warmup_value`` to ``base_value``.
|
|
175
|
+
warmup_value: Starting value for the warmup ramp.
|
|
176
|
+
freeze_ratio: Optional fraction of iterations to keep the value frozen at zero at the beginning.
|
|
177
|
+
multiplier_field: Optional per-group multiplier applied to the scheduled value.
|
|
178
|
+
skip_if_zero: Leave groups untouched when their target field equals zero.
|
|
179
|
+
apply_if_field: Require this flag to be present in a param group before updating.
|
|
180
|
+
ignore_if_field: Skip groups that declare this flag.
|
|
181
|
+
|
|
182
|
+
"""
|
|
183
|
+
self.apply_if_field = apply_if_field
|
|
184
|
+
self.ignore_if_field = ignore_if_field
|
|
185
|
+
self.optimizer = optimizer
|
|
186
|
+
self.multiplier_field = multiplier_field
|
|
187
|
+
self.skip_if_zero = skip_if_zero
|
|
188
|
+
super().__init__(
|
|
189
|
+
param_name=param_group_field,
|
|
190
|
+
num_iters=num_iters,
|
|
191
|
+
base_value=base_value,
|
|
192
|
+
final_value=final_value,
|
|
193
|
+
plateau_ratio=plateau_ratio,
|
|
194
|
+
warmup_ratio=warmup_ratio,
|
|
195
|
+
warmup_value=warmup_value,
|
|
196
|
+
freeze_ratio=freeze_ratio,
|
|
197
|
+
)
|
|
198
|
+
self.param_group_field = param_group_field
|
|
199
|
+
return
|
|
200
|
+
|
|
201
|
+
@override
|
|
202
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
203
|
+
self.__dict__.update(state_dict)
|
|
204
|
+
self.scheduled_values = np.array([], dtype=np.float64)
|
|
205
|
+
return
|
|
206
|
+
|
|
207
|
+
@override
|
|
208
|
+
def state_dict(self) -> dict[str, Any]:
|
|
209
|
+
state = {
|
|
210
|
+
k: v
|
|
211
|
+
for k, v in self.__dict__.items()
|
|
212
|
+
if k not in ["scheduled_values", "optimizer"]
|
|
213
|
+
}
|
|
214
|
+
return state
|
|
215
|
+
|
|
216
|
+
@override
|
|
217
|
+
def step(self, it: int) -> None:
|
|
218
|
+
value = self._get_value(it)
|
|
219
|
+
for pg in self.optimizer.param_groups:
|
|
220
|
+
if self.param_group_field not in pg:
|
|
221
|
+
raise ValueError(
|
|
222
|
+
f"Parameter group field '{self.param_group_field}' not found in optimizer parameter groups."
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
if (self.apply_if_field is not None) and (self.apply_if_field not in pg):
|
|
226
|
+
continue
|
|
227
|
+
|
|
228
|
+
if (self.ignore_if_field is not None) and (self.ignore_if_field in pg):
|
|
229
|
+
continue
|
|
230
|
+
|
|
231
|
+
if self.skip_if_zero and pg[self.param_group_field] == 0:
|
|
232
|
+
continue
|
|
233
|
+
|
|
234
|
+
if self.multiplier_field is not None:
|
|
235
|
+
if self.multiplier_field not in pg:
|
|
236
|
+
multiplier = 1.0
|
|
237
|
+
else:
|
|
238
|
+
multiplier = pg[self.multiplier_field]
|
|
239
|
+
pg[self.param_group_field] = value * multiplier
|
|
240
|
+
else:
|
|
241
|
+
pg[self.param_group_field] = value
|
|
242
|
+
return
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class CosineWithPlateauParamScheduler(_CosineWithPlateauSchedulerCore):
|
|
246
|
+
"""
|
|
247
|
+
Standalone cosine scheduler with plateau for non-optimizer parameters.
|
|
248
|
+
|
|
249
|
+
Schedule phases: freeze (0) → warmup → plateau (base_value) → cosine annealing to final_value.
|
|
250
|
+
The plateau phase maintains the base_value before cosine annealing begins.
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
@override
|
|
254
|
+
def step(self, it: int) -> float:
|
|
255
|
+
"""
|
|
256
|
+
Computes the value corresponding to the given iteration step.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
it: The current iteration index used for value computation.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
The computed value for the provided iteration step as a float.
|
|
263
|
+
|
|
264
|
+
"""
|
|
265
|
+
value = self._get_value(it)
|
|
266
|
+
return value
|
|
267
|
+
|
|
268
|
+
@override
|
|
269
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
270
|
+
self.__dict__.update(state_dict)
|
|
271
|
+
self.scheduled_values = np.array([], dtype=np.float64)
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
@override
|
|
275
|
+
def state_dict(self) -> dict[str, Any]:
|
|
276
|
+
state = {k: v for k, v in self.__dict__.items() if k != "scheduled_values"}
|
|
277
|
+
return state
|
kostyl/ml/schedulers/linear.py
CHANGED
|
@@ -21,24 +21,23 @@ class _LinearScheduleBase(BaseScheduler):
|
|
|
21
21
|
self.start_value = start_value
|
|
22
22
|
self.final_value = final_value
|
|
23
23
|
|
|
24
|
-
self.
|
|
24
|
+
self.scheduled_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
|
|
25
25
|
self.current_value_ = self.start_value
|
|
26
26
|
return
|
|
27
27
|
|
|
28
28
|
def _create_scheduler(self) -> None:
|
|
29
|
-
self.
|
|
29
|
+
self.scheduled_values = np.linspace(
|
|
30
30
|
self.start_value, self.final_value, num=self.num_iters, dtype=np.float64
|
|
31
31
|
)
|
|
32
|
-
|
|
33
|
-
raise ValueError(
|
|
34
|
-
f"Scheduler length ({len(self.scheduler_values)}) does not match total_iters ({self.num_iters})."
|
|
35
|
-
)
|
|
32
|
+
self._verify()
|
|
36
33
|
return
|
|
37
34
|
|
|
38
35
|
@override
|
|
39
|
-
def
|
|
40
|
-
|
|
41
|
-
|
|
36
|
+
def _verify(self) -> None:
|
|
37
|
+
if len(self.scheduled_values) != self.num_iters:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
f"Scheduler length ({len(self.scheduled_values)}) does not match total_iters ({self.num_iters})."
|
|
40
|
+
)
|
|
42
41
|
return
|
|
43
42
|
|
|
44
43
|
@override
|
|
@@ -46,13 +45,13 @@ class _LinearScheduleBase(BaseScheduler):
|
|
|
46
45
|
raise NotImplementedError
|
|
47
46
|
|
|
48
47
|
def _get_value(self, it: int) -> float:
|
|
49
|
-
if len(self.
|
|
48
|
+
if len(self.scheduled_values) == 0:
|
|
50
49
|
self._create_scheduler()
|
|
51
50
|
|
|
52
51
|
if it >= self.num_iters:
|
|
53
52
|
value: float = self.final_value
|
|
54
53
|
else:
|
|
55
|
-
value: float = self.
|
|
54
|
+
value: float = self.scheduled_values[it]
|
|
56
55
|
self.current_value_ = value
|
|
57
56
|
return value
|
|
58
57
|
|
|
@@ -105,6 +104,21 @@ class LinearScheduler(_LinearScheduleBase):
|
|
|
105
104
|
self.param_group_field = param_group_field
|
|
106
105
|
return
|
|
107
106
|
|
|
107
|
+
@override
|
|
108
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
109
|
+
self.__dict__.update(state_dict)
|
|
110
|
+
self.scheduled_values = np.array([], dtype=np.float64)
|
|
111
|
+
return
|
|
112
|
+
|
|
113
|
+
@override
|
|
114
|
+
def state_dict(self) -> dict[str, Any]:
|
|
115
|
+
state = {
|
|
116
|
+
k: v
|
|
117
|
+
for k, v in self.__dict__.items()
|
|
118
|
+
if k not in ["scheduled_values", "optimizer"]
|
|
119
|
+
}
|
|
120
|
+
return state
|
|
121
|
+
|
|
108
122
|
@override
|
|
109
123
|
def step(self, it: int) -> None:
|
|
110
124
|
value = self._get_value(it)
|
|
@@ -137,6 +151,17 @@ class LinearScheduler(_LinearScheduleBase):
|
|
|
137
151
|
class LinearParamScheduler(_LinearScheduleBase):
|
|
138
152
|
"""LinearParamScheduler adjusts a parameter value using a linear scheduler."""
|
|
139
153
|
|
|
154
|
+
@override
|
|
155
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
156
|
+
self.__dict__.update(state_dict)
|
|
157
|
+
self.scheduled_values = np.array([], dtype=np.float64)
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
@override
|
|
161
|
+
def state_dict(self) -> dict[str, Any]:
|
|
162
|
+
state = {k: v for k, v in self.__dict__.items() if k != "scheduled_values"}
|
|
163
|
+
return state
|
|
164
|
+
|
|
140
165
|
@override
|
|
141
166
|
def step(self, it: int) -> float:
|
|
142
167
|
"""
|