kostyl-toolkit 0.1.35__py3-none-any.whl → 0.1.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,6 @@ from typing import Any
2
2
  from typing import override
3
3
 
4
4
  import numpy as np
5
- import numpy.typing as npt
6
5
  import torch
7
6
 
8
7
  from .base import BaseScheduler
@@ -29,18 +28,24 @@ class _CosineSchedulerCore(BaseScheduler):
29
28
  if freeze_ratio is not None:
30
29
  if not (0 < freeze_ratio < 1):
31
30
  raise ValueError(f"Freeze ratio must be in (0, 1), got {freeze_ratio}.")
31
+ pre_annealing_ratio = (warmup_ratio if warmup_ratio is not None else 0) + (
32
+ freeze_ratio if freeze_ratio is not None else 0
33
+ )
34
+ if pre_annealing_ratio > 1:
35
+ raise ValueError(
36
+ "The sum of warmup_ratio and freeze_ratio must <= 1, got "
37
+ f"{pre_annealing_ratio}."
38
+ )
32
39
 
33
40
  self.param_name = param_name
34
41
  self.num_iters = num_iters
35
42
  self.base_value = base_value
36
43
  self.final_value = final_value
37
-
38
44
  self.warmup_ratio = warmup_ratio
39
45
  self.warmup_value = warmup_value
40
-
41
46
  self.freeze_ratio = freeze_ratio
42
47
 
43
- self.scheduler_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
48
+ self.scheduled_values: np.ndarray = np.array([], dtype=np.float64)
44
49
  self.current_value_ = self.base_value
45
50
  return
46
51
 
@@ -63,31 +68,29 @@ class _CosineSchedulerCore(BaseScheduler):
63
68
  warmup_iters = 0
64
69
  warmup_schedule = np.array([], dtype=np.float64)
65
70
 
71
+ # Create cosine annealing schedule
66
72
  cosine_annealing_iters = self.num_iters - warmup_iters - freeze_iters
67
- if cosine_annealing_iters <= 0:
68
- raise ValueError("Cosine annealing iters must be > 0.")
69
-
70
- # Create cosine schedule
71
- iters = np.arange(cosine_annealing_iters)
72
- schedule = self.final_value + 0.5 * (self.base_value - self.final_value) * (
73
- 1 + np.cos(np.pi * iters / len(iters))
74
- )
73
+ if cosine_annealing_iters > 0:
74
+ iters = np.arange(cosine_annealing_iters)
75
+ cosine_annealing_schedule = self.final_value + 0.5 * (
76
+ self.base_value - self.final_value
77
+ ) * (1 + np.cos(np.pi * iters / len(iters)))
78
+ else:
79
+ cosine_annealing_schedule = np.array([], dtype=np.float64)
75
80
 
76
81
  # Concatenate all parts of the schedule
