kostyl-toolkit 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,15 +29,11 @@ class Lr(BaseModel):
29
29
  @model_validator(mode="after")
30
30
  def validate_warmup(self) -> "Lr":
31
31
  """Validates the warmup parameters based on use_scheduler."""
32
- if (self.warmup_value is None) != (
33
- self.warmup_iters_ratio is None
34
- ) and self.use_scheduler:
32
+ if (self.warmup_value is None) != (self.warmup_iters_ratio is None): # fmt: skip
35
33
  raise ValueError(
36
34
  "Both warmup_value and warmup_iters_ratio must be provided or neither"
37
35
  )
38
- elif (
39
- (self.warmup_value is not None) or (self.warmup_iters_ratio is not None)
40
- ) and (not self.use_scheduler):
36
+ if ((self.warmup_value is not None) or (self.warmup_iters_ratio is not None)) and not self.use_scheduler: # fmt: skip
41
37
  logger.warning(
42
38
  "use_scheduler is False, warmup_value and warmup_iters_ratio will be ignored."
43
39
  )
@@ -2,12 +2,16 @@ from datetime import timedelta
2
2
  from pathlib import Path
3
3
  from shutil import rmtree
4
4
  from typing import Literal
5
+ from typing import override
5
6
 
7
+ import lightning.pytorch as pl
8
+ import torch.distributed as dist
6
9
  from lightning.fabric.utilities.types import _PATH
7
10
  from lightning.pytorch.callbacks import ModelCheckpoint
8
11
 
9
12
  from kostyl.ml.configs import CheckpointConfig
10
13
  from kostyl.ml.dist_utils import is_main_process
14
+ from kostyl.ml.lightning import KostylLightningModule
11
15
  from kostyl.utils import setup_logger
12
16
 
13
17
  from .registry_uploader import RegistryUploaderCallback
@@ -16,7 +20,7 @@ from .registry_uploader import RegistryUploaderCallback
16
20
  logger = setup_logger("callbacks/checkpoint.py")
17
21
 
18
22
 
19
- class CustomModelCheckpoint(ModelCheckpoint):
23
+ class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
20
24
  r"""
21
25
  Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
22
26
  :class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
@@ -226,6 +230,8 @@ class CustomModelCheckpoint(ModelCheckpoint):
226
230
 
227
231
  def __init__( # noqa: D107
228
232
  self,
233
+ registry_uploader_callback: RegistryUploaderCallback,
234
+ uploading_mode: Literal["only-best", "every-checkpoint"] = "only-best",
229
235
  dirpath: _PATH | None = None,
230
236
  filename: str | None = None,
231
237
  monitor: str | None = None,
@@ -241,10 +247,10 @@ class CustomModelCheckpoint(ModelCheckpoint):
241
247
  every_n_epochs: int | None = None,
242
248
  save_on_train_epoch_end: bool | None = None,
243
249
  enable_version_counter: bool = True,
244
- registry_uploader_callback: RegistryUploaderCallback | None = None,
245
250
  ) -> None:
246
251
  self.registry_uploader_callback = registry_uploader_callback
247
- self._custom_best_model_path = ""
252
+ self.process_group: dist.ProcessGroup | None = None
253
+ self.uploading_mode = uploading_mode
248
254
  super().__init__(
249
255
  dirpath=dirpath,
250
256
  filename=filename,
@@ -264,16 +270,30 @@ class CustomModelCheckpoint(ModelCheckpoint):
264
270
  )
265
271
  return
266
272
 
267
- @property
268
- def best_model_path(self) -> str:
269
- """Best model path."""
270
- return self._custom_best_model_path
273
+ @override
274
+ def setup(
275
+ self,
276
+ trainer: pl.Trainer,
277
+ pl_module: pl.LightningModule | KostylLightningModule,
278
+ stage: str,
279
+ ) -> None:
280
+ super().setup(trainer, pl_module, stage)
281
+ if isinstance(pl_module, KostylLightningModule):
282
+ self.process_group = pl_module.get_process_group()
283
+ return
271
284
 
272
- @best_model_path.setter
273
- def best_model_path(self, value: str) -> None:
274
- self._custom_best_model_path = value
275
- if self.registry_uploader_callback is not None:
276
- self.registry_uploader_callback.best_model_path = value
285
+ @override
286
+ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
287
+ super()._save_checkpoint(trainer, filepath)
288
+ if dist.is_initialized():
289
+ dist.barrier(group=self.process_group)
290
+ if trainer.is_global_zero and self.registry_uploader_callback is not None:
291
+ match self.uploading_mode:
292
+ case "every-checkpoint":
293
+ self.registry_uploader_callback.upload_checkpoint(filepath)
294
+ case "only-best":
295
+ if filepath == self.best_model_path:
296
+ self.registry_uploader_callback.upload_checkpoint(filepath)
277
297
  return
278
298
 
279
299
 
@@ -282,28 +302,44 @@ def setup_checkpoint_callback(
282
302
  ckpt_cfg: CheckpointConfig,
283
303
  save_weights_only: bool = True,
284
304
  registry_uploader_callback: RegistryUploaderCallback | None = None,
285
- ) -> CustomModelCheckpoint:
305
+ uploading_mode: Literal["only-best", "every-checkpoint"] | None = None,
306
+ ) -> ModelCheckpointWithRegistryUploader | ModelCheckpoint:
286
307
  """
