kostyl-toolkit 0.1.24__tar.gz → 0.1.26__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/PKG-INFO +1 -1
  2. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/configs/hyperparams.py +4 -8
  3. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/lightning/callbacks/checkpoint.py +85 -38
  4. kostyl_toolkit-0.1.26/kostyl/ml/lightning/callbacks/registry_uploader.py +122 -0
  5. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/schedulers/cosine.py +46 -54
  6. kostyl_toolkit-0.1.26/kostyl/ml/schedulers/linear.py +153 -0
  7. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/pyproject.toml +1 -1
  8. kostyl_toolkit-0.1.24/kostyl/ml/lightning/callbacks/registry_uploader.py +0 -193
  9. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/README.md +0 -0
  10. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/__init__.py +0 -0
  11. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/__init__.py +0 -0
  12. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/clearml/__init__.py +0 -0
  13. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/clearml/dataset_utils.py +0 -0
  14. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/clearml/logging_utils.py +0 -0
  15. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/clearml/pulling_utils.py +0 -0
  16. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/configs/__init__.py +0 -0
  17. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/configs/base_model.py +0 -0
  18. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/configs/training_settings.py +0 -0
  19. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/dist_utils.py +0 -0
  20. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/lightning/__init__.py +0 -0
  21. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/lightning/callbacks/__init__.py +0 -0
  22. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/lightning/callbacks/early_stopping.py +0 -0
  23. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/lightning/extenstions/__init__.py +0 -0
  24. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/lightning/extenstions/custom_module.py +0 -0
  25. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/lightning/extenstions/pretrained_model.py +0 -0
  26. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/lightning/loggers/__init__.py +0 -0
  27. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/lightning/loggers/tb_logger.py +0 -0
  28. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/lightning/steps_estimation.py +0 -0
  29. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/metrics_formatting.py +0 -0
  30. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/params_groups.py +0 -0
  31. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/schedulers/__init__.py +0 -0
  32. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/schedulers/base.py +0 -0
  33. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/ml/schedulers/composite.py +0 -0
  34. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/utils/__init__.py +0 -0
  35. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/utils/dict_manipulations.py +0 -0
  36. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/utils/fs.py +0 -0
  37. {kostyl_toolkit-0.1.24 → kostyl_toolkit-0.1.26}/kostyl/utils/logging.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.24
3
+ Version: 0.1.26
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -23,21 +23,17 @@ class Lr(BaseModel):
23
23
  default=None, gt=0, lt=1, validate_default=False
24
24
  )
25
25
  warmup_value: float | None = Field(default=None, gt=0, validate_default=False)
26
- base_value: float
26
+ start_value: float
27
27
  final_value: float | None = Field(default=None, gt=0, validate_default=False)
28
28
 
29
29
  @model_validator(mode="after")
30
30
  def validate_warmup(self) -> "Lr":
31
31
  """Validates the warmup parameters based on use_scheduler."""
32
- if (self.warmup_value is None) != (
33
- self.warmup_iters_ratio is None
34
- ) and self.use_scheduler:
32
+ if (self.warmup_value is None) != (self.warmup_iters_ratio is None): # fmt: skip
35
33
  raise ValueError(
36
34
  "Both warmup_value and warmup_iters_ratio must be provided or neither"
37
35
  )
38
- elif (
39
- (self.warmup_value is not None) or (self.warmup_iters_ratio is not None)
40
- ) and (not self.use_scheduler):
36
+ if ((self.warmup_value is not None) or (self.warmup_iters_ratio is not None)) and not self.use_scheduler: # fmt: skip
41
37
  logger.warning(
42
38
  "use_scheduler is False, warmup_value and warmup_iters_ratio will be ignored."
43
39
  )
@@ -60,7 +56,7 @@ class WeightDecay(BaseModel):
60
56
  """Weight decay hyperparameters configuration."""
61
57
 
62
58
  use_scheduler: bool = False
63
- base_value: float
59
+ start_value: float
64
60
  final_value: float | None = None
65
61
 
66
62
  @model_validator(mode="after")
@@ -2,13 +2,16 @@ from datetime import timedelta
2
2
  from pathlib import Path
3
3
  from shutil import rmtree
4
4
  from typing import Literal
5
- from typing import cast
5
+ from typing import override
6
6
 
7
+ import lightning.pytorch as pl
8
+ import torch.distributed as dist
7
9
  from lightning.fabric.utilities.types import _PATH
8
10
  from lightning.pytorch.callbacks import ModelCheckpoint
9
11
 
10
12
  from kostyl.ml.configs import CheckpointConfig
11
13
  from kostyl.ml.dist_utils import is_main_process
14
+ from kostyl.ml.lightning import KostylLightningModule
12
15
  from kostyl.utils import setup_logger
13
16
 
14
17
  from .registry_uploader import RegistryUploaderCallback
@@ -17,7 +20,7 @@ from .registry_uploader import RegistryUploaderCallback
17
20
  logger = setup_logger("callbacks/checkpoint.py")