77
- self.scheduler_values = np.concatenate(
78
- (freeze_schedule, warmup_schedule, schedule)
82
+ self.scheduled_values = np.concatenate(
83
+ (freeze_schedule, warmup_schedule, cosine_annealing_schedule)
79
84
  )
80
-
81
- if len(self.scheduler_values) != self.num_iters:
82
- raise ValueError(
83
- f"Scheduler length ({len(self.scheduler_values)}) does not match num_iters ({self.num_iters})."
84
- )
85
+ self._verify()
85
86
  return
86
87
 
87
88
  @override
88
- def load_state_dict(self, state_dict: dict[str, Any]) -> None:
89
- super().load_state_dict(state_dict)
90
- self.scheduler_values = np.array([], dtype=np.float64)
89
+ def _verify(self) -> None:
90
+ if len(self.scheduled_values) != self.num_iters:
91
+ raise ValueError(
92
+ f"Scheduler length ({len(self.scheduled_values)}) does not match num_iters ({self.num_iters})."
93
+ )
91
94
  return
92
95
 
93
96
  @override
@@ -95,13 +98,13 @@ class _CosineSchedulerCore(BaseScheduler):
95
98
  raise NotImplementedError
96
99
 
97
100
  def _get_value(self, it: int) -> float:
98
- if len(self.scheduler_values) == 0:
101
+ if len(self.scheduled_values) == 0:
99
102
  self._create_scheduler()
100
103
 
101
104
  if it >= self.num_iters:
102
105
  value: float = self.final_value
103
106
  else:
104
- value: float = self.scheduler_values[it]
107
+ value: float = self.scheduled_values[it]
105
108
  self.current_value_ = value
106
109
  return value
107
110
 
@@ -163,6 +166,21 @@ class CosineScheduler(_CosineSchedulerCore):
163
166
  self.param_group_field = param_group_field
164
167
  return
165
168
 
169
+ @override
170
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
171
+ self.__dict__.update(state_dict)
172
+ self.scheduled_values = np.array([], dtype=np.float64)
173
+ return
174
+
175
+ @override
176
+ def state_dict(self) -> dict[str, Any]:
177
+ state = {
178
+ k: v
179
+ for k, v in self.__dict__.items()
180
+ if k not in ["scheduled_values", "optimizer"]
181
+ }
182
+ return state
183
+
166
184
  @override
167
185
  def step(self, it: int) -> None:
168
186
  value = self._get_value(it)
@@ -209,3 +227,14 @@ class CosineParamScheduler(_CosineSchedulerCore):
209
227
  """
210
228
  value = self._get_value(it)
211
229
  return value
230
+
231
+ @override
232
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
233
+ self.__dict__.update(state_dict)
234
+ self.scheduled_values = np.array([], dtype=np.float64)
235
+ return
236
+
237
+ @override
238
+ def state_dict(self) -> dict[str, Any]:
239
+ state = {k: v for k, v in self.__dict__.items() if k != "scheduled_values"}
240
+ return state
@@ -0,0 +1,277 @@
1
+ from typing import Any
2
+ from typing import override
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from .base import BaseScheduler
8
+
9
+
10
+ class _CosineWithPlateauSchedulerCore(BaseScheduler):
11
+ """Core cosine with plateau scheduler logic."""
12
+
13
+ def __init__(
14
+ self,
15
+ param_name: str,
16
+ num_iters: int,
17
+ base_value: float,
18
+ final_value: float,
19
+ plateau_ratio: float,
20
+ warmup_value: float | None = None,
21
+ warmup_ratio: float | None = None,
22
+ freeze_ratio: float | None = None,
23
+ ) -> None:
24
+ if warmup_ratio is not None:
25
+ if not (0 < warmup_ratio < 1):
26
+ raise ValueError(f"Warmup ratio must be in (0, 1), got {warmup_ratio}.")
27
+ if (warmup_value is None) != (warmup_ratio is None):
28
+ raise ValueError(
29
+ "Both warmup_ratio and warmup_value must be provided or neither."
30
+ )
31
+ if freeze_ratio is not None:
32
+ if not (0 < freeze_ratio < 1):
33
+ raise ValueError(f"Freeze ratio must be in (0, 1), got {freeze_ratio}.")
34
+ if not (0 < plateau_ratio < 1):
35
+ raise ValueError(f"Plateau ratio must be in (0, 1), got {plateau_ratio}.")
36
+
37
+ pre_annealing_ratio = (
38
+ plateau_ratio
39
+ + (warmup_ratio if warmup_ratio is not None else 0)
40
+ + (freeze_ratio if freeze_ratio is not None else 0)
41
+ )
42
+ if pre_annealing_ratio > 1:
43
+ raise ValueError(
44
+ "The sum of plateau_ratio, warmup_ratio, and freeze_ratio must <= 1, got "
45
+ f"{pre_annealing_ratio}."
46
+ )
47
+
48
+ self.param_name = param_name
49
+ self.num_iters = num_iters
50
+ self.base_value = base_value
51
+ self.final_value = final_value
52
+ self.cosine_annealing_ratio = 1 - pre_annealing_ratio
53
+ self.plateau_ratio = plateau_ratio
54
+ self.warmup_ratio = warmup_ratio
55
+ self.warmup_value = warmup_value
56
+ self.freeze_ratio = freeze_ratio
57
+
58
+ self.scheduled_values: np.ndarray = np.array([], dtype=np.float64)
59
+ self.current_value_ = self.base_value
60
+ return
61
+
62
+ def _create_scheduler(self) -> None:
63
+ # Create freeze schedule
64
+ if self.freeze_ratio is not None:
65
+ freeze_iters = int(self.num_iters * self.freeze_ratio)
66
+ freeze_schedule = np.zeros(freeze_iters, dtype=np.float64)
67
+ else:
68
+ freeze_iters = 0
69
+ freeze_schedule = np.array([], dtype=np.float64)
70
+
71
+ # Create linear warmup schedule
72
+ if self.warmup_ratio is not None and self.warmup_value is not None:
73
+ warmup_iters = int(self.num_iters * self.warmup_ratio)
74
+ warmup_schedule = np.linspace(
75
+ self.warmup_value, self.base_value, warmup_iters, dtype=np.float64
76
+ )
77
+ else:
78
+ warmup_iters = 0
79
+ warmup_schedule = np.array([], dtype=np.float64)
80
+
81
+ # Create cosine annealing schedule
82
+ if self.cosine_annealing_ratio > 0:
83
+ cosine_annealing_iters = int(self.num_iters * self.cosine_annealing_ratio)
84
+ iters = np.arange(cosine_annealing_iters)
85
+ cosine_annealing_schedule = self.final_value + 0.5 * (
86
+ self.base_value - self.final_value
87
+ ) * (1 + np.cos(np.pi * iters / len(iters)))
88
+ else:
89
+ cosine_annealing_iters = 0
90
+ cosine_annealing_schedule = np.array([], dtype=np.float64)
91
+
92
+ plateau_iters = (
93
+ self.num_iters - warmup_iters - freeze_iters - cosine_annealing_iters
94
+ )
95
+ if plateau_iters > 0:
96
+ plateau_schedule = np.full(plateau_iters, self.base_value, dtype=np.float64)
97
+ else:
98
+ plateau_schedule = np.array([], dtype=np.float64)
99
+
100
+ # Concatenate all parts of the schedule
101
+ self.scheduled_values = np.concatenate(
102
+ (
103
+ freeze_schedule,
104
+ warmup_schedule,
105
+ plateau_schedule,
106
+ cosine_annealing_schedule,
107
+ )
108
+ )
109
+ self._verify()
110
+ return
111
+
112
+ @override
113
+ def _verify(self) -> None:
114
+ if len(self.scheduled_values) != self.num_iters:
115
+ raise ValueError(
116
+ f"Scheduler length ({len(self.scheduled_values)}) does not match num_iters ({self.num_iters})."
117
+ )
118
+ return
119
+
120
+ @override
121
+ def step(self, it: int) -> None | float:
122
+ raise NotImplementedError
123
+
124
+ def _get_value(self, it: int) -> float:
125
+ if len(self.scheduled_values) == 0:
126
+ self._create_scheduler()
127
+
128
+ if it >= self.num_iters:
129
+ value: float = self.final_value
130
+ else:
131
+ value: float = self.scheduled_values[it]
132
+ self.current_value_ = value
133
+ return value
134
+
135
+ @override
136
+ def current_value(self) -> dict[str, float]:
137
+ return {self.param_name: self.current_value_}
138
+
139
+
140
+ class CosineWithPlateuScheduler(_CosineWithPlateauSchedulerCore):
141
+ """
142
+ Applies a cosine schedule with plateau to an optimizer param-group field.
143
+
144
+ Schedule phases: freeze (0) → warmup → plateau (base_value) → cosine annealing to final_value.
145
+ The plateau phase maintains the base_value before cosine annealing begins.
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ optimizer: torch.optim.Optimizer,
151
+ param_group_field: str,
152
+ num_iters: int,
153
+ base_value: float,
154
+ final_value: float,
155
+ plateau_ratio: float,
156
+ warmup_value: float | None = None,
157
+ warmup_ratio: float | None = None,
158
+ freeze_ratio: float | None = None,
159
+ multiplier_field: str | None = None,
160
+ skip_if_zero: bool = False,
161
+ apply_if_field: str | None = None,
162
+ ignore_if_field: str | None = None,
163
+ ) -> None:
164
+ """
165
+ Configure cosine scheduling for matching optimizer groups.
166
+
167
+ Args:
168
+ optimizer: Optimizer whose param groups are updated in-place.
169
+ param_group_field: Name of the field that receives the scheduled value.
170
+ num_iters: Number of scheduler iterations before clamping at ``final_value``.
171
+ base_value: Value maintained during plateau phase and used as cosine start.
172
+ final_value: Value approached as iterations progress during cosine annealing.
173
+ plateau_ratio: Fraction of iterations to maintain ``base_value`` before cosine annealing.
174
+ warmup_ratio: Optional fraction of iterations to linearly ramp from ``warmup_value`` to ``base_value``.
175
+ warmup_value: Starting value for the warmup ramp.
176
+ freeze_ratio: Optional fraction of iterations to keep the value frozen at zero at the beginning.
177
+ multiplier_field: Optional per-group multiplier applied to the scheduled value.
178
+ skip_if_zero: Leave groups untouched when their target field equals zero.
179
+ apply_if_field: Require this flag to be present in a param group before updating.
180
+ ignore_if_field: Skip groups that declare this flag.
181
+
182
+ """
183
+ self.apply_if_field = apply_if_field
184
+ self.ignore_if_field = ignore_if_field
185
+ self.optimizer = optimizer
186
+ self.multiplier_field = multiplier_field
187
+ self.skip_if_zero = skip_if_zero
188
+ super().__init__(
189
+ param_name=param_group_field,
190
+ num_iters=num_iters,
191
+ base_value=base_value,
192
+ final_value=final_value,
193
+ plateau_ratio=plateau_ratio,
194
+ warmup_ratio=warmup_ratio,
195
+ warmup_value=warmup_value,
196
+ freeze_ratio=freeze_ratio,
197
+ )
198
+ self.param_group_field = param_group_field
199
+ return
200
+
201
+ @override
202
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
203
+ self.__dict__.update(state_dict)
204
+ self.scheduled_values = np.array([], dtype=np.float64)
205
+ return
206
+
207
+ @override
208
+ def state_dict(self) -> dict[str, Any]:
209
+ state = {
210
+ k: v
211
+ for k, v in self.__dict__.items()
212
+ if k not in ["scheduled_values", "optimizer"]
213
+ }
214
+ return state
215
+
216
+ @override
217
+ def step(self, it: int) -> None:
218
+ value = self._get_value(it)
219
+ for pg in self.optimizer.param_groups:
220
+ if self.param_group_field not in pg:
221
+ raise ValueError(
222
+ f"Parameter group field '{self.param_group_field}' not found in optimizer parameter groups."
223
+ )
224
+
225
+ if (self.apply_if_field is not None) and (self.apply_if_field not in pg):
226
+ continue
227
+
228
+ if (self.ignore_if_field is not None) and (self.ignore_if_field in pg):
229
+ continue
230
+
231
+ if self.skip_if_zero and pg[self.param_group_field] == 0:
232
+ continue
233
+
234
+ if self.multiplier_field is not None:
235
+ if self.multiplier_field not in pg:
236
+ multiplier = 1.0
237
+ else:
238
+ multiplier = pg[self.multiplier_field]
239
+ pg[self.param_group_field] = value * multiplier
240
+ else:
241
+ pg[self.param_group_field] = value
242
+ return
243
+
244
+
245
+ class CosineWithPlateauParamScheduler(_CosineWithPlateauSchedulerCore):
246
+ """
247
+ Standalone cosine scheduler with plateau for non-optimizer parameters.
248
+
249
+ Schedule phases: freeze (0) → warmup → plateau (base_value) → cosine annealing to final_value.
250
+ The plateau phase maintains the base_value before cosine annealing begins.
251
+ """
252
+
253
+ @override
254
+ def step(self, it: int) -> float:
255
+ """
256
+ Computes the value corresponding to the given iteration step.
257
+
258
+ Args:
259
+ it: The current iteration index used for value computation.
260
+
261
+ Returns:
262
+ The computed value for the provided iteration step as a float.
263
+
264
+ """
265
+ value = self._get_value(it)
266
+ return value
267
+
268
+ @override
269
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
270
+ self.__dict__.update(state_dict)
271
+ self.scheduled_values = np.array([], dtype=np.float64)
272
+ return
273
+
274
+ @override
275
+ def state_dict(self) -> dict[str, Any]:
276
+ state = {k: v for k, v in self.__dict__.items() if k != "scheduled_values"}
277
+ return state
@@ -21,24 +21,23 @@ class _LinearScheduleBase(BaseScheduler):
21
21
  self.start_value = start_value