287
- Sets up a ModelCheckpoint callback for PyTorch Lightning.
308
+ Create and configure a checkpoint callback for model saving.
288
309
 
289
- This function prepares a checkpoint directory and configures a ModelCheckpoint
290
- callback based on the provided configuration. If the directory already exists,
291
- it is removed (only by the main process) to ensure a clean start. Otherwise,
292
- the directory is created.
310
+ Creates the checkpoint directory (removing existing one if present) and returns
311
+ a configured callback for saving models during training. When registry_uploader_callback
312
+ is provided, returns an extended version with support for uploading checkpoints to a remote registry.
293
313
 
294
314
  Args:
295
- dirpath (Path): The path to the directory where checkpoints will be saved.
296
- ckpt_cfg (CheckpointConfig): Configuration object containing checkpoint
297
- settings such as filename, save_top_k, monitor, and mode.
298
- save_weights_only (bool, optional): Whether to save only the model weights
299
- or the full model. Defaults to True.
300
- registry_uploader_callback (RegistryUploaderCallback | None, optional):
301
- An optional callback for uploading checkpoints to a registry. Defaults to None.
315
+ dirpath: Path to the directory for saving checkpoints.
316
+ ckpt_cfg: Checkpoint configuration (filename, monitor, mode, save_top_k).
317
+ save_weights_only: If True, only model weights are saved without optimizer and lr-scheduler state.
318
+ Defaults to True.
319
+ registry_uploader_callback: Optional callback for uploading checkpoints to a remote registry.
320
+ Must be specified together with uploading_mode.
321
+ uploading_mode: Checkpoint upload mode:
322
+ - "only-best": only the best checkpoint is uploaded
323
+ - "every-checkpoint": every saved checkpoint is uploaded
324
+ Must be specified together with registry_uploader_callback.
302
325
 
303
326
  Returns:
304
- ModelCheckpoint: The configured ModelCheckpoint callback instance.
327
+ ModelCheckpointWithRegistryUploader if registry_uploader_callback is provided,
328
+ otherwise standard ModelCheckpoint.
329
+
330
+ Raises:
331
+ ValueError: If only one of registry_uploader_callback or uploading_mode is None.
332
+
333
+ Note:
334
+ If the dirpath directory already exists, it will be removed and recreated
335
+ (only on the main process in distributed training).
305
336
 