18
21
 
19
22
 
20
- class CustomModelCheckpoint(ModelCheckpoint):
23
+ class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
21
24
  r"""
22
25
  Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
23
26
  :class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
@@ -227,6 +230,8 @@ class CustomModelCheckpoint(ModelCheckpoint):
227
230
 
228
231
  def __init__( # noqa: D107
229
232
  self,
233
+ registry_uploader_callback: RegistryUploaderCallback,
234
+ uploading_mode: Literal["only-best", "every-checkpoint"] = "only-best",
230
235
  dirpath: _PATH | None = None,
231
236
  filename: str | None = None,
232
237
  monitor: str | None = None,
@@ -242,8 +247,10 @@ class CustomModelCheckpoint(ModelCheckpoint):
242
247
  every_n_epochs: int | None = None,
243
248
  save_on_train_epoch_end: bool | None = None,
244
249
  enable_version_counter: bool = True,
245
- registry_uploader_callback: RegistryUploaderCallback | None = None,
246
250
  ) -> None:
251
+ self.registry_uploader_callback = registry_uploader_callback
252
+ self.process_group: dist.ProcessGroup | None = None
253
+ self.uploading_mode = uploading_mode
247
254
  super().__init__(
248
255
  dirpath=dirpath,
249
256
  filename=filename,
@@ -261,20 +268,32 @@ class CustomModelCheckpoint(ModelCheckpoint):
261
268
  save_on_train_epoch_end=save_on_train_epoch_end,
262
269
  enable_version_counter=enable_version_counter,
263
270
  )
264
- self.registry_uploader_callback = registry_uploader_callback
265
- self._custom_best_model_path = cast(str, self.best_model_path)
266
271
  return
267
272
 
268
- @property
269
- def best_model_path(self) -> str:
270
- """Best model path."""
271
- return self._custom_best_model_path
273
+ @override
274
+ def setup(
275
+ self,
276
+ trainer: pl.Trainer,
277
+ pl_module: pl.LightningModule | KostylLightningModule,
278
+ stage: str,
279
+ ) -> None:
280
+ super().setup(trainer, pl_module, stage)
281
+ if isinstance(pl_module, KostylLightningModule):
282
+ self.process_group = pl_module.get_process_group()
283
+ return
272
284
 
273
- @best_model_path.setter
274
- def best_model_path(self, value: str) -> None:
275
- self._custom_best_model_path = value
276
- if self.registry_uploader_callback is not None:
277
- self.registry_uploader_callback.best_model_path = value
285
+ @override
286
+ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
287
+ super()._save_checkpoint(trainer, filepath)
288
+ if dist.is_initialized():
289
+ dist.barrier(group=self.process_group)
290
+ if trainer.is_global_zero and self.registry_uploader_callback is not None:
291
+ match self.uploading_mode:
292
+ case "every-checkpoint":
293
+ self.registry_uploader_callback.upload_checkpoint(filepath)
294
+ case "only-best":
295
+ if filepath == self.best_model_path:
296
+ self.registry_uploader_callback.upload_checkpoint(filepath)
278
297
  return
279
298
 
280
299
 
@@ -283,28 +302,44 @@ def setup_checkpoint_callback(
283
302
  ckpt_cfg: CheckpointConfig,
284
303
  save_weights_only: bool = True,
285
304
  registry_uploader_callback: RegistryUploaderCallback | None = None,
286
- ) -> CustomModelCheckpoint:
305
+ uploading_mode: Literal["only-best", "every-checkpoint"] | None = None,
306
+ ) -> ModelCheckpointWithRegistryUploader | ModelCheckpoint:
287
307
  """
288
- Sets up a ModelCheckpoint callback for PyTorch Lightning.
308
+ Create and configure a checkpoint callback for model saving.
289
309
 
290
- This function prepares a checkpoint directory and configures a ModelCheckpoint
291
- callback based on the provided configuration. If the directory already exists,
292
- it is removed (only by the main process) to ensure a clean start. Otherwise,
293
- the directory is created.
310
+ Creates the checkpoint directory (removing existing one if present) and returns
311
+ a configured callback for saving models during training. When registry_uploader_callback
312
+ is provided, returns an extended version with support for uploading checkpoints to a remote registry.
294
313
 
295
314
  Args:
296
- dirpath (Path): The path to the directory where checkpoints will be saved.
297
- ckpt_cfg (CheckpointConfig): Configuration object containing checkpoint
298
- settings such as filename, save_top_k, monitor, and mode.
299
- save_weights_only (bool, optional): Whether to save only the model weights
300
- or the full model. Defaults to True.
301
- registry_uploader_callback (RegistryUploaderCallback | None, optional):
302
- An optional callback for uploading checkpoints to a registry. Defaults to None.
315
+ dirpath: Path to the directory for saving checkpoints.
316
+ ckpt_cfg: Checkpoint configuration (filename, monitor, mode, save_top_k).
317
+ save_weights_only: If True, only model weights are saved without optimizer and lr-scheduler state.
318
+ Defaults to True.
319
+ registry_uploader_callback: Optional callback for uploading checkpoints to a remote registry.
320
+ Must be specified together with uploading_mode.
321
+ uploading_mode: Checkpoint upload mode:
322
+ - "only-best": only the best checkpoint is uploaded
323
+ - "every-checkpoint": every saved checkpoint is uploaded
324
+ Must be specified together with registry_uploader_callback.
303
325
 
