kostyl-toolkit 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kostyl/ml/configs/hyperparams.py +2 -6
- kostyl/ml/lightning/callbacks/checkpoint.py +84 -36
- kostyl/ml/lightning/callbacks/registry_uploader.py +21 -92
- kostyl/ml/schedulers/cosine.py +39 -47
- kostyl/ml/schedulers/linear.py +153 -0
- {kostyl_toolkit-0.1.25.dist-info → kostyl_toolkit-0.1.27.dist-info}/METADATA +1 -1
- {kostyl_toolkit-0.1.25.dist-info → kostyl_toolkit-0.1.27.dist-info}/RECORD +8 -7
- {kostyl_toolkit-0.1.25.dist-info → kostyl_toolkit-0.1.27.dist-info}/WHEEL +1 -1
kostyl/ml/configs/hyperparams.py
CHANGED
|
@@ -29,15 +29,11 @@ class Lr(BaseModel):
|
|
|
29
29
|
@model_validator(mode="after")
|
|
30
30
|
def validate_warmup(self) -> "Lr":
|
|
31
31
|
"""Validates the warmup parameters based on use_scheduler."""
|
|
32
|
-
if (self.warmup_value is None) != (
|
|
33
|
-
self.warmup_iters_ratio is None
|
|
34
|
-
) and self.use_scheduler:
|
|
32
|
+
if (self.warmup_value is None) != (self.warmup_iters_ratio is None): # fmt: skip
|
|
35
33
|
raise ValueError(
|
|
36
34
|
"Both warmup_value and warmup_iters_ratio must be provided or neither"
|
|
37
35
|
)
|
|
38
|
-
|
|
39
|
-
(self.warmup_value is not None) or (self.warmup_iters_ratio is not None)
|
|
40
|
-
) and (not self.use_scheduler):
|
|
36
|
+
if ((self.warmup_value is not None) or (self.warmup_iters_ratio is not None)) and not self.use_scheduler: # fmt: skip
|
|
41
37
|
logger.warning(
|
|
42
38
|
"use_scheduler is False, warmup_value and warmup_iters_ratio will be ignored."
|
|
43
39
|
)
|
|
@@ -2,12 +2,16 @@ from datetime import timedelta
|
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from shutil import rmtree
|
|
4
4
|
from typing import Literal
|
|
5
|
+
from typing import override
|
|
5
6
|
|
|
7
|
+
import lightning.pytorch as pl
|
|
8
|
+
import torch.distributed as dist
|
|
6
9
|
from lightning.fabric.utilities.types import _PATH
|
|
7
10
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
8
11
|
|
|
9
12
|
from kostyl.ml.configs import CheckpointConfig
|
|
10
13
|
from kostyl.ml.dist_utils import is_main_process
|
|
14
|
+
from kostyl.ml.lightning import KostylLightningModule
|
|
11
15
|
from kostyl.utils import setup_logger
|
|
12
16
|
|
|
13
17
|
from .registry_uploader import RegistryUploaderCallback
|
|
@@ -16,7 +20,7 @@ from .registry_uploader import RegistryUploaderCallback
|
|
|
16
20
|
logger = setup_logger("callbacks/checkpoint.py")
|
|
17
21
|
|
|
18
22
|
|
|
19
|
-
class
|
|
23
|
+
class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
|
|
20
24
|
r"""
|
|
21
25
|
Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
|
|
22
26
|
:class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
|
|
@@ -226,6 +230,8 @@ class CustomModelCheckpoint(ModelCheckpoint):
|
|
|
226
230
|
|
|
227
231
|
def __init__( # noqa: D107
|
|
228
232
|
self,
|
|
233
|
+
registry_uploader_callback: RegistryUploaderCallback,
|
|
234
|
+
uploading_mode: Literal["only-best", "every-checkpoint"] = "only-best",
|
|
229
235
|
dirpath: _PATH | None = None,
|
|
230
236
|
filename: str | None = None,
|
|
231
237
|
monitor: str | None = None,
|
|
@@ -241,10 +247,10 @@ class CustomModelCheckpoint(ModelCheckpoint):
|
|
|
241
247
|
every_n_epochs: int | None = None,
|
|
242
248
|
save_on_train_epoch_end: bool | None = None,
|
|
243
249
|
enable_version_counter: bool = True,
|
|
244
|
-
registry_uploader_callback: RegistryUploaderCallback | None = None,
|
|
245
250
|
) -> None:
|
|
246
251
|
self.registry_uploader_callback = registry_uploader_callback
|
|
247
|
-
self.
|
|
252
|
+
self.process_group: dist.ProcessGroup | None = None
|
|
253
|
+
self.uploading_mode = uploading_mode
|
|
248
254
|
super().__init__(
|
|
249
255
|
dirpath=dirpath,
|
|
250
256
|
filename=filename,
|
|
@@ -264,16 +270,30 @@ class CustomModelCheckpoint(ModelCheckpoint):
|
|
|
264
270
|
)
|
|
265
271
|
return
|
|
266
272
|
|
|
267
|
-
@
|
|
268
|
-
def
|
|
269
|
-
|
|
270
|
-
|
|
273
|
+
@override
|
|
274
|
+
def setup(
|
|
275
|
+
self,
|
|
276
|
+
trainer: pl.Trainer,
|
|
277
|
+
pl_module: pl.LightningModule | KostylLightningModule,
|
|
278
|
+
stage: str,
|
|
279
|
+
) -> None:
|
|
280
|
+
super().setup(trainer, pl_module, stage)
|
|
281
|
+
if isinstance(pl_module, KostylLightningModule):
|
|
282
|
+
self.process_group = pl_module.get_process_group()
|
|
283
|
+
return
|
|
271
284
|
|
|
272
|
-
@
|
|
273
|
-
def
|
|
274
|
-
|
|
275
|
-
if
|
|
276
|
-
self.
|
|
285
|
+
@override
|
|
286
|
+
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
|
|
287
|
+
super()._save_checkpoint(trainer, filepath)
|
|
288
|
+
if dist.is_initialized():
|
|
289
|
+
dist.barrier(group=self.process_group)
|
|
290
|
+
if trainer.is_global_zero and self.registry_uploader_callback is not None:
|
|
291
|
+
match self.uploading_mode:
|
|
292
|
+
case "every-checkpoint":
|
|
293
|
+
self.registry_uploader_callback.upload_checkpoint(filepath)
|
|
294
|
+
case "only-best":
|
|
295
|
+
if filepath == self.best_model_path:
|
|
296
|
+
self.registry_uploader_callback.upload_checkpoint(filepath)
|
|
277
297
|
return
|
|
278
298
|
|
|
279
299
|
|
|
@@ -282,28 +302,44 @@ def setup_checkpoint_callback(
|
|
|
282
302
|
ckpt_cfg: CheckpointConfig,
|
|
283
303
|
save_weights_only: bool = True,
|
|
284
304
|
registry_uploader_callback: RegistryUploaderCallback | None = None,
|
|
285
|
-
|
|
305
|
+
uploading_mode: Literal["only-best", "every-checkpoint"] | None = None,
|
|
306
|
+
) -> ModelCheckpointWithRegistryUploader | ModelCheckpoint:
|
|
286
307
|
"""
|
|
287
|
-
|
|
308
|
+
Create and configure a checkpoint callback for model saving.
|
|
288
309
|
|
|
289
|
-
|
|
290
|
-
callback
|
|
291
|
-
|
|
292
|
-
the directory is created.
|
|
310
|
+
Creates the checkpoint directory (removing existing one if present) and returns
|
|
311
|
+
a configured callback for saving models during training. When registry_uploader_callback
|
|
312
|
+
is provided, returns an extended version with support for uploading checkpoints to a remote registry.
|
|
293
313
|
|
|
294
314
|
Args:
|
|
295
|
-
dirpath
|
|
296
|
-
ckpt_cfg (
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
315
|
+
dirpath: Path to the directory for saving checkpoints.
|
|
316
|
+
ckpt_cfg: Checkpoint configuration (filename, monitor, mode, save_top_k).
|
|
317
|
+
save_weights_only: If True, only model weights are saved without optimizer and lr-scheduler state.
|
|
318
|
+
Defaults to True.
|
|
319
|
+
registry_uploader_callback: Optional callback for uploading checkpoints to a remote registry.
|
|
320
|
+
Must be specified together with uploading_mode.
|
|
321
|
+
uploading_mode: Checkpoint upload mode:
|
|
322
|
+
- "only-best": only the best checkpoint is uploaded
|
|
323
|
+
- "every-checkpoint": every saved checkpoint is uploaded
|
|
324
|
+
Must be specified together with registry_uploader_callback.
|
|
302
325
|
|
|
303
326
|
Returns:
|
|
304
|
-
|
|
327
|
+
ModelCheckpointWithRegistryUploader if registry_uploader_callback is provided,
|
|
328
|
+
otherwise standard ModelCheckpoint.
|
|
329
|
+
|
|
330
|
+
Raises:
|
|
331
|
+
ValueError: If only one of registry_uploader_callback or uploading_mode is None.
|
|
332
|
+
|
|
333
|
+
Note:
|
|
334
|
+
If the dirpath directory already exists, it will be removed and recreated
|
|
335
|
+
(only on the main process in distributed training).
|
|
305
336
|
|
|
306
337
|
"""
|
|
338
|
+
if (registry_uploader_callback is None) != (uploading_mode is None):
|
|
339
|
+
raise ValueError(
|
|
340
|
+
"Both registry_uploader_callback and uploading_mode must be provided or neither."
|
|
341
|
+
)
|
|
342
|
+
|
|
307
343
|
if dirpath.exists():
|
|
308
344
|
if is_main_process():
|
|
309
345
|
logger.warning(f"Checkpoint directory {dirpath} already exists.")
|
|
@@ -313,14 +349,26 @@ def setup_checkpoint_callback(
|
|
|
313
349
|
logger.info(f"Creating checkpoint directory {dirpath}.")
|
|
314
350
|
dirpath.mkdir(parents=True, exist_ok=True)
|
|
315
351
|
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
352
|
+
if (registry_uploader_callback is not None) and (uploading_mode is not None):
|
|
353
|
+
checkpoint_callback = ModelCheckpointWithRegistryUploader(
|
|
354
|
+
dirpath=dirpath,
|
|
355
|
+
filename=ckpt_cfg.filename,
|
|
356
|
+
save_top_k=ckpt_cfg.save_top_k,
|
|
357
|
+
monitor=ckpt_cfg.monitor,
|
|
358
|
+
mode=ckpt_cfg.mode,
|
|
359
|
+
verbose=True,
|
|
360
|
+
save_weights_only=save_weights_only,
|
|
361
|
+
registry_uploader_callback=registry_uploader_callback,
|
|
362
|
+
uploading_mode=uploading_mode,
|
|
363
|
+
)
|
|
364
|
+
else:
|
|
365
|
+
checkpoint_callback = ModelCheckpoint(
|
|
366
|
+
dirpath=dirpath,
|
|
367
|
+
filename=ckpt_cfg.filename,
|
|
368
|
+
save_top_k=ckpt_cfg.save_top_k,
|
|
369
|
+
monitor=ckpt_cfg.monitor,
|
|
370
|
+
mode=ckpt_cfg.mode,
|
|
371
|
+
verbose=True,
|
|
372
|
+
save_weights_only=save_weights_only,
|
|
373
|
+
)
|
|
326
374
|
return checkpoint_callback
|
|
@@ -1,40 +1,26 @@
|
|
|
1
1
|
from abc import ABC
|
|
2
2
|
from abc import abstractmethod
|
|
3
3
|
from collections.abc import Callable
|
|
4
|
-
from
|
|
5
|
-
from typing import Literal
|
|
4
|
+
from pathlib import Path
|
|
6
5
|
from typing import override
|
|
7
6
|
|
|
8
7
|
from clearml import OutputModel
|
|
9
8
|
from clearml import Task
|
|
10
|
-
from lightning import Trainer
|
|
11
|
-
from lightning.pytorch.callbacks import Callback
|
|
12
9
|
|
|
13
10
|
from kostyl.ml.clearml.logging_utils import find_version_in_tags
|
|
14
11
|
from kostyl.ml.clearml.logging_utils import increment_version
|
|
15
|
-
from kostyl.ml.lightning import KostylLightningModule
|
|
16
12
|
from kostyl.utils.logging import setup_logger
|
|
17
13
|
|
|
18
14
|
|
|
19
15
|
logger = setup_logger()
|
|
20
16
|
|
|
21
17
|
|
|
22
|
-
class RegistryUploaderCallback(
|
|
18
|
+
class RegistryUploaderCallback(ABC):
|
|
23
19
|
"""Abstract Lightning callback responsible for tracking and uploading the best-performing model checkpoint."""
|
|
24
20
|
|
|
25
|
-
@property
|
|
26
21
|
@abstractmethod
|
|
27
|
-
def
|
|
28
|
-
"""
|
|
29
|
-
raise NotImplementedError
|
|
30
|
-
|
|
31
|
-
@best_model_path.setter
|
|
32
|
-
@abstractmethod
|
|
33
|
-
def best_model_path(self, value: str) -> None:
|
|
34
|
-
raise NotImplementedError
|
|
35
|
-
|
|
36
|
-
@abstractmethod
|
|
37
|
-
def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
|
|
22
|
+
def upload_checkpoint(self, path: str | Path) -> None:
|
|
23
|
+
"""Upload the checkpoint located at the given path to the configured registry backend."""
|
|
38
24
|
raise NotImplementedError
|
|
39
25
|
|
|
40
26
|
|
|
@@ -50,9 +36,6 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
|
50
36
|
enable_tag_versioning: bool = True,
|
|
51
37
|
label_enumeration: dict[str, int] | None = None,
|
|
52
38
|
config_dict: dict[str, str] | None = None,
|
|
53
|
-
uploading_frequency: Literal[
|
|
54
|
-
"after-every-eval", "on-train-end"
|
|
55
|
-
] = "on-train-end",
|
|
56
39
|
) -> None:
|
|
57
40
|
"""
|
|
58
41
|
Initializes the ClearMLRegistryUploaderCallback.
|
|
@@ -67,9 +50,6 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
|
67
50
|
config_dict: Optional configuration dictionary to associate with the model.
|
|
68
51
|
enable_tag_versioning: Whether to enable versioning in tags. If True,
|
|
69
52
|
the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
|
|
70
|
-
uploading_frequency: When to upload:
|
|
71
|
-
- "after-every-eval": after each validation phase.
|
|
72
|
-
- "on-train-end": once at the end of training.
|
|
73
53
|
|
|
74
54
|
"""
|
|
75
55
|
super().__init__()
|
|
@@ -82,29 +62,16 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
|
82
62
|
self.config_dict = config_dict
|
|
83
63
|
self.label_enumeration = label_enumeration
|
|
84
64
|
self.verbose = verbose
|
|
85
|
-
self.uploading_frequency = uploading_frequency
|
|
86
65
|
self.enable_tag_versioning = enable_tag_versioning
|
|
87
66
|
|
|
67
|
+
self.best_model_path: str = ""
|
|
68
|
+
|
|
88
69
|
self._output_model: OutputModel | None = None
|
|
89
70
|
self._last_uploaded_model_path: str = ""
|
|
90
|
-
self._best_model_path: str = ""
|
|
91
71
|
self._upload_callback: Callable | None = None
|
|
92
72
|
return
|
|
93
73
|
|
|
94
|
-
|
|
95
|
-
@override
|
|
96
|
-
def best_model_path(self) -> str:
|
|
97
|
-
return self._best_model_path
|
|
98
|
-
|
|
99
|
-
@best_model_path.setter
|
|
100
|
-
@override
|
|
101
|
-
def best_model_path(self, value: str) -> None:
|
|
102
|
-
self._best_model_path = value
|
|
103
|
-
if self._upload_callback is not None:
|
|
104
|
-
self._upload_callback()
|
|
105
|
-
return
|
|
106
|
-
|
|
107
|
-
def _create_output_model(self, pl_module: "KostylLightningModule") -> OutputModel:
|
|
74
|
+
def _create_output_model(self) -> OutputModel:
|
|
108
75
|
if self.enable_tag_versioning:
|
|
109
76
|
version = find_version_in_tags(self.output_model_tags)
|
|
110
77
|
if version is None:
|
|
@@ -117,13 +84,6 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
|
117
84
|
if "LightningCheckpoint" not in self.output_model_tags:
|
|
118
85
|
self.output_model_tags.append("LightningCheckpoint")
|
|
119
86
|
|
|
120
|
-
if self.config_dict is None:
|
|
121
|
-
config = pl_module.model_config
|
|
122
|
-
if config is not None:
|
|
123
|
-
config = config.to_dict()
|
|
124
|
-
else:
|
|
125
|
-
config = self.config_dict
|
|
126
|
-
|
|
127
87
|
return OutputModel(
|
|
128
88
|
task=self.task,
|
|
129
89
|
name=self.output_model_name,
|
|
@@ -134,60 +94,29 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
|
134
94
|
)
|
|
135
95
|
|
|
136
96
|
@override
|
|
137
|
-
def
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
logger.info("Best model unchanged since last upload")
|
|
147
|
-
self._upload_callback = partial(self._upload_best_checkpoint, pl_module)
|
|
97
|
+
def upload_checkpoint(
|
|
98
|
+
self,
|
|
99
|
+
path: str | Path,
|
|
100
|
+
) -> None:
|
|
101
|
+
if isinstance(path, Path):
|
|
102
|
+
path = str(path)
|
|
103
|
+
if path == self._last_uploaded_model_path:
|
|
104
|
+
if self.verbose:
|
|
105
|
+
logger.info("Model unchanged since last upload")
|
|
148
106
|
return
|
|
149
|
-
self._upload_callback = None
|
|
150
107
|
|
|
151
108
|
if self._output_model is None:
|
|
152
|
-
self._output_model = self._create_output_model(
|
|
109
|
+
self._output_model = self._create_output_model()
|
|
153
110
|
|
|
154
111
|
if self.verbose:
|
|
155
|
-
logger.info(f"Uploading
|
|
112
|
+
logger.info(f"Uploading model from {path}")
|
|
156
113
|
|
|
157
114
|
self._output_model.update_weights(
|
|
158
|
-
|
|
115
|
+
path,
|
|
159
116
|
auto_delete_file=False,
|
|
160
117
|
async_enable=False,
|
|
161
118
|
)
|
|
162
|
-
|
|
163
|
-
config = pl_module.model_config
|
|
164
|
-
if config is not None:
|
|
165
|
-
config = config.to_dict()
|
|
166
|
-
else:
|
|
167
|
-
config = self.config_dict
|
|
168
|
-
self._output_model.update_design(config_dict=config)
|
|
169
|
-
|
|
170
|
-
self._last_uploaded_model_path = self._best_model_path
|
|
171
|
-
return
|
|
172
|
-
|
|
173
|
-
@override
|
|
174
|
-
def on_validation_end(
|
|
175
|
-
self, trainer: Trainer, pl_module: "KostylLightningModule"
|
|
176
|
-
) -> None:
|
|
177
|
-
if self.uploading_frequency != "after-every-eval":
|
|
178
|
-
return
|
|
179
|
-
if not trainer.is_global_zero:
|
|
180
|
-
return
|
|
181
|
-
|
|
182
|
-
self._upload_best_checkpoint(pl_module)
|
|
183
|
-
return
|
|
184
|
-
|
|
185
|
-
@override
|
|
186
|
-
def on_train_end(
|
|
187
|
-
self, trainer: Trainer, pl_module: "KostylLightningModule"
|
|
188
|
-
) -> None:
|
|
189
|
-
if not trainer.is_global_zero:
|
|
190
|
-
return
|
|
119
|
+
self._output_model.update_design(config_dict=self.config_dict)
|
|
191
120
|
|
|
192
|
-
self.
|
|
121
|
+
self._last_uploaded_model_path = path
|
|
193
122
|
return
|
kostyl/ml/schedulers/cosine.py
CHANGED
|
@@ -11,20 +11,18 @@ from .base import BaseScheduler
|
|
|
11
11
|
class _CosineSchedulerCore(BaseScheduler):
|
|
12
12
|
def __init__(
|
|
13
13
|
self,
|
|
14
|
-
|
|
15
|
-
|
|
14
|
+
param_name: str,
|
|
15
|
+
num_iters: int,
|
|
16
16
|
base_value: float,
|
|
17
17
|
final_value: float,
|
|
18
|
-
|
|
18
|
+
warmup_ratio: float | None = None,
|
|
19
19
|
warmup_value: float | None = None,
|
|
20
20
|
freeze_ratio: float | None = None,
|
|
21
21
|
) -> None:
|
|
22
|
-
if
|
|
23
|
-
if not (0 <
|
|
24
|
-
raise ValueError(
|
|
25
|
-
|
|
26
|
-
)
|
|
27
|
-
if (warmup_value is None) != (warmup_iters_ratio is None):
|
|
22
|
+
if warmup_ratio is not None:
|
|
23
|
+
if not (0 < warmup_ratio < 1):
|
|
24
|
+
raise ValueError(f"Warmup ratio must be in (0, 1), got {warmup_ratio}.")
|
|
25
|
+
if (warmup_value is None) != (warmup_ratio is None):
|
|
28
26
|
raise ValueError(
|
|
29
27
|
"Both warmup_ratio and warmup_value must be provided or neither."
|
|
30
28
|
)
|
|
@@ -32,12 +30,12 @@ class _CosineSchedulerCore(BaseScheduler):
|
|
|
32
30
|
if not (0 < freeze_ratio < 1):
|
|
33
31
|
raise ValueError(f"Freeze ratio must be in (0, 1), got {freeze_ratio}.")
|
|
34
32
|
|
|
35
|
-
self.
|
|
36
|
-
self.
|
|
33
|
+
self.param_name = param_name
|
|
34
|
+
self.num_iters = num_iters
|
|
37
35
|
self.base_value = base_value
|
|
38
36
|
self.final_value = final_value
|
|
39
37
|
|
|
40
|
-
self.
|
|
38
|
+
self.warmup_ratio = warmup_ratio
|
|
41
39
|
self.warmup_value = warmup_value
|
|
42
40
|
|
|
43
41
|
self.freeze_ratio = freeze_ratio
|
|
@@ -49,15 +47,15 @@ class _CosineSchedulerCore(BaseScheduler):
|
|
|
49
47
|
def _create_scheduler(self) -> None:
|
|
50
48
|
# Create freeze schedule
|
|
51
49
|
if self.freeze_ratio is not None:
|
|
52
|
-
freeze_iters = int(self.
|
|
50
|
+
freeze_iters = int(self.num_iters * self.freeze_ratio)
|
|
53
51
|
freeze_schedule = np.zeros(freeze_iters, dtype=np.float64)
|
|
54
52
|
else:
|
|
55
53
|
freeze_iters = 0
|
|
56
54
|
freeze_schedule = np.array([], dtype=np.float64)
|
|
57
55
|
|
|
58
56
|
# Create linear warmup schedule
|
|
59
|
-
if self.
|
|
60
|
-
warmup_iters = int(self.
|
|
57
|
+
if self.warmup_ratio is not None and self.warmup_value is not None:
|
|
58
|
+
warmup_iters = int(self.num_iters * self.warmup_ratio)
|
|
61
59
|
warmup_schedule = np.linspace(
|
|
62
60
|
self.warmup_value, self.base_value, warmup_iters, dtype=np.float64
|
|
63
61
|
)
|
|
@@ -65,7 +63,7 @@ class _CosineSchedulerCore(BaseScheduler):
|
|
|
65
63
|
warmup_iters = 0
|
|
66
64
|
warmup_schedule = np.array([], dtype=np.float64)
|
|
67
65
|
|
|
68
|
-
cosine_annealing_iters = self.
|
|
66
|
+
cosine_annealing_iters = self.num_iters - warmup_iters - freeze_iters
|
|
69
67
|
if cosine_annealing_iters <= 0:
|
|
70
68
|
raise ValueError("Cosine annealing iters must be > 0.")
|
|
71
69
|
|
|
@@ -80,9 +78,9 @@ class _CosineSchedulerCore(BaseScheduler):
|
|
|
80
78
|
(freeze_schedule, warmup_schedule, schedule)
|
|
81
79
|
)
|
|
82
80
|
|
|
83
|
-
if len(self.scheduler_values) != self.
|
|
81
|
+
if len(self.scheduler_values) != self.num_iters:
|
|
84
82
|
raise ValueError(
|
|
85
|
-
f"Scheduler length ({len(self.scheduler_values)}) does not match
|
|
83
|
+
f"Scheduler length ({len(self.scheduler_values)}) does not match num_iters ({self.num_iters})."
|
|
86
84
|
)
|
|
87
85
|
return
|
|
88
86
|
|
|
@@ -100,7 +98,7 @@ class _CosineSchedulerCore(BaseScheduler):
|
|
|
100
98
|
if len(self.scheduler_values) == 0:
|
|
101
99
|
self._create_scheduler()
|
|
102
100
|
|
|
103
|
-
if it >= self.
|
|
101
|
+
if it >= self.num_iters:
|
|
104
102
|
value: float = self.final_value
|
|
105
103
|
else:
|
|
106
104
|
value: float = self.scheduler_values[it]
|
|
@@ -109,20 +107,20 @@ class _CosineSchedulerCore(BaseScheduler):
|
|
|
109
107
|
|
|
110
108
|
@override
|
|
111
109
|
def current_value(self) -> dict[str, float]:
|
|
112
|
-
return {self.
|
|
110
|
+
return {self.param_name: self.current_value_}
|
|
113
111
|
|
|
114
112
|
|
|
115
113
|
class CosineScheduler(_CosineSchedulerCore):
|
|
116
|
-
"""
|
|
114
|
+
"""Applies a cosine schedule to an optimizer param-group field."""
|
|
117
115
|
|
|
118
116
|
def __init__(
|
|
119
117
|
self,
|
|
120
118
|
optimizer: torch.optim.Optimizer,
|
|
121
119
|
param_group_field: str,
|
|
122
|
-
|
|
120
|
+
num_iters: int,
|
|
123
121
|
base_value: float,
|
|
124
122
|
final_value: float,
|
|
125
|
-
|
|
123
|
+
warmup_ratio: float | None = None,
|
|
126
124
|
warmup_value: float | None = None,
|
|
127
125
|
freeze_ratio: float | None = None,
|
|
128
126
|
multiplier_field: str | None = None,
|
|
@@ -131,21 +129,21 @@ class CosineScheduler(_CosineSchedulerCore):
|
|
|
131
129
|
ignore_if_field: str | None = None,
|
|
132
130
|
) -> None:
|
|
133
131
|
"""
|
|
134
|
-
|
|
132
|
+
Configure cosine scheduling for matching optimizer groups.
|
|
135
133
|
|
|
136
134
|
Args:
|
|
137
|
-
optimizer:
|
|
138
|
-
param_group_field: Name of the
|
|
139
|
-
|
|
140
|
-
base_value:
|
|
141
|
-
final_value:
|
|
142
|
-
|
|
143
|
-
warmup_value:
|
|
144
|
-
freeze_ratio:
|
|
145
|
-
multiplier_field:
|
|
146
|
-
skip_if_zero:
|
|
147
|
-
apply_if_field:
|
|
148
|
-
ignore_if_field:
|
|
135
|
+
optimizer: Optimizer whose param groups are updated in-place.
|
|
136
|
+
param_group_field: Name of the field that receives the scheduled value.
|
|
137
|
+
num_iters: Number of scheduler iterations before clamping at ``final_value``.
|
|
138
|
+
base_value: Value used on the first cosine step (after warmup/freeze).
|
|
139
|
+
final_value: Value approached as iterations progress.
|
|
140
|
+
warmup_ratio: Optional fraction of iterations to linearly ramp from ``warmup_value`` to ``base_value``.
|
|
141
|
+
warmup_value: Starting value for the warmup ramp.
|
|
142
|
+
freeze_ratio: Optional fraction of iterations to keep the value frozen at zero at the beginning.
|
|
143
|
+
multiplier_field: Optional per-group multiplier applied to the scheduled value.
|
|
144
|
+
skip_if_zero: Leave groups untouched when their target field equals zero.
|
|
145
|
+
apply_if_field: Require this flag to be present in a param group before updating.
|
|
146
|
+
ignore_if_field: Skip groups that declare this flag.
|
|
149
147
|
|
|
150
148
|
"""
|
|
151
149
|
self.apply_if_field = apply_if_field
|
|
@@ -154,14 +152,15 @@ class CosineScheduler(_CosineSchedulerCore):
|
|
|
154
152
|
self.multiplier_field = multiplier_field
|
|
155
153
|
self.skip_if_zero = skip_if_zero
|
|
156
154
|
super().__init__(
|
|
157
|
-
|
|
158
|
-
|
|
155
|
+
param_name=param_group_field,
|
|
156
|
+
num_iters=num_iters,
|
|
159
157
|
base_value=base_value,
|
|
160
158
|
final_value=final_value,
|
|
161
|
-
|
|
159
|
+
warmup_ratio=warmup_ratio,
|
|
162
160
|
warmup_value=warmup_value,
|
|
163
161
|
freeze_ratio=freeze_ratio,
|
|
164
162
|
)
|
|
163
|
+
self.param_group_field = param_group_field
|
|
165
164
|
return
|
|
166
165
|
|
|
167
166
|
@override
|
|
@@ -194,14 +193,7 @@ class CosineScheduler(_CosineSchedulerCore):
|
|
|
194
193
|
|
|
195
194
|
|
|
196
195
|
class CosineParamScheduler(_CosineSchedulerCore):
|
|
197
|
-
"""
|
|
198
|
-
CosineParamScheduler adjusts a parameter value using a cosine annealing scheduler.
|
|
199
|
-
|
|
200
|
-
This class provides a mechanism to schedule the value of a parameter over a
|
|
201
|
-
predefined number of iterations. It supports linear warm-up and optional freezing
|
|
202
|
-
periods before the cosine annealing wave begins. The scheduler can be used to
|
|
203
|
-
gradually transition a parameter value from a starting value to a final value.
|
|
204
|
-
"""
|
|
196
|
+
"""Standalone cosine scheduler for non-optimizer parameters."""
|
|
205
197
|
|
|
206
198
|
@override
|
|
207
199
|
def step(self, it: int) -> float:
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from typing import override
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from .base import BaseScheduler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class _LinearScheduleBase(BaseScheduler):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
param_name: str,
|
|
15
|
+
num_iters: int,
|
|
16
|
+
base_value: float,
|
|
17
|
+
final_value: float,
|
|
18
|
+
) -> None:
|
|
19
|
+
self.param_name = param_name
|
|
20
|
+
self.num_iters = num_iters
|
|
21
|
+
self.base_value = base_value
|
|
22
|
+
self.final_value = final_value
|
|
23
|
+
|
|
24
|
+
self.scheduler_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
|
|
25
|
+
self.current_value_ = self.base_value
|
|
26
|
+
return
|
|
27
|
+
|
|
28
|
+
def _create_scheduler(self) -> None:
|
|
29
|
+
self.scheduler_values = np.linspace(
|
|
30
|
+
self.base_value, self.final_value, num=self.num_iters, dtype=np.float64
|
|
31
|
+
)
|
|
32
|
+
if len(self.scheduler_values) != self.num_iters:
|
|
33
|
+
raise ValueError(
|
|
34
|
+
f"Scheduler length ({len(self.scheduler_values)}) does not match total_iters ({self.num_iters})."
|
|
35
|
+
)
|
|
36
|
+
return
|
|
37
|
+
|
|
38
|
+
@override
|
|
39
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
40
|
+
super().load_state_dict(state_dict)
|
|
41
|
+
self.scheduler_values = np.array([], dtype=np.float64)
|
|
42
|
+
return
|
|
43
|
+
|
|
44
|
+
@override
|
|
45
|
+
def step(self, it: int) -> None | float:
|
|
46
|
+
raise NotImplementedError
|
|
47
|
+
|
|
48
|
+
def _get_value(self, it: int) -> float:
|
|
49
|
+
if len(self.scheduler_values) == 0:
|
|
50
|
+
self._create_scheduler()
|
|
51
|
+
|
|
52
|
+
if it >= self.num_iters:
|
|
53
|
+
value: float = self.final_value
|
|
54
|
+
else:
|
|
55
|
+
value: float = self.scheduler_values[it]
|
|
56
|
+
self.current_value_ = value
|
|
57
|
+
return value
|
|
58
|
+
|
|
59
|
+
@override
|
|
60
|
+
def current_value(self) -> dict[str, float]:
|
|
61
|
+
return {self.param_name: self.current_value_}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class LinearScheduler(_LinearScheduleBase):
|
|
65
|
+
"""Implements a linear scheduler for adjusting parameter values in torch.optim.Optimizer."""
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
optimizer: torch.optim.Optimizer,
|
|
70
|
+
param_group_field: str,
|
|
71
|
+
num_iters: int,
|
|
72
|
+
base_value: float,
|
|
73
|
+
final_value: float,
|
|
74
|
+
multiplier_field: str | None = None,
|
|
75
|
+
skip_if_zero: bool = False,
|
|
76
|
+
apply_if_field: str | None = None,
|
|
77
|
+
ignore_if_field: str | None = None,
|
|
78
|
+
) -> None:
|
|
79
|
+
"""
|
|
80
|
+
Configure which optimizer groups get a linear value schedule.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
optimizer: Optimizer whose param groups are updated in-place.
|
|
84
|
+
param_group_field: Name of the field that receives the scheduled value.
|
|
85
|
+
num_iters: Number of scheduler iterations before clamping at ``final_value``.
|
|
86
|
+
base_value: Value used on the first iteration.
|
|
87
|
+
final_value: Value used once ``num_iters`` iterations are consumed.
|
|
88
|
+
multiplier_field: Optional per-group multiplier applied to the scheduled value.
|
|
89
|
+
skip_if_zero: Leave groups untouched when their target field equals zero.
|
|
90
|
+
apply_if_field: Require this flag to be present in a param group before updating.
|
|
91
|
+
ignore_if_field: Skip groups that declare this flag.
|
|
92
|
+
|
|
93
|
+
"""
|
|
94
|
+
self.apply_if_field = apply_if_field
|
|
95
|
+
self.ignore_if_field = ignore_if_field
|
|
96
|
+
self.optimizer = optimizer
|
|
97
|
+
self.multiplier_field = multiplier_field
|
|
98
|
+
self.skip_if_zero = skip_if_zero
|
|
99
|
+
super().__init__(
|
|
100
|
+
param_name=param_group_field,
|
|
101
|
+
num_iters=num_iters,
|
|
102
|
+
base_value=base_value,
|
|
103
|
+
final_value=final_value,
|
|
104
|
+
)
|
|
105
|
+
self.param_group_field = param_group_field
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
@override
|
|
109
|
+
def step(self, it: int) -> None:
|
|
110
|
+
value = self._get_value(it)
|
|
111
|
+
for pg in self.optimizer.param_groups:
|
|
112
|
+
if self.param_group_field not in pg:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"Parameter group field '{self.param_group_field}' not found in optimizer parameter groups."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if (self.apply_if_field is not None) and (self.apply_if_field not in pg):
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
if (self.ignore_if_field is not None) and (self.ignore_if_field in pg):
|
|
121
|
+
continue
|
|
122
|
+
|
|
123
|
+
if self.skip_if_zero and pg[self.param_group_field] == 0:
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
if self.multiplier_field is not None:
|
|
127
|
+
if self.multiplier_field not in pg:
|
|
128
|
+
multiplier = 1.0
|
|
129
|
+
else:
|
|
130
|
+
multiplier = pg[self.multiplier_field]
|
|
131
|
+
pg[self.param_group_field] = value * multiplier
|
|
132
|
+
else:
|
|
133
|
+
pg[self.param_group_field] = value
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class LinearParamScheduler(_LinearScheduleBase):
|
|
138
|
+
"""LinearParamScheduler adjusts a parameter value using a linear scheduler."""
|
|
139
|
+
|
|
140
|
+
@override
|
|
141
|
+
def step(self, it: int) -> float:
|
|
142
|
+
"""
|
|
143
|
+
Computes the value corresponding to the given iteration step.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
it: The current iteration index used for value computation.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
The computed value for the provided iteration step as a float.
|
|
150
|
+
|
|
151
|
+
"""
|
|
152
|
+
value = self._get_value(it)
|
|
153
|
+
return value
|
|
@@ -6,14 +6,14 @@ kostyl/ml/clearml/logging_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWa
|
|
|
6
6
|
kostyl/ml/clearml/pulling_utils.py,sha256=cNa_-_5LHjNVYi9btXBrfl5sPvI6BAAlIFidtpKu310,4078
|
|
7
7
|
kostyl/ml/configs/__init__.py,sha256=IetcivbqYGutowLqxdKp7QR4tkXKBr4m8t4Zkk9jHZU,911
|
|
8
8
|
kostyl/ml/configs/base_model.py,sha256=Eofn14J9RsjpVx_J4rp6C19pDDCANU4hr3JtX-d0FpQ,4820
|
|
9
|
-
kostyl/ml/configs/hyperparams.py,sha256=
|
|
9
|
+
kostyl/ml/configs/hyperparams.py,sha256=2S_VEZ07RWquNFSWjHBb3OUpBlTznbUpFSchzMpSBOc,2879
|
|
10
10
|
kostyl/ml/configs/training_settings.py,sha256=Sq2tiRuwkbmi9zKDG2JghZLXo5DDt_eQqN_KYJSdcTY,2509
|
|
11
11
|
kostyl/ml/dist_utils.py,sha256=Onf0KHVLA8oeUgZTcTdmR9qiM22f2uYLoNwgLbMGJWk,3495
|
|
12
12
|
kostyl/ml/lightning/__init__.py,sha256=-F3JAyq8KU1d-nACWryGu8d1CbvWbQ1rXFdeRwfE2X8,175
|
|
13
13
|
kostyl/ml/lightning/callbacks/__init__.py,sha256=enexQt3octktsTiEYHltSF_24CM-NeFEVFimXiavGiY,296
|
|
14
|
-
kostyl/ml/lightning/callbacks/checkpoint.py,sha256=
|
|
14
|
+
kostyl/ml/lightning/callbacks/checkpoint.py,sha256=KNwNVB2TFh2dcn133NbeTo5ul0jgiPYCeA-8NQ7U_mw,18951
|
|
15
15
|
kostyl/ml/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
|
|
16
|
-
kostyl/ml/lightning/callbacks/registry_uploader.py,sha256=
|
|
16
|
+
kostyl/ml/lightning/callbacks/registry_uploader.py,sha256=pIZHzHVANO_VsxPIbYhS8SwgZFHL341mP2HJnQ4iMFs,4216
|
|
17
17
|
kostyl/ml/lightning/extenstions/__init__.py,sha256=OY6QGv1agYgqqKf1xJBrxgp_i8FunVfPzYezfaRrGXU,182
|
|
18
18
|
kostyl/ml/lightning/extenstions/custom_module.py,sha256=nB5jW7cqRD1tyh-q5LD2EtiFQwFkLXpnS9Yu6c5xMRg,5987
|
|
19
19
|
kostyl/ml/lightning/extenstions/pretrained_model.py,sha256=QJGr2UvYJcU2Gy2w8z_cEvTodjv7hGdd2PPPfdOI-Mw,4017
|
|
@@ -25,11 +25,12 @@ kostyl/ml/params_groups.py,sha256=nUyw5d06Pvy9QPiYtZzLYR87xwXqJLxbHthgQH8oSCM,35
|
|
|
25
25
|
kostyl/ml/schedulers/__init__.py,sha256=bxXbsU_WYnVbhvNNnuI7cOAh2Axz7D25TaleBTZhYfc,197
|
|
26
26
|
kostyl/ml/schedulers/base.py,sha256=9M2iOoOVSRojR_liPX1qo3Nn4iMXSM5ZJuAFWZTulUk,1327
|
|
27
27
|
kostyl/ml/schedulers/composite.py,sha256=ee4xlMDMMtjKPkbTF2ue9GTr9DuGCGjZWf11mHbi6aE,2387
|
|
28
|
-
kostyl/ml/schedulers/cosine.py,sha256=
|
|
28
|
+
kostyl/ml/schedulers/cosine.py,sha256=t74_ByT22L5NQKpnBVU9UGzBVx1ZM2GTylb9ct3_PVg,7627
|
|
29
|
+
kostyl/ml/schedulers/linear.py,sha256=62mYEfd_2cQjOWrd0Vl5_sFeEokBKYmx496szhY04aU,5159
|
|
29
30
|
kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
|
|
30
31
|
kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
|
|
31
32
|
kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
|
|
32
33
|
kostyl/utils/logging.py,sha256=Vye0u4-yeOSUc-f03gpQbxSktTbFiilTWLEVr00ZHvc,5796
|
|
33
|
-
kostyl_toolkit-0.1.
|
|
34
|
-
kostyl_toolkit-0.1.
|
|
35
|
-
kostyl_toolkit-0.1.
|
|
34
|
+
kostyl_toolkit-0.1.27.dist-info/WHEEL,sha256=ZyFSCYkV2BrxH6-HRVRg3R9Fo7MALzer9KiPYqNxSbo,79
|
|
35
|
+
kostyl_toolkit-0.1.27.dist-info/METADATA,sha256=kg7Y2CJqhAI-3--rIKsPlarm1Ukk6jQLJpW2ZBvysI8,4269
|
|
36
|
+
kostyl_toolkit-0.1.27.dist-info/RECORD,,
|