306
337
  """
338
+ if (registry_uploader_callback is None) != (uploading_mode is None):
339
+ raise ValueError(
340
+ "Both registry_uploader_callback and uploading_mode must be provided or neither."
341
+ )
342
+
307
343
  if dirpath.exists():
308
344
  if is_main_process():
309
345
  logger.warning(f"Checkpoint directory {dirpath} already exists.")
@@ -313,14 +349,26 @@ def setup_checkpoint_callback(
313
349
  logger.info(f"Creating checkpoint directory {dirpath}.")
314
350
  dirpath.mkdir(parents=True, exist_ok=True)
315
351
 
316
- checkpoint_callback = CustomModelCheckpoint(
317
- dirpath=dirpath,
318
- filename=ckpt_cfg.filename,
319
- save_top_k=ckpt_cfg.save_top_k,
320
- monitor=ckpt_cfg.monitor,
321
- mode=ckpt_cfg.mode,
322
- verbose=True,
323
- save_weights_only=save_weights_only,
324
- registry_uploader_callback=registry_uploader_callback,
325
- )
352
+ if (registry_uploader_callback is not None) and (uploading_mode is not None):
353
+ checkpoint_callback = ModelCheckpointWithRegistryUploader(
354
+ dirpath=dirpath,
355
+ filename=ckpt_cfg.filename,
356
+ save_top_k=ckpt_cfg.save_top_k,
357
+ monitor=ckpt_cfg.monitor,
358
+ mode=ckpt_cfg.mode,
359
+ verbose=True,
360
+ save_weights_only=save_weights_only,
361
+ registry_uploader_callback=registry_uploader_callback,
362
+ uploading_mode=uploading_mode,
363
+ )
364
+ else:
365
+ checkpoint_callback = ModelCheckpoint(
366
+ dirpath=dirpath,
367
+ filename=ckpt_cfg.filename,
368
+ save_top_k=ckpt_cfg.save_top_k,
369
+ monitor=ckpt_cfg.monitor,
370
+ mode=ckpt_cfg.mode,
371
+ verbose=True,
372
+ save_weights_only=save_weights_only,
373
+ )
326
374
  return checkpoint_callback
@@ -1,40 +1,26 @@
1
1
  from abc import ABC
2
2
  from abc import abstractmethod
3
3
  from collections.abc import Callable
4
- from functools import partial
5
- from typing import Literal
4
+ from pathlib import Path
6
5
  from typing import override
7
6
 
8
7
  from clearml import OutputModel
9
8
  from clearml import Task
10
- from lightning import Trainer
11
- from lightning.pytorch.callbacks import Callback
12
9
 
13
10
  from kostyl.ml.clearml.logging_utils import find_version_in_tags
14
11
  from kostyl.ml.clearml.logging_utils import increment_version
15
- from kostyl.ml.lightning import KostylLightningModule
16
12
  from kostyl.utils.logging import setup_logger
17
13
 
18
14
 
19
15
  logger = setup_logger()
20
16
 
21
17
 
22
- class RegistryUploaderCallback(Callback, ABC):
18
+ class RegistryUploaderCallback(ABC):
23
19
  """Abstract Lightning callback responsible for tracking and uploading the best-performing model checkpoint."""
24
20
 
25
- @property
26
21
  @abstractmethod
27
- def best_model_path(self) -> str:
28
- """Return the file system path pointing to the best model artifact produced during training."""
29
- raise NotImplementedError
30
-
31
- @best_model_path.setter
32
- @abstractmethod
33
- def best_model_path(self, value: str) -> None:
34
- raise NotImplementedError
35
-
36
- @abstractmethod
37
- def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
22
+ def upload_checkpoint(self, path: str | Path) -> None:
23
+ """Upload the checkpoint located at the given path to the configured registry backend."""
38
24
  raise NotImplementedError
39
25
 
40
26
 
@@ -50,9 +36,6 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
50
36
  enable_tag_versioning: bool = True,
51
37
  label_enumeration: dict[str, int] | None = None,
52
38
  config_dict: dict[str, str] | None = None,
53
- uploading_frequency: Literal[
54
- "after-every-eval", "on-train-end"
55
- ] = "on-train-end",
56
39
  ) -> None:
57
40
  """
58
41
  Initializes the ClearMLRegistryUploaderCallback.
@@ -67,9 +50,6 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
67
50
  config_dict: Optional configuration dictionary to associate with the model.
68
51
  enable_tag_versioning: Whether to enable versioning in tags. If True,
69
52
  the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
70
- uploading_frequency: When to upload:
71
- - "after-every-eval": after each validation phase.
72
- - "on-train-end": once at the end of training.
73
53
 