22
22
  self.final_value = final_value
23
23
 
24
- self.scheduler_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
24
+ self.scheduled_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
25
25
  self.current_value_ = self.start_value
26
26
  return
27
27
 
28
28
  def _create_scheduler(self) -> None:
29
- self.scheduler_values = np.linspace(
29
+ self.scheduled_values = np.linspace(
30
30
  self.start_value, self.final_value, num=self.num_iters, dtype=np.float64
31
31
  )
32
- if len(self.scheduler_values) != self.num_iters:
33
- raise ValueError(
34
- f"Scheduler length ({len(self.scheduler_values)}) does not match total_iters ({self.num_iters})."
35
- )
32
+ self._verify()
36
33
  return
37
34
 
38
35
  @override
39
- def load_state_dict(self, state_dict: dict[str, Any]) -> None:
40
- super().load_state_dict(state_dict)
41
- self.scheduler_values = np.array([], dtype=np.float64)
36
+ def _verify(self) -> None:
37
+ if len(self.scheduled_values) != self.num_iters:
38
+ raise ValueError(
39
+ f"Scheduler length ({len(self.scheduled_values)}) does not match total_iters ({self.num_iters})."
40
+ )
42
41
  return
43
42
 
44
43
  @override
@@ -46,13 +45,13 @@ class _LinearScheduleBase(BaseScheduler):
46
45
  raise NotImplementedError
47
46
 
48
47
  def _get_value(self, it: int) -> float:
49
- if len(self.scheduler_values) == 0:
48
+ if len(self.scheduled_values) == 0:
50
49
  self._create_scheduler()
51
50
 
52
51
  if it >= self.num_iters:
53
52
  value: float = self.final_value
54
53
  else:
55
- value: float = self.scheduler_values[it]
54
+ value: float = self.scheduled_values[it]
56
55
  self.current_value_ = value
57
56
  return value
58
57
 
@@ -105,6 +104,21 @@ class LinearScheduler(_LinearScheduleBase):
105
104
  self.param_group_field = param_group_field
106
105
  return
107
106
 
107
+ @override
108
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
109
+ self.__dict__.update(state_dict)
110
+ self.scheduled_values = np.array([], dtype=np.float64)
111
+ return
112
+
113
+ @override
114
+ def state_dict(self) -> dict[str, Any]:
115
+ state = {
116
+ k: v
117
+ for k, v in self.__dict__.items()
118
+ if k not in ["scheduled_values", "optimizer"]
119
+ }
120
+ return state
121
+
108
122
  @override
109
123
  def step(self, it: int) -> None:
110
124
  value = self._get_value(it)
@@ -137,6 +151,17 @@ class LinearScheduler(_LinearScheduleBase):
137
151
  class LinearParamScheduler(_LinearScheduleBase):
138
152
  """LinearParamScheduler adjusts a parameter value using a linear scheduler."""
139
153
 
154
+ @override
155
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
156
+ self.__dict__.update(state_dict)
157
+ self.scheduled_values = np.array([], dtype=np.float64)
158
+ return
159
+
160
+ @override
161
+ def state_dict(self) -> dict[str, Any]:
162
+ state = {k: v for k, v in self.__dict__.items() if k != "scheduled_values"}
163
+ return state
164
+
140
165
  @override
141
166
  def step(self, it: int) -> float:
142
167
  """