304
326
  Returns:
305
- ModelCheckpoint: The configured ModelCheckpoint callback instance.
327
+ ModelCheckpointWithRegistryUploader if registry_uploader_callback is provided,
328
+ otherwise standard ModelCheckpoint.
329
+
330
+ Raises:
331
+ ValueError: If only one of registry_uploader_callback or uploading_mode is None.
332
+
333
+ Note:
334
+ If the dirpath directory already exists, it will be removed and recreated
335
+ (only on the main process in distributed training).
306
336
 
307
337
  """
338
+ if (registry_uploader_callback is None) != (uploading_mode is None):
339
+ raise ValueError(
340
+ "Both registry_uploader_callback and uploading_mode must be provided or neither."
341
+ )
342
+
308
343
  if dirpath.exists():
309
344
  if is_main_process():
310
345
  logger.warning(f"Checkpoint directory {dirpath} already exists.")
@@ -314,14 +349,26 @@ def setup_checkpoint_callback(
314
349
  logger.info(f"Creating checkpoint directory {dirpath}.")
315
350
  dirpath.mkdir(parents=True, exist_ok=True)
316
351
 
317
- checkpoint_callback = CustomModelCheckpoint(
318
- dirpath=dirpath,
319
- filename=ckpt_cfg.filename,
320
- save_top_k=ckpt_cfg.save_top_k,
321
- monitor=ckpt_cfg.monitor,
322
- mode=ckpt_cfg.mode,
323
- verbose=True,
324
- save_weights_only=save_weights_only,
325
- registry_uploader_callback=registry_uploader_callback,
326
- )
352
+ if (registry_uploader_callback is not None) and (uploading_mode is not None):
353
+ checkpoint_callback = ModelCheckpointWithRegistryUploader(
354
+ dirpath=dirpath,
355
+ filename=ckpt_cfg.filename,
356
+ save_top_k=ckpt_cfg.save_top_k,
357
+ monitor=ckpt_cfg.monitor,
358
+ mode=ckpt_cfg.mode,
359
+ verbose=True,
360
+ save_weights_only=save_weights_only,
361
+ registry_uploader_callback=registry_uploader_callback,
362
+ uploading_mode=uploading_mode,
363
+ )
364
+ else:
365
+ checkpoint_callback = ModelCheckpoint(
366
+ dirpath=dirpath,
367
+ filename=ckpt_cfg.filename,
368
+ save_top_k=ckpt_cfg.save_top_k,
369
+ monitor=ckpt_cfg.monitor,
370
+ mode=ckpt_cfg.mode,
371
+ verbose=True,
372
+ save_weights_only=save_weights_only,
373
+ )
327
374
  return checkpoint_callback
@@ -0,0 +1,122 @@
1
+ from abc import ABC
2
+ from abc import abstractmethod
3
+ from collections.abc import Callable
4
+ from pathlib import Path
5
+ from typing import override
6
+
7
+ from clearml import OutputModel
8
+ from clearml import Task
9
+
10
+ from kostyl.ml.clearml.logging_utils import find_version_in_tags
11
+ from kostyl.ml.clearml.logging_utils import increment_version
12
+ from kostyl.utils.logging import setup_logger
13
+
14
+
15
+ logger = setup_logger()
16
+
17
+
18
+ class RegistryUploaderCallback(ABC):
19
+ """Abstract Lightning callback responsible for tracking and uploading the best-performing model checkpoint."""
20
+
21
+ @abstractmethod
22
+ def upload_checkpoint(self, path: str | Path) -> None:
23
+ """Upload the checkpoint located at the given path to the configured registry backend."""
24
+ raise NotImplementedError
25
+
26
+
27
+ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
28
+ """PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
29
+
30
+ def __init__(
31
+ self,
32
+ task: Task,
33
+ output_model_name: str,
34
+ output_model_tags: list[str] | None = None,
35
+ verbose: bool = True,
36
+ enable_tag_versioning: bool = True,
37
+ label_enumeration: dict[str, int] | None = None,
38
+ config_dict: dict[str, str] | None = None,
39
+ ) -> None:
40
+ """
41
+ Initializes the ClearMLRegistryUploaderCallback.
42
+
43
+ Args:
44
+ task: ClearML task.
45
+ ckpt_callback: ModelCheckpoint instance used by Trainer.
46
+ output_model_name: Name for the ClearML output model.
47
+ output_model_tags: Tags for the output model.
48
+ verbose: Whether to log messages.
49
+ label_enumeration: Optional mapping of label names to integer IDs.
50
+ config_dict: Optional configuration dictionary to associate with the model.
51
+ enable_tag_versioning: Whether to enable versioning in tags. If True,
52
+ the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
53
+
54
+ """
55
+ super().__init__()
56
+ if output_model_tags is None:
57
+ output_model_tags = []
58
+
59
+ self.task = task
60
+ self.output_model_name = output_model_name
61
+ self.output_model_tags = output_model_tags
62
+ self.config_dict = config_dict
63
+ self.label_enumeration = label_enumeration
64
+ self.verbose = verbose
65
+ self.enable_tag_versioning = enable_tag_versioning
66
+
67
+ self.best_model_path: str = ""
68
+
69
+ self._output_model: OutputModel | None = None
70
+ self._last_uploaded_model_path: str = ""
71
+ self._upload_callback: Callable | None = None
72
+ return
73
+
74
+ def _create_output_model(self) -> OutputModel:
75
+ if self.enable_tag_versioning:
76
+ version = find_version_in_tags(self.output_model_tags)
77
+ if version is None:
78
+ self.output_model_tags.append("v1.0")
79
+ else:
80
+ new_version = increment_version(version)
81
+ self.output_model_tags.remove(version)
82
+ self.output_model_tags.append(new_version)
83
+
84
+ if "LightningCheckpoint" not in self.output_model_tags:
85
+ self.output_model_tags.append("LightningCheckpoint")
86
+
87
+ return OutputModel(
88
+ task=self.task,
89
+ name=self.output_model_name,
90
+ framework="PyTorch",
91
+ tags=self.output_model_tags,
92
+ config_dict=None,
93
+ label_enumeration=self.label_enumeration,
94
+ )
95
+
96
+ @override
97
+ def upload_checkpoint(
98
+ self,
99
+ path: str | Path,
100
+ ) -> None:
101
+ if isinstance(path, Path):
102
+ path = str(path)
103
+ if path == self._last_uploaded_model_path:
104
+ if self.verbose:
105
+ logger.info("Model unchanged since last upload")
106
+ return
107
+
108
+ if self._output_model is None:
109
+ self._output_model = self._create_output_model()
110
+
111
+ if self.verbose:
112
+ logger.info(f"Uploading model from {path}")
113
+
114
+ self._output_model.update_weights(
115
+ path,
116
+ auto_delete_file=False,
117
+ async_enable=False,
118
+ )
119
+ self._output_model.update_design(config_dict=self.config_dict)
120
+
121
+ self._last_uploaded_model_path = path
122
+ return
@@ -11,20 +11,18 @@ from .base import BaseScheduler
11
11
  class _CosineSchedulerCore(BaseScheduler):