74
54
  """
75
55
  super().__init__()
@@ -82,29 +62,16 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
82
62
  self.config_dict = config_dict
83
63
  self.label_enumeration = label_enumeration
84
64
  self.verbose = verbose
85
- self.uploading_frequency = uploading_frequency
86
65
  self.enable_tag_versioning = enable_tag_versioning
87
66
 
67
+ self.best_model_path: str = ""
68
+
88
69
  self._output_model: OutputModel | None = None
89
70
  self._last_uploaded_model_path: str = ""
90
- self._best_model_path: str = ""
91
71
  self._upload_callback: Callable | None = None
92
72
  return
93
73
 
94
- @property
95
- @override
96
- def best_model_path(self) -> str:
97
- return self._best_model_path
98
-
99
- @best_model_path.setter
100
- @override
101
- def best_model_path(self, value: str) -> None:
102
- self._best_model_path = value
103
- if self._upload_callback is not None:
104
- self._upload_callback()
105
- return
106
-
107
- def _create_output_model(self, pl_module: "KostylLightningModule") -> OutputModel:
74
+ def _create_output_model(self) -> OutputModel:
108
75
  if self.enable_tag_versioning:
109
76
  version = find_version_in_tags(self.output_model_tags)
110
77
  if version is None:
@@ -117,13 +84,6 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
117
84
  if "LightningCheckpoint" not in self.output_model_tags:
118
85
  self.output_model_tags.append("LightningCheckpoint")
119
86
 
120
- if self.config_dict is None:
121
- config = pl_module.model_config
122
- if config is not None:
123
- config = config.to_dict()
124
- else:
125
- config = self.config_dict
126
-
127
87
  return OutputModel(
128
88
  task=self.task,
129
89
  name=self.output_model_name,
@@ -134,60 +94,29 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
134
94
  )
135
95
 
136
96
  @override
137
- def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
138
- if not self._best_model_path or (
139
- self._best_model_path == self._last_uploaded_model_path
140
- ):
141
- if not self._best_model_path:
142
- if self.verbose:
143
- logger.info("No best model found yet to upload")
144
- elif self._best_model_path == self._last_uploaded_model_path:
145
- if self.verbose:
146
- logger.info("Best model unchanged since last upload")
147
- self._upload_callback = partial(self._upload_best_checkpoint, pl_module)
97
+ def upload_checkpoint(
98
+ self,
99
+ path: str | Path,
100
+ ) -> None:
101
+ if isinstance(path, Path):
102
+ path = str(path)
103
+ if path == self._last_uploaded_model_path:
104
+ if self.verbose:
105
+ logger.info("Model unchanged since last upload")
148
106
  return
149
- self._upload_callback = None
150
107
 
151
108
  if self._output_model is None:
152
- self._output_model = self._create_output_model(pl_module)
109
+ self._output_model = self._create_output_model()
153
110
 
154
111
  if self.verbose:
155
- logger.info(f"Uploading best model from {self._best_model_path}")
112
+ logger.info(f"Uploading model from {path}")
156
113
 
157
114
  self._output_model.update_weights(
158
- self._best_model_path,
115
+ path,
159
116
  auto_delete_file=False,
160
117
  async_enable=False,
161
118
  )
162
- if self.config_dict is None:
163
- config = pl_module.model_config
164
- if config is not None:
165
- config = config.to_dict()
166
- else:
167
- config = self.config_dict
168
- self._output_model.update_design(config_dict=config)
169
-
170
- self._last_uploaded_model_path = self._best_model_path
171
- return
172
-
173
- @override
174
- def on_validation_end(
175
- self, trainer: Trainer, pl_module: "KostylLightningModule"
176
- ) -> None:
177
- if self.uploading_frequency != "after-every-eval":
178
- return
179
- if not trainer.is_global_zero:
180
- return
181
-
182
- self._upload_best_checkpoint(pl_module)
183
- return
184
-
185
- @override
186
- def on_train_end(
187
- self, trainer: Trainer, pl_module: "KostylLightningModule"
188
- ) -> None:
189
- if not trainer.is_global_zero:
190
- return
119
+ self._output_model.update_design(config_dict=self.config_dict)
191
120
 
192
- self._upload_best_checkpoint(pl_module)
121
+ self._last_uploaded_model_path = path
193
122
  return
@@ -11,20 +11,18 @@ from .base import BaseScheduler
11
11
  class _CosineSchedulerCore(BaseScheduler):
12
12
  def __init__(
13
13
  self,
14
- param_group_field: str,
15
- total_iters: int,
14
+ param_name: str,
15
+ num_iters: int,
16
16
  base_value: float,
17
17
  final_value: float,
18
- warmup_iters_ratio: float | None = None,
18
+ warmup_ratio: float | None = None,
19
19
  warmup_value: float | None = None,
20
20
  freeze_ratio: float | None = None,
21
21
  ) -> None:
22
- if warmup_iters_ratio is not None:
23
- if not (0 < warmup_iters_ratio < 1):
24
- raise ValueError(
25
- f"Warmup ratio must be in (0, 1), got {warmup_iters_ratio}."
26
- )
27
- if (warmup_value is None) != (warmup_iters_ratio is None):
22
+ if warmup_ratio is not None:
23
+ if not (0 < warmup_ratio < 1):
24
+ raise ValueError(f"Warmup ratio must be in (0, 1), got {warmup_ratio}.")
25
+ if (warmup_value is None) != (warmup_ratio is None):
28
26
  raise ValueError(
29
27
  "Both warmup_ratio and warmup_value must be provided or neither."
30
28
  )
@@ -32,12 +30,12 @@ class _CosineSchedulerCore(BaseScheduler):
32
30
  if not (0 < freeze_ratio < 1):
33
31
  raise ValueError(f"Freeze ratio must be in (0, 1), got {freeze_ratio}.")
34
32
 
35
- self.param_group_field = param_group_field
36
- self.total_iters = total_iters
33
+ self.param_name = param_name
34
+ self.num_iters = num_iters
37
35
  self.base_value = base_value
38
36
  self.final_value = final_value
39
37
 
40
- self.warmup_iters_ratio = warmup_iters_ratio
38
+ self.warmup_ratio = warmup_ratio
41
39
  self.warmup_value = warmup_value
42
40
 
43
41
  self.freeze_ratio = freeze_ratio
@@ -49,15 +47,15 @@ class _CosineSchedulerCore(BaseScheduler):
49
47
  def _create_scheduler(self) -> None:
50
48
  # Create freeze schedule
51
49
  if self.freeze_ratio is not None:
52
- freeze_iters = int(self.total_iters * self.freeze_ratio)
50
+ freeze_iters = int(self.num_iters * self.freeze_ratio)
53
51
  freeze_schedule = np.zeros(freeze_iters, dtype=np.float64)
54
52
  else:
55
53
  freeze_iters = 0
56
54
  freeze_schedule = np.array([], dtype=np.float64)
57
55
 
58
56
  # Create linear warmup schedule
59
- if self.warmup_iters_ratio is not None and self.warmup_value is not None:
60
- warmup_iters = int(self.total_iters * self.warmup_iters_ratio)
57
+ if self.warmup_ratio is not None and self.warmup_value is not None:
58
+ warmup_iters = int(self.num_iters * self.warmup_ratio)
61
59
  warmup_schedule = np.linspace(
62
60
  self.warmup_value, self.base_value, warmup_iters, dtype=np.float64
63
61
  )
@@ -65,7 +63,7 @@ class _CosineSchedulerCore(BaseScheduler):
65
63
  warmup_iters = 0
66
64
  warmup_schedule = np.array([], dtype=np.float64)
67
65
 
68
- cosine_annealing_iters = self.total_iters - warmup_iters - freeze_iters
66
+ cosine_annealing_iters = self.num_iters - warmup_iters - freeze_iters
69
67
  if cosine_annealing_iters <= 0:
70
68
  raise ValueError("Cosine annealing iters must be > 0.")
71
69
 
@@ -80,9 +78,9 @@ class _CosineSchedulerCore(BaseScheduler):
80
78
  (freeze_schedule, warmup_schedule, schedule)
81
79
  )
82
80
 
83
- if len(self.scheduler_values) != self.total_iters:
81
+ if len(self.scheduler_values) != self.num_iters:
84
82
  raise ValueError(
85
- f"Scheduler length ({len(self.scheduler_values)}) does not match total_iters ({self.total_iters})."
83
+ f"Scheduler length ({len(self.scheduler_values)}) does not match num_iters ({self.num_iters})."
86
84
  )
87
85
  return
88
86
 
@@ -100,7 +98,7 @@ class _CosineSchedulerCore(BaseScheduler):
100
98
  if len(self.scheduler_values) == 0:
101
99
  self._create_scheduler()
102
100
 
103
- if it >= self.total_iters:
101
+ if it >= self.num_iters:
104
102
  value: float = self.final_value
105
103
  else:
106
104
  value: float = self.scheduler_values[it]
@@ -109,20 +107,20 @@ class _CosineSchedulerCore(BaseScheduler):
109
107
 
110
108
  @override
111
109
  def current_value(self) -> dict[str, float]:
112
- return {self.param_group_field: self.current_value_}
110
+ return {self.param_name: self.current_value_}
113
111
 
114
112
 
115
113
  class CosineScheduler(_CosineSchedulerCore):
116
- """Implements a cosine scheduler for adjusting parameter values in torch.optim.Optimizer."""
114
+ """Applies a cosine schedule to an optimizer param-group field."""
117
115
 
118
116
  def __init__(
119
117
  self,
120
118
  optimizer: torch.optim.Optimizer,
121
119
  param_group_field: str,
122
- total_iters: int,
120
+ num_iters: int,
123
121
  base_value: float,
124
122
  final_value: float,
125
- warmup_iters_ratio: float | None = None,
123
+ warmup_ratio: float | None = None,
126
124
  warmup_value: float | None = None,
127
125
  freeze_ratio: float | None = None,
128
126
  multiplier_field: str | None = None,
@@ -131,21 +129,21 @@ class CosineScheduler(_CosineSchedulerCore):
131
129
  ignore_if_field: str | None = None,
132
130
  ) -> None:
133
131
  """
134
- Initialize the scheduler with optimizer and scheduling parameters.
132
+ Configure cosine scheduling for matching optimizer groups.
135
133
 
136
134
  Args:
137
- optimizer: PyTorch optimizer to schedule parameters for.
138
- param_group_field: Name of the parameter group field to modify (e.g., 'lr', 'weight_decay').
139
- total_iters: Total number of iterations for the scheduling.
140
- base_value: Initial value for the parameter.
141
- final_value: Final value for the parameter at the end of scheduling.
142
- warmup_iters_ratio: Ratio of total iterations to use for warmup phase. Defaults to None.
143
- warmup_value: Value to use during warmup phase. Defaults to None.
144
- freeze_ratio: Ratio of total iterations to freeze parameter updates. Defaults to None.
145
- multiplier_field: Field name for multiplier values in parameter groups. Defaults to None.
146
- skip_if_zero: Whether to skip scheduling if the parameter value is zero. Defaults to False.
147
- apply_if_field: Field name that must be present to apply scheduling. Defaults to None.
148
- ignore_if_field: Field name that when present causes scheduling to be ignored. Defaults to None.
135
+ optimizer: Optimizer whose param groups are updated in-place.
136
+ param_group_field: Name of the field that receives the scheduled value.
137
+ num_iters: Number of scheduler iterations before clamping at ``final_value``.
138
+ base_value: Value used on the first cosine step (after warmup/freeze).
139
+ final_value: Value approached as iterations progress.
140
+ warmup_ratio: Optional fraction of iterations to linearly ramp from ``warmup_value`` to ``base_value``.
141
+ warmup_value: Starting value for the warmup ramp.
142
+ freeze_ratio: Optional fraction of iterations to keep the value frozen at zero at the beginning.
143
+ multiplier_field: Optional per-group multiplier applied to the scheduled value.
144
+ skip_if_zero: Leave groups untouched when their target field equals zero.
145
+ apply_if_field: Require this flag to be present in a param group before updating.
146
+ ignore_if_field: Skip groups that declare this flag.
149
147
 