12
12
  def __init__(
13
13
  self,
14
- param_group_field: str,
15
- total_iters: int,
16
- base_value: float,
14
+ param_name: str,
15
+ num_iters: int,
16
+ start_value: float,
17
17
  final_value: float,
18
- warmup_iters_ratio: float | None = None,
18
+ warmup_ratio: float | None = None,
19
19
  warmup_value: float | None = None,
20
20
  freeze_ratio: float | None = None,
21
21
  ) -> None:
22
- if warmup_iters_ratio is not None:
23
- if not (0 < warmup_iters_ratio < 1):
24
- raise ValueError(
25
- f"Warmup ratio must be in (0, 1), got {warmup_iters_ratio}."
26
- )
27
- if (warmup_value is None) != (warmup_iters_ratio is None):
22
+ if warmup_ratio is not None:
23
+ if not (0 < warmup_ratio < 1):
24
+ raise ValueError(f"Warmup ratio must be in (0, 1), got {warmup_ratio}.")
25
+ if (warmup_value is None) != (warmup_ratio is None):
28
26
  raise ValueError(
29
27
  "Both warmup_ratio and warmup_value must be provided or neither."
30
28
  )
@@ -32,46 +30,46 @@ class _CosineSchedulerCore(BaseScheduler):
32
30
  if not (0 < freeze_ratio < 1):
33
31
  raise ValueError(f"Freeze ratio must be in (0, 1), got {freeze_ratio}.")
34
32
 
35
- self.param_group_field = param_group_field
36
- self.total_iters = total_iters
37
- self.base_value = base_value
33
+ self.param_name = param_name
34
+ self.num_iters = num_iters
35
+ self.start_value = start_value
38
36
  self.final_value = final_value
39
37
 
40
- self.warmup_iters_ratio = warmup_iters_ratio
38
+ self.warmup_ratio = warmup_ratio
41
39
  self.warmup_value = warmup_value
42
40
 
43
41
  self.freeze_ratio = freeze_ratio
44
42
 
45
43
  self.scheduler_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
46
- self.current_value_ = self.base_value
44
+ self.current_value_ = self.start_value
47
45
  return
48
46
 
49
47
  def _create_scheduler(self) -> None:
50
48
  # Create freeze schedule
51
49
  if self.freeze_ratio is not None:
52
- freeze_iters = int(self.total_iters * self.freeze_ratio)
50
+ freeze_iters = int(self.num_iters * self.freeze_ratio)
53
51
  freeze_schedule = np.zeros(freeze_iters, dtype=np.float64)
54
52
  else:
55
53
  freeze_iters = 0
56
54
  freeze_schedule = np.array([], dtype=np.float64)
57
55
 
58
56
  # Create linear warmup schedule
59
- if self.warmup_iters_ratio is not None and self.warmup_value is not None:
60
- warmup_iters = int(self.total_iters * self.warmup_iters_ratio)
57
+ if self.warmup_ratio is not None and self.warmup_value is not None:
58
+ warmup_iters = int(self.num_iters * self.warmup_ratio)
61
59
  warmup_schedule = np.linspace(
62
- self.warmup_value, self.base_value, warmup_iters, dtype=np.float64
60
+ self.warmup_value, self.start_value, warmup_iters, dtype=np.float64
63
61
  )
64
62
  else:
65
63
  warmup_iters = 0
66
64
  warmup_schedule = np.array([], dtype=np.float64)
67
65
 
68
- cosine_annealing_iters = self.total_iters - warmup_iters - freeze_iters
66
+ cosine_annealing_iters = self.num_iters - warmup_iters - freeze_iters
69
67
  if cosine_annealing_iters <= 0:
70
68
  raise ValueError("Cosine annealing iters must be > 0.")
71
69
 
72
70
  # Create cosine schedule
73
71
  iters = np.arange(cosine_annealing_iters)
74
- schedule = self.final_value + 0.5 * (self.base_value - self.final_value) * (
72
+ schedule = self.final_value + 0.5 * (self.start_value - self.final_value) * (
75
73
  1 + np.cos(np.pi * iters / len(iters))
76
74
  )
77
75
 
@@ -80,9 +78,9 @@ class _CosineSchedulerCore(BaseScheduler):
80
78
  (freeze_schedule, warmup_schedule, schedule)
81
79
  )
82
80
 
83
- if len(self.scheduler_values) != self.total_iters:
81
+ if len(self.scheduler_values) != self.num_iters:
84
82
  raise ValueError(
85
- f"Scheduler length ({len(self.scheduler_values)}) does not match total_iters ({self.total_iters})."
83
+ f"Scheduler length ({len(self.scheduler_values)}) does not match num_iters ({self.num_iters})."
86
84
  )
87
85
  return
88
86
 
@@ -100,7 +98,7 @@ class _CosineSchedulerCore(BaseScheduler):
100
98
  if len(self.scheduler_values) == 0:
101
99
  self._create_scheduler()
102
100
 
103
- if it >= self.total_iters:
101
+ if it >= self.num_iters:
104
102
  value: float = self.final_value
105
103
  else:
106
104
  value: float = self.scheduler_values[it]
@@ -109,20 +107,20 @@ class _CosineSchedulerCore(BaseScheduler):
109
107
 
110
108
  @override
111
109
  def current_value(self) -> dict[str, float]:
112
- return {self.param_group_field: self.current_value_}
110
+ return {self.param_name: self.current_value_}
113
111
 
114
112
 
115
113
  class CosineScheduler(_CosineSchedulerCore):
116
- """Implements a cosine scheduler for adjusting parameter values in torch.optim.Optimizer."""
114
+ """Applies a cosine schedule to an optimizer param-group field."""
117
115
 
118
116
  def __init__(
119
117
  self,
120
118
  optimizer: torch.optim.Optimizer,
121
119
  param_group_field: str,
122
- total_iters: int,
123
- base_value: float,
120
+ num_iters: int,
121
+ start_value: float,
124
122
  final_value: float,
125
- warmup_iters_ratio: float | None = None,
123
+ warmup_ratio: float | None = None,
126
124
  warmup_value: float | None = None,
127
125
  freeze_ratio: float | None = None,
128
126
  multiplier_field: str | None = None,
@@ -131,21 +129,21 @@ class CosineScheduler(_CosineSchedulerCore):
131
129
  ignore_if_field: str | None = None,
132
130
  ) -> None:
133
131
  """
134
- Initialize the scheduler with optimizer and scheduling parameters.
132
+ Configure cosine scheduling for matching optimizer groups.
135
133
 
136
134
  Args:
137
- optimizer: PyTorch optimizer to schedule parameters for.
138
- param_group_field: Name of the parameter group field to modify (e.g., 'lr', 'weight_decay').
139
- total_iters: Total number of iterations for the scheduling.
140
- base_value: Initial value for the parameter.
141
- final_value: Final value for the parameter at the end of scheduling.
142
- warmup_iters_ratio: Ratio of total iterations to use for warmup phase. Defaults to None.
143
- warmup_value: Value to use during warmup phase. Defaults to None.
144
- freeze_ratio: Ratio of total iterations to freeze parameter updates. Defaults to None.
145
- multiplier_field: Field name for multiplier values in parameter groups. Defaults to None.
146
- skip_if_zero: Whether to skip scheduling if the parameter value is zero. Defaults to False.
147
- apply_if_field: Field name that must be present to apply scheduling. Defaults to None.
148
- ignore_if_field: Field name that when present causes scheduling to be ignored. Defaults to None.
135
+ optimizer: Optimizer whose param groups are updated in-place.
136
+ param_group_field: Name of the field that receives the scheduled value.
137
+ num_iters: Number of scheduler iterations before clamping at ``final_value``.
138
+ start_value: Value used on the first cosine step (after warmup/freeze).
139
+ final_value: Value approached as iterations progress.
140
+ warmup_ratio: Optional fraction of iterations to linearly ramp from ``warmup_value`` to ``start_value``.
141
+ warmup_value: Starting value for the warmup ramp.
142
+ freeze_ratio: Optional fraction of iterations to keep the value frozen at zero at the beginning.
143
+ multiplier_field: Optional per-group multiplier applied to the scheduled value.
144
+ skip_if_zero: Leave groups untouched when their target field equals zero.
145
+ apply_if_field: Require this flag to be present in a param group before updating.
146
+ ignore_if_field: Skip groups that declare this flag.
149
147
 
150
148
  """
151
149
  self.apply_if_field = apply_if_field
@@ -154,14 +152,15 @@ class CosineScheduler(_CosineSchedulerCore):
154
152
  self.multiplier_field = multiplier_field
155
153
  self.skip_if_zero = skip_if_zero
156
154
  super().__init__(
157
- param_group_field=param_group_field,
158
- total_iters=total_iters,
159
- base_value=base_value,
155
+ param_name=param_group_field,
156
+ num_iters=num_iters,
157
+ start_value=start_value,
160
158
  final_value=final_value,
161
- warmup_iters_ratio=warmup_iters_ratio,
159
+ warmup_ratio=warmup_ratio,
162
160
  warmup_value=warmup_value,
163
161
  freeze_ratio=freeze_ratio,
164
162
  )
163
+ self.param_group_field = param_group_field
165
164
  return
166
165
 
167
166
  @override
@@ -194,14 +193,7 @@ class CosineScheduler(_CosineSchedulerCore):
194
193
 
195
194
 
196
195
  class CosineParamScheduler(_CosineSchedulerCore):
197
- """
198
- CosineParamScheduler adjusts a parameter value using a cosine annealing scheduler.
199
-
200
- This class provides a mechanism to schedule the value of a parameter over a
201
- predefined number of iterations. It supports linear warm-up and optional freezing
202
- periods before the cosine annealing wave begins. The scheduler can be used to
203
- gradually transition a parameter value from a starting value to a final value.
204
- """
196
+ """Standalone cosine scheduler for non-optimizer parameters."""
205
197
 
206
198
  @override
207
199
  def step(self, it: int) -> float:
@@ -0,0 +1,153 @@
1
+ from typing import Any
2
+ from typing import override
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import torch
7
+
8
+ from .base import BaseScheduler
9
+
10
+
11
+ class _LinearScheduleBase(BaseScheduler):
12
+ def __init__(
13
+ self,
14
+ param_name: str,
15
+ num_iters: int,
16
+ start_value: float,
17
+ final_value: float,
18
+ ) -> None:
19
+ self.param_name = param_name
20
+ self.num_iters = num_iters
21
+ self.start_value = start_value
22
+ self.final_value = final_value
23
+
24
+ self.scheduler_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
25
+ self.current_value_ = self.start_value
26
+ return
27
+
28
+ def _create_scheduler(self) -> None:
29
+ self.scheduler_values = np.linspace(
30
+ self.start_value, self.final_value, num=self.num_iters, dtype=np.float64
31
+ )
32
+ if len(self.scheduler_values) != self.num_iters:
33
+ raise ValueError(
34
+ f"Scheduler length ({len(self.scheduler_values)}) does not match total_iters ({self.num_iters})."
35
+ )
36
+ return
37
+
38
+ @override
39
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
40
+ super().load_state_dict(state_dict)
41
+ self.scheduler_values = np.array([], dtype=np.float64)
42
+ return
43
+
44
+ @override
45
+ def step(self, it: int) -> None | float:
46
+ raise NotImplementedError
47
+
48
+ def _get_value(self, it: int) -> float:
49
+ if len(self.scheduler_values) == 0:
50
+ self._create_scheduler()
51
+
52
+ if it >= self.num_iters:
53
+ value: float = self.final_value
54
+ else:
55
+ value: float = self.scheduler_values[it]
56
+ self.current_value_ = value
57
+ return value
58
+
59
+ @override
60
+ def current_value(self) -> dict[str, float]:
61
+ return {self.param_name: self.current_value_}
62
+
63
+
64
+ class LinearScheduler(_LinearScheduleBase):
65
+ """Implements a linear scheduler for adjusting parameter values in torch.optim.Optimizer."""
66
+
67
+ def __init__(
68
+ self,
69
+ optimizer: torch.optim.Optimizer,
70
+ param_group_field: str,
71
+ num_iters: int,
72
+ start_value: float,
73
+ final_value: float,
74
+ multiplier_field: str | None = None,
75
+ skip_if_zero: bool = False,
76
+ apply_if_field: str | None = None,
77
+ ignore_if_field: str | None = None,
78
+ ) -> None:
79
+ """
80
+ Configure which optimizer groups get a linear value schedule.
81
+
82
+ Args:
83
+ optimizer: Optimizer whose param groups are updated in-place.
84
+ param_group_field: Name of the field that receives the scheduled value.
85
+ num_iters: Number of scheduler iterations before clamping at ``final_value``.
86
+ start_value: Value used on the first iteration.
87
+ final_value: Value used once ``num_iters`` iterations are consumed.
88
+ multiplier_field: Optional per-group multiplier applied to the scheduled value.
89
+ skip_if_zero: Leave groups untouched when their target field equals zero.
90
+ apply_if_field: Require this flag to be present in a param group before updating.
91
+ ignore_if_field: Skip groups that declare this flag.
92
+
93
+ """
94
+ self.apply_if_field = apply_if_field
95
+ self.ignore_if_field = ignore_if_field
96
+ self.optimizer = optimizer
97
+ self.multiplier_field = multiplier_field
98
+ self.skip_if_zero = skip_if_zero
99
+ super().__init__(
100
+ param_name=param_group_field,
101
+ num_iters=num_iters,
102
+ start_value=start_value,
103
+ final_value=final_value,
104
+ )
105
+ self.param_group_field = param_group_field
106
+ return
107
+
108
+ @override
109
+ def step(self, it: int) -> None:
110
+ value = self._get_value(it)
111
+ for pg in self.optimizer.param_groups:
112
+ if self.param_group_field not in pg:
113
+ raise ValueError(
114
+ f"Parameter group field '{self.param_group_field}' not found in optimizer parameter groups."
115
+ )
116
+
117
+ if (self.apply_if_field is not None) and (self.apply_if_field not in pg):
118
+ continue
119
+
120
+ if (self.ignore_if_field is not None) and (self.ignore_if_field in pg):
121
+ continue
122
+
123
+ if self.skip_if_zero and pg[self.param_group_field] == 0:
124
+ continue
125
+
126
+ if self.multiplier_field is not None:
127
+ if self.multiplier_field not in pg:
128
+ multiplier = 1.0
129
+ else:
130
+ multiplier = pg[self.multiplier_field]
131
+ pg[self.param_group_field] = value * multiplier
132
+ else:
133
+ pg[self.param_group_field] = value
134
+ return
135
+
136
+
137
+ class LinearParamScheduler(_LinearScheduleBase):
138
+ """LinearParamScheduler adjusts a parameter value using a linear scheduler."""
139
+
140
+ @override
141
+ def step(self, it: int) -> float:
142
+ """
143
+ Computes the value corresponding to the given iteration step.
144
+
145
+ Args:
146
+ it: The current iteration index used for value computation.
147
+
148
+ Returns:
149
+ The computed value for the provided iteration step as a float.
150
+
151
+ """
152
+ value = self._get_value(it)
153
+ return value
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "kostyl-toolkit"
3
- version = "0.1.24"
3
+ version = "0.1.26"
4
4
  description = "Kickass Orchestration System for Training, Yielding & Logging "
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -1,193 +0,0 @@
1
- from abc import ABC
2
- from abc import abstractmethod
3
- from collections.abc import Callable
4
- from functools import partial
5
- from typing import Literal
6
- from typing import override
7
-
8
- from clearml import OutputModel
9
- from clearml import Task
10
- from lightning import Trainer
11
- from lightning.pytorch.callbacks import Callback
12
-
13
- from kostyl.ml.clearml.logging_utils import find_version_in_tags
14
- from kostyl.ml.clearml.logging_utils import increment_version
15
- from kostyl.ml.lightning import KostylLightningModule
16
- from kostyl.utils.logging import setup_logger
17
-
18
-
19
- logger = setup_logger()
20
-
21
-
22
- class RegistryUploaderCallback(Callback, ABC):
23
- """Abstract Lightning callback responsible for tracking and uploading the best-performing model checkpoint."""
24
-
25
- @property
26
- @abstractmethod
27
- def best_model_path(self) -> str:
28
- """Return the file system path pointing to the best model artifact produced during training."""
29
- raise NotImplementedError
30
-
31
- @best_model_path.setter
32
- @abstractmethod
33
- def best_model_path(self, value: str) -> None:
34
- raise NotImplementedError
35
-
36
- @abstractmethod
37
- def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
38
- raise NotImplementedError
39
-
40
-
41
- class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
42
- """PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
43
-
44
- def __init__(
45
- self,
46
- task: Task,
47
- output_model_name: str,
48
- output_model_tags: list[str] | None = None,
49
- verbose: bool = True,
50
- enable_tag_versioning: bool = True,
51
- label_enumeration: dict[str, int] | None = None,
52
- config_dict: dict[str, str] | None = None,
53
- uploading_frequency: Literal[
54
- "after-every-eval", "on-train-end"
55
- ] = "on-train-end",
56
- ) -> None:
57
- """
58
- Initializes the ClearMLRegistryUploaderCallback.
59
-
60
- Args:
61
- task: ClearML task.
62
- ckpt_callback: ModelCheckpoint instance used by Trainer.
63
- output_model_name: Name for the ClearML output model.
64
- output_model_tags: Tags for the output model.
65
- verbose: Whether to log messages.
66
- label_enumeration: Optional mapping of label names to integer IDs.
67
- config_dict: Optional configuration dictionary to associate with the model.
68
- enable_tag_versioning: Whether to enable versioning in tags. If True,
69
- the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
70
- uploading_frequency: When to upload:
71
- - "after-every-eval": after each validation phase.
72
- - "on-train-end": once at the end of training.
73
-
74
- """
75
- super().__init__()
76
- if output_model_tags is None:
77
- output_model_tags = []
78
-
79
- self.task = task
80
- self.output_model_name = output_model_name
81
- self.output_model_tags = output_model_tags
82
- self.config_dict = config_dict
83
- self.label_enumeration = label_enumeration
84
- self.verbose = verbose
85
- self.uploading_frequency = uploading_frequency
86
- self.enable_tag_versioning = enable_tag_versioning
87
-
88
- self._output_model: OutputModel | None = None
89
- self._last_uploaded_model_path: str = ""
90
- self._best_model_path: str = ""
91
- self._upload_callback: Callable | None = None
92
- return
93
-
94
- @property
95
- @override
96
- def best_model_path(self) -> str:
97
- return self._best_model_path
98
-
99
- @best_model_path.setter
100
- @override
101
- def best_model_path(self, value: str) -> None:
102
- self._best_model_path = value
103
- if self._upload_callback is not None:
104
- self._upload_callback()
105
- return
106
-
107
- def _create_output_model(self, pl_module: "KostylLightningModule") -> OutputModel:
108
- if self.enable_tag_versioning:
109
- version = find_version_in_tags(self.output_model_tags)
110
- if version is None:
111
- self.output_model_tags.append("v1.0")
112
- else:
113
- new_version = increment_version(version)
114
- self.output_model_tags.remove(version)
115
- self.output_model_tags.append(new_version)
116
-
117
- if "LightningCheckpoint" not in self.output_model_tags:
118
- self.output_model_tags.append("LightningCheckpoint")
119
-
120
- if self.config_dict is None:
121
- config = pl_module.model_config
122
- if config is not None:
123
- config = config.to_dict()
124
- else:
125
- config = self.config_dict
126
-
127
- return OutputModel(
128
- task=self.task,
129
- name=self.output_model_name,
130
- framework="PyTorch",
131
- tags=self.output_model_tags,
132
- config_dict=None,
133
- label_enumeration=self.label_enumeration,
134
- )
135
-
136
- @override
137
- def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
138
- if not self._best_model_path or (
139
- self._best_model_path == self._last_uploaded_model_path
140
- ):
141
- if not self._best_model_path:
142
- if self.verbose:
143
- logger.info("No best model found yet to upload")
144
- elif self._best_model_path == self._last_uploaded_model_path:
145
- if self.verbose:
146
- logger.info("Best model unchanged since last upload")
147
- self._upload_callback = partial(self._upload_best_checkpoint, pl_module)
148
- return
149
- self._upload_callback = None
150
-
151
- if self._output_model is None:
152
- self._output_model = self._create_output_model(pl_module)
153
-
154
- if self.verbose:
155
- logger.info(f"Uploading best model from {self._best_model_path}")
156
-
157
- self._output_model.update_weights(
158
- self._best_model_path,
159
- auto_delete_file=False,
160
- async_enable=False,
161
- )
162
- if self.config_dict is None:
163
- config = pl_module.model_config
164
- if config is not None:
165
- config = config.to_dict()
166
- else:
167
- config = self.config_dict
168
- self._output_model.update_design(config_dict=config)
169
-
170
- self._last_uploaded_model_path = self._best_model_path
171
- return
172
-
173
- @override
174
- def on_validation_end(
175
- self, trainer: Trainer, pl_module: "KostylLightningModule"
176
- ) -> None:
177
- if self.uploading_frequency != "after-every-eval":
178
- return
179
- if not trainer.is_global_zero:
180
- return
181
-
182
- self._upload_best_checkpoint(pl_module)
183
- return
184
-
185
- @override
186
- def on_train_end(
187
- self, trainer: Trainer, pl_module: "KostylLightningModule"
188
- ) -> None:
189
- if not trainer.is_global_zero:
190
- return
191
-
192
- self._upload_best_checkpoint(pl_module)
193
- return