150
148
  """
151
149
  self.apply_if_field = apply_if_field
@@ -154,14 +152,15 @@ class CosineScheduler(_CosineSchedulerCore):
154
152
  self.multiplier_field = multiplier_field
155
153
  self.skip_if_zero = skip_if_zero
156
154
  super().__init__(
157
- param_group_field=param_group_field,
158
- total_iters=total_iters,
155
+ param_name=param_group_field,
156
+ num_iters=num_iters,
159
157
  base_value=base_value,
160
158
  final_value=final_value,
161
- warmup_iters_ratio=warmup_iters_ratio,
159
+ warmup_ratio=warmup_ratio,
162
160
  warmup_value=warmup_value,
163
161
  freeze_ratio=freeze_ratio,
164
162
  )
163
+ self.param_group_field = param_group_field
165
164
  return
166
165
 
167
166
  @override
@@ -194,14 +193,7 @@ class CosineScheduler(_CosineSchedulerCore):
194
193
 
195
194
 
196
195
  class CosineParamScheduler(_CosineSchedulerCore):
197
- """
198
- CosineParamScheduler adjusts a parameter value using a cosine annealing scheduler.
199
-
200
- This class provides a mechanism to schedule the value of a parameter over a
201
- predefined number of iterations. It supports linear warm-up and optional freezing
202
- periods before the cosine annealing wave begins. The scheduler can be used to
203
- gradually transition a parameter value from a starting value to a final value.
204
- """
196
+ """Standalone cosine scheduler for non-optimizer parameters."""
205
197
 
206
198
  @override
207
199
  def step(self, it: int) -> float:
@@ -0,0 +1,153 @@
1
+ from typing import Any
2
+ from typing import override
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import torch
7
+
8
+ from .base import BaseScheduler
9
+
10
+
11
+ class _LinearScheduleBase(BaseScheduler):
12
+ def __init__(
13
+ self,
14
+ param_name: str,
15
+ num_iters: int,
16
+ base_value: float,
17
+ final_value: float,
18
+ ) -> None:
19
+ self.param_name = param_name
20
+ self.num_iters = num_iters
21
+ self.base_value = base_value
22
+ self.final_value = final_value
23
+
24
+ self.scheduler_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
25
+ self.current_value_ = self.base_value
26
+ return
27
+
28
+ def _create_scheduler(self) -> None:
29
+ self.scheduler_values = np.linspace(
30
+ self.base_value, self.final_value, num=self.num_iters, dtype=np.float64
31
+ )
32
+ if len(self.scheduler_values) != self.num_iters:
33
+ raise ValueError(
34
+ f"Scheduler length ({len(self.scheduler_values)}) does not match total_iters ({self.num_iters})."
35
+ )
36
+ return
37
+
38
+ @override
39
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
40
+ super().load_state_dict(state_dict)
41
+ self.scheduler_values = np.array([], dtype=np.float64)
42
+ return
43
+
44
+ @override
45
+ def step(self, it: int) -> None | float:
46
+ raise NotImplementedError
47
+
48
+ def _get_value(self, it: int) -> float:
49
+ if len(self.scheduler_values) == 0:
50
+ self._create_scheduler()
51
+
52
+ if it >= self.num_iters:
53
+ value: float = self.final_value
54
+ else:
55
+ value: float = self.scheduler_values[it]
56
+ self.current_value_ = value
57
+ return value
58
+
59
+ @override
60
+ def current_value(self) -> dict[str, float]:
61
+ return {self.param_name: self.current_value_}
62
+
63
+
64
+ class LinearScheduler(_LinearScheduleBase):
65
+ """Implements a linear scheduler for adjusting parameter values in torch.optim.Optimizer."""
66
+
67
+ def __init__(
68
+ self,
69
+ optimizer: torch.optim.Optimizer,
70
+ param_group_field: str,
71
+ num_iters: int,
72
+ base_value: float,
73
+ final_value: float,
74
+ multiplier_field: str | None = None,
75
+ skip_if_zero: bool = False,
76
+ apply_if_field: str | None = None,
77
+ ignore_if_field: str | None = None,
78
+ ) -> None:
79
+ """
80
+ Configure which optimizer groups get a linear value schedule.
81
+
82
+ Args:
83
+ optimizer: Optimizer whose param groups are updated in-place.
84
+ param_group_field: Name of the field that receives the scheduled value.
85
+ num_iters: Number of scheduler iterations before clamping at ``final_value``.
86
+ base_value: Value used on the first iteration.
87
+ final_value: Value used once ``num_iters`` iterations are consumed.
88
+ multiplier_field: Optional per-group multiplier applied to the scheduled value.
89
+ skip_if_zero: Leave groups untouched when their target field equals zero.
90
+ apply_if_field: Require this flag to be present in a param group before updating.
91
+ ignore_if_field: Skip groups that declare this flag.
92
+
93
+ """
94
+ self.apply_if_field = apply_if_field
95
+ self.ignore_if_field = ignore_if_field
96
+ self.optimizer = optimizer
97
+ self.multiplier_field = multiplier_field
98
+ self.skip_if_zero = skip_if_zero
99
+ super().__init__(
100
+ param_name=param_group_field,
101
+ num_iters=num_iters,
102
+ base_value=base_value,
103
+ final_value=final_value,
104
+ )
105
+ self.param_group_field = param_group_field
106
+ return
107
+
108
+ @override
109
+ def step(self, it: int) -> None:
110
+ value = self._get_value(it)
111
+ for pg in self.optimizer.param_groups:
112
+ if self.param_group_field not in pg:
113
+ raise ValueError(
114
+ f"Parameter group field '{self.param_group_field}' not found in optimizer parameter groups."
115
+ )
116
+
117
+ if (self.apply_if_field is not None) and (self.apply_if_field not in pg):
118
+ continue
119
+
120
+ if (self.ignore_if_field is not None) and (self.ignore_if_field in pg):
121
+ continue
122
+
123
+ if self.skip_if_zero and pg[self.param_group_field] == 0:
124
+ continue
125
+
126
+ if self.multiplier_field is not None:
127
+ if self.multiplier_field not in pg:
128
+ multiplier = 1.0
129
+ else:
130
+ multiplier = pg[self.multiplier_field]
131
+ pg[self.param_group_field] = value * multiplier
132
+ else:
133
+ pg[self.param_group_field] = value
134
+ return
135
+
136
+
137
+ class LinearParamScheduler(_LinearScheduleBase):
138
+ """LinearParamScheduler adjusts a parameter value using a linear scheduler."""
139
+
140
+ @override
141
+ def step(self, it: int) -> float:
142
+ """
143
+ Computes the value corresponding to the given iteration step.
144
+
145
+ Args:
146
+ it: The current iteration index used for value computation.
147
+
148
+ Returns:
149
+ The computed value for the provided iteration step as a float.
150
+
151
+ """
152
+ value = self._get_value(it)
153
+ return value
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.25
3
+ Version: 0.1.27
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -6,14 +6,14 @@ kostyl/ml/clearml/logging_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWa
6
6
  kostyl/ml/clearml/pulling_utils.py,sha256=cNa_-_5LHjNVYi9btXBrfl5sPvI6BAAlIFidtpKu310,4078
7
7
  kostyl/ml/configs/__init__.py,sha256=IetcivbqYGutowLqxdKp7QR4tkXKBr4m8t4Zkk9jHZU,911
8
8
  kostyl/ml/configs/base_model.py,sha256=Eofn14J9RsjpVx_J4rp6C19pDDCANU4hr3JtX-d0FpQ,4820
9
- kostyl/ml/configs/hyperparams.py,sha256=OqN7mEj3zc5MTqBPCZL3Lcd2VCTDLo_K0yvhRWGfhCs,2924
9
+ kostyl/ml/configs/hyperparams.py,sha256=2S_VEZ07RWquNFSWjHBb3OUpBlTznbUpFSchzMpSBOc,2879
10
10
  kostyl/ml/configs/training_settings.py,sha256=Sq2tiRuwkbmi9zKDG2JghZLXo5DDt_eQqN_KYJSdcTY,2509
11
11
  kostyl/ml/dist_utils.py,sha256=Onf0KHVLA8oeUgZTcTdmR9qiM22f2uYLoNwgLbMGJWk,3495
12
12
  kostyl/ml/lightning/__init__.py,sha256=-F3JAyq8KU1d-nACWryGu8d1CbvWbQ1rXFdeRwfE2X8,175
13
13
  kostyl/ml/lightning/callbacks/__init__.py,sha256=enexQt3octktsTiEYHltSF_24CM-NeFEVFimXiavGiY,296
14
- kostyl/ml/lightning/callbacks/checkpoint.py,sha256=1gk5-NjsMXe5cZP0OgNcoc9KUTzRDTHIokVEDr74sjI,16740
14
+ kostyl/ml/lightning/callbacks/checkpoint.py,sha256=KNwNVB2TFh2dcn133NbeTo5ul0jgiPYCeA-8NQ7U_mw,18951
15
15
  kostyl/ml/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
16
- kostyl/ml/lightning/callbacks/registry_uploader.py,sha256=ksoh02dzIde4E_GaZykfiOgfSjZti-IJt_i61enem3s,6779
16
+ kostyl/ml/lightning/callbacks/registry_uploader.py,sha256=pIZHzHVANO_VsxPIbYhS8SwgZFHL341mP2HJnQ4iMFs,4216
17
17
  kostyl/ml/lightning/extenstions/__init__.py,sha256=OY6QGv1agYgqqKf1xJBrxgp_i8FunVfPzYezfaRrGXU,182
18
18
  kostyl/ml/lightning/extenstions/custom_module.py,sha256=nB5jW7cqRD1tyh-q5LD2EtiFQwFkLXpnS9Yu6c5xMRg,5987
19
19
  kostyl/ml/lightning/extenstions/pretrained_model.py,sha256=QJGr2UvYJcU2Gy2w8z_cEvTodjv7hGdd2PPPfdOI-Mw,4017
@@ -25,11 +25,12 @@ kostyl/ml/params_groups.py,sha256=nUyw5d06Pvy9QPiYtZzLYR87xwXqJLxbHthgQH8oSCM,35
25
25
  kostyl/ml/schedulers/__init__.py,sha256=bxXbsU_WYnVbhvNNnuI7cOAh2Axz7D25TaleBTZhYfc,197
26
26
  kostyl/ml/schedulers/base.py,sha256=9M2iOoOVSRojR_liPX1qo3Nn4iMXSM5ZJuAFWZTulUk,1327
27
27
  kostyl/ml/schedulers/composite.py,sha256=ee4xlMDMMtjKPkbTF2ue9GTr9DuGCGjZWf11mHbi6aE,2387
28
- kostyl/ml/schedulers/cosine.py,sha256=jufULVHn_L_ZZEc3ZTG3QCY_pc0jlAMH5Aw496T31jo,8203
28
+ kostyl/ml/schedulers/cosine.py,sha256=t74_ByT22L5NQKpnBVU9UGzBVx1ZM2GTylb9ct3_PVg,7627
29
+ kostyl/ml/schedulers/linear.py,sha256=62mYEfd_2cQjOWrd0Vl5_sFeEokBKYmx496szhY04aU,5159
29
30
  kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
30
31
  kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
31
32
  kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
32
33
  kostyl/utils/logging.py,sha256=Vye0u4-yeOSUc-f03gpQbxSktTbFiilTWLEVr00ZHvc,5796
33
- kostyl_toolkit-0.1.25.dist-info/WHEEL,sha256=z-mOpxbJHqy3cq6SvUThBZdaLGFZzdZPtgWLcP2NKjQ,79
34
- kostyl_toolkit-0.1.25.dist-info/METADATA,sha256=Go9dF8W4vQ4HzQpzRCLa-i9NsYq5o0J34NtvGWZnEvA,4269
35
- kostyl_toolkit-0.1.25.dist-info/RECORD,,
34
+ kostyl_toolkit-0.1.27.dist-info/WHEEL,sha256=ZyFSCYkV2BrxH6-HRVRg3R9Fo7MALzer9KiPYqNxSbo,79
35
+ kostyl_toolkit-0.1.27.dist-info/METADATA,sha256=kg7Y2CJqhAI-3--rIKsPlarm1Ukk6jQLJpW2ZBvysI8,4269
36
+ kostyl_toolkit-0.1.27.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: uv 0.9.15
2
+ Generator: uv 0.9.18
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any