kostyl-toolkit 0.1.35__tar.gz → 0.1.36__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/PKG-INFO +1 -1
  2. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/configs/hyperparams.py +21 -5
  3. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/configs/training_settings.py +17 -6
  4. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/lightning/callbacks/checkpoint.py +8 -8
  5. kostyl_toolkit-0.1.36/kostyl/ml/lightning/utils.py +58 -0
  6. kostyl_toolkit-0.1.36/kostyl/ml/registry_uploader.py +126 -0
  7. kostyl_toolkit-0.1.36/kostyl/ml/schedulers/__init__.py +18 -0
  8. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/schedulers/base.py +9 -7
  9. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/schedulers/cosine.py +53 -24
  10. kostyl_toolkit-0.1.36/kostyl/ml/schedulers/cosine_with_plateu.py +277 -0
  11. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/schedulers/linear.py +36 -11
  12. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/utils/logging.py +1 -1
  13. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/pyproject.toml +2 -2
  14. kostyl_toolkit-0.1.35/kostyl/ml/lightning/training_utils.py +0 -241
  15. kostyl_toolkit-0.1.35/kostyl/ml/registry_uploader.py +0 -99
  16. kostyl_toolkit-0.1.35/kostyl/ml/schedulers/__init__.py +0 -6
  17. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/README.md +0 -0
  18. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/__init__.py +0 -0
  19. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/__init__.py +0 -0
  20. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/clearml/__init__.py +0 -0
  21. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/clearml/dataset_utils.py +0 -0
  22. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/clearml/logging_utils.py +0 -0
  23. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/clearml/pulling_utils.py +0 -0
  24. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/configs/__init__.py +0 -0
  25. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/configs/base_model.py +0 -0
  26. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/data_processing_utils.py +0 -0
  27. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/dist_utils.py +0 -0
  28. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/lightning/__init__.py +0 -0
  29. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/lightning/callbacks/__init__.py +0 -0
  30. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/lightning/callbacks/early_stopping.py +0 -0
  31. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/lightning/extensions/__init__.py +0 -0
  32. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/lightning/extensions/custom_module.py +0 -0
  33. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/lightning/extensions/pretrained_model.py +0 -0
  34. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/lightning/loggers/__init__.py +0 -0
  35. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/lightning/loggers/tb_logger.py +0 -0
  36. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/metrics_formatting.py +0 -0
  37. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/params_groups.py +0 -0
  38. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/ml/schedulers/composite.py +0 -0
  39. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/utils/__init__.py +0 -0
  40. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/utils/dict_manipulations.py +0 -0
  41. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.36}/kostyl/utils/fs.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.35
3
+ Version: 0.1.36
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -1,3 +1,5 @@
1
+ from typing import Literal
2
+
1
3
  from pydantic import BaseModel
2
4
  from pydantic import Field
3
5
  from pydantic import model_validator
@@ -8,11 +10,25 @@ from kostyl.utils.logging import setup_logger
8
10
  logger = setup_logger(fmt="only_message")
9
11
 
10
12
 
11
- class Optimizer(BaseModel):
12
- """Optimizer hyperparameters configuration."""
13
+ class AdamConfig(BaseModel):
14
+ """AdamW optimizer hyperparameters configuration."""
15
+
16
+ type: Literal["AdamW"] = "AdamW"
17
+ betas: tuple[float, float] = (0.9, 0.999)
18
+ is_adamw: bool = True
19
+
20
+
21
+ class AdamWithPrecisionConfig(BaseModel):
22
+ """Adam optimizer with low-precision hyperparameters configuration."""
23
+
24
+ type: Literal["Adam8bit", "Adam4bit", "AdamFp8"]
25
+ betas: tuple[float, float] = (0.9, 0.999)
26
+ block_size: int
27
+ bf16_stochastic_round: bool = False
28
+ is_adamw: bool = True
29
+
13
30
 
14
- adamw_beta1: float = 0.9
15
- adamw_beta2: float = 0.999
31
+ Optimizer = AdamConfig | AdamWithPrecisionConfig
16
32
 
17
33
 
18
34
  class Lr(BaseModel):
@@ -73,6 +89,6 @@ class HyperparamsConfig(BaseModel):
73
89
  """Model training hyperparameters configuration."""
74
90
 
75
91
  grad_clip_val: float | None = Field(default=None, gt=0, validate_default=False)
76
- optimizer: Optimizer = Optimizer()
92
+ optimizer: Optimizer
77
93
  lr: Lr
78
94
  weight_decay: WeightDecay
@@ -25,21 +25,31 @@ PRECISION = Literal[
25
25
  "16",
26
26
  "bf16",
27
27
  ]
28
+ DTYPE = Literal["float32", "float16", "bfloat16", "float64"]
29
+
30
+
31
+ class SingleDeviceStrategyConfig(BaseModel):
32
+ """Single device strategy configuration."""
33
+
34
+ type: Literal["single_device"]
28
35
 
29
36
 
30
37
  class FSDP1StrategyConfig(BaseModel):
31
38
  """Fully Sharded Data Parallel (FSDP) strategy configuration."""
32
39
 
33
40
  type: Literal["fsdp1"]
34
- param_dtype: Literal["float32", "float16", "bfloat16"]
35
- reduce_dtype: Literal["float32", "float16", "bfloat16"]
36
- buffer_dtype: Literal["float32", "float16", "bfloat16"]
41
+ param_dtype: DTYPE | None
42
+ reduce_dtype: DTYPE | None
43
+ buffer_dtype: DTYPE | None
37
44
 
38
45
 
39
- class SingleDeviceStrategyConfig(BaseModel):
40
- """Single device strategy configuration."""
46
+ class FSDP2StrategyConfig(BaseModel):
47
+ """Fully Sharded Data Parallel (FSDP) strategy configuration."""
41
48
 
42
- type: Literal["single_device"]
49
+ type: Literal["fsdp2"]
50
+ param_dtype: DTYPE | None
51
+ reduce_dtype: DTYPE | None
52
+ buffer_dtype: DTYPE | None
43
53
 
44
54
 
45
55
  class DDPStrategyConfig(BaseModel):
@@ -82,6 +92,7 @@ class CheckpointConfig(BaseModel):
82
92
  monitor: str = "val_loss"
83
93
  mode: str = "min"
84
94
  filename: str = "{epoch:02d}-{val_loss:.2f}"
95
+ save_weights_only: bool = True
85
96
 
86
97
 
87
98
  class DataConfig(BaseModel):
@@ -299,9 +299,9 @@ class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
299
299
  def setup_checkpoint_callback(
300
300
  dirpath: Path,
301
301
  ckpt_cfg: CheckpointConfig,
302
- save_weights_only: bool = True,
303
302
  registry_uploader_callback: RegistryUploaderCallback | None = None,
304
303
  uploading_strategy: Literal["only-best", "every-checkpoint"] | None = None,
304
+ remove_folder_if_exists: bool = True,
305
305
  ) -> ModelCheckpointWithRegistryUploader | ModelCheckpoint:
306
306
  """
307
307
  Create and configure a checkpoint callback for model saving.
@@ -313,14 +313,13 @@ def setup_checkpoint_callback(
313
313
  Args:
314
314
  dirpath: Path to the directory for saving checkpoints.
315
315
  ckpt_cfg: Checkpoint configuration (filename, monitor, mode, save_top_k).
316
- save_weights_only: If True, only model weights are saved without optimizer and lr-scheduler state.
317
- Defaults to True.
318
316
  registry_uploader_callback: Optional callback for uploading checkpoints to a remote registry.
319
317
  Must be specified together with uploading_strategy.
320
318
  uploading_strategy: Checkpoint upload mode:
321
319
  - "only-best": only the best checkpoint is uploaded
322
320
  - "every-checkpoint": every saved checkpoint is uploaded
323
321
  Must be specified together with registry_uploader_callback.
322
+ remove_folder_if_exists: If True, removes existing checkpoint directory before creating a new one.
324
323
 
325
324
  Returns:
326
325
  ModelCheckpointWithRegistryUploader if registry_uploader_callback is provided,
@@ -331,7 +330,7 @@ def setup_checkpoint_callback(
331
330
 
332
331
  Note:
333
332
  If the dirpath directory already exists, it will be removed and recreated
334
- (only on the main process in distributed training).
333
+ (only on the main process in distributed training) if remove_folder_if_exists is True.
335
334
 
336
335
  """
337
336
  if (registry_uploader_callback is None) != (uploading_strategy is None):
@@ -342,8 +341,9 @@ def setup_checkpoint_callback(
342
341
  if dirpath.exists():
343
342
  if is_main_process():
344
343
  logger.warning(f"Checkpoint directory {dirpath} already exists.")
345
- rmtree(dirpath)
346
- logger.warning(f"Removed existing checkpoint directory {dirpath}.")
344
+ if remove_folder_if_exists:
345
+ rmtree(dirpath)
346
+ logger.warning(f"Removed existing checkpoint directory {dirpath}.")
347
347
  else:
348
348
  logger.info(f"Creating checkpoint directory {dirpath}.")
349
349
  dirpath.mkdir(parents=True, exist_ok=True)
@@ -356,7 +356,7 @@ def setup_checkpoint_callback(
356
356
  monitor=ckpt_cfg.monitor,
357
357
  mode=ckpt_cfg.mode,
358
358
  verbose=True,
359
- save_weights_only=save_weights_only,
359
+ save_weights_only=ckpt_cfg.save_weights_only,
360
360
  registry_uploader_callback=registry_uploader_callback,
361
361
  uploading_mode=uploading_strategy,
362
362
  )
@@ -368,6 +368,6 @@ def setup_checkpoint_callback(
368
368
  monitor=ckpt_cfg.monitor,
369
369
  mode=ckpt_cfg.mode,
370
370
  verbose=True,
371
- save_weights_only=save_weights_only,
371
+ save_weights_only=ckpt_cfg.save_weights_only,
372
372
  )
373
373
  return checkpoint_callback
@@ -0,0 +1,58 @@
1
+ from typing import cast
2
+
3
+ import lightning as L
4
+ import torch.distributed as dist
5
+ from torch.distributed import ProcessGroup
6
+
7
+ from kostyl.ml.configs import DDPStrategyConfig
8
+ from kostyl.ml.configs import FSDP1StrategyConfig
9
+ from kostyl.ml.configs import SingleDeviceStrategyConfig
10
+ from kostyl.utils.logging import setup_logger
11
+
12
+
13
+ TRAINING_STRATEGIES = (
14
+ FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
15
+ )
16
+
17
+ logger = setup_logger(add_rank=True)
18
+
19
+
20
+ def estimate_total_steps(
21
+ trainer: L.Trainer, dp_process_group: ProcessGroup | None = None
22
+ ) -> int:
23
+ """
24
+ Estimates the total number of training steps with respect to data parallelism and gradient accumulation.
25
+
26
+ Args:
27
+ trainer: The PyTorch Lightning Trainer instance.
28
+ dp_process_group: The data parallel process group. If None, the world process group will be used.
29
+
30
+ """
31
+ if dist.is_initialized():
32
+ world_size = dist.get_world_size(dp_process_group)
33
+ else:
34
+ world_size = 1
35
+
36
+ datamodule = trainer.datamodule # type: ignore
37
+ if datamodule is None:
38
+ raise ValueError("Trainer must have a datamodule to estimate total steps.")
39
+ datamodule = cast(L.LightningDataModule, datamodule)
40
+
41
+ logger.info("Loading `train_dataloader` to estimate number of stepping batches.")
42
+ datamodule.setup("fit")
43
+
44
+ dataloader_len = len(datamodule.train_dataloader())
45
+ steps_per_epoch = dataloader_len // trainer.accumulate_grad_batches // world_size
46
+
47
+ if trainer.max_epochs is None:
48
+ raise ValueError("Trainer must have `max_epochs` set to estimate total steps.")
49
+ total_steps = steps_per_epoch * trainer.max_epochs
50
+
51
+ logger.info(
52
+ f"Total steps: {total_steps} (per-epoch: {steps_per_epoch}) "
53
+ f"-> Dataloader len: {dataloader_len} "
54
+ f"-> Accumulate grad batches: {trainer.accumulate_grad_batches} "
55
+ f"-> Epochs: {trainer.max_epochs} "
56
+ f"-> DataParallel size: {world_size}"
57
+ )
58
+ return total_steps
@@ -0,0 +1,126 @@
1
+ from abc import ABC
2
+ from abc import abstractmethod
3
+ from collections.abc import Callable
4
+ from functools import partial
5
+ from pathlib import Path
6
+ from typing import override
7
+
8
+ from clearml import OutputModel
9
+
10
+ from kostyl.utils.logging import setup_logger
11
+
12
+
13
+ logger = setup_logger()
14
+
15
+
16
+ class RegistryUploaderCallback(ABC):
17
+ """Abstract Lightning callback responsible for tracking and uploading the best-performing model checkpoint."""
18
+
19
+ @abstractmethod
20
+ def upload_checkpoint(self, path: str | Path) -> None:
21
+ """Upload the checkpoint located at the given path to the configured registry backend."""
22
+ raise NotImplementedError
23
+
24
+
25
+ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
26
+ """PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
27
+
28
+ def __init__(
29
+ self,
30
+ model_name: str,
31
+ config_dict: dict[str, str] | None = None,
32
+ label_enumeration: dict[str, int] | None = None,
33
+ tags: list[str] | None = None,
34
+ comment: str | None = None,
35
+ framework: str | None = None,
36
+ base_model_id: str | None = None,
37
+ new_model_per_upload: bool = True,
38
+ verbose: bool = True,
39
+ ) -> None:
40
+ """
41
+ Initializes the ClearMLRegistryUploaderCallback.
42
+
43
+ Args:
44
+ model_name: The name for the newly created model.
45
+ label_enumeration: The label enumeration dictionary of string (label) to integer (value) pairs.
46
+ config_dict: Optional configuration dictionary to associate with the model.
47
+ tags: A list of strings which are tags for the model.
48
+ comment: A comment / description for the model.
49
+ framework: The framework of the model (e.g., "PyTorch", "TensorFlow").
50
+ base_model_id: Optional ClearML model ID to use as a base for the new model
51
+ new_model_per_upload: Whether to create a new ClearML model
52
+ for every upload or update weights of the same model. When updating weights,
53
+ the last uploaded checkpoint will be replaced (and deleted).
54
+ verbose: Whether to log messages during upload.
55
+
56
+ """
57
+ super().__init__()
58
+ if base_model_id is not None and new_model_per_upload:
59
+ raise ValueError(
60
+ "Cannot set base_model_id when new_model_per_upload is True."
61
+ )
62
+
63
+ self.verbose = verbose
64
+ self.new_model_per_upload = new_model_per_upload
65
+ self.best_model_path: str = ""
66
+ self.config_dict = config_dict
67
+ self._output_model: OutputModel | None = None
68
+ self._last_uploaded_model_path: str = ""
69
+ self._upload_callback: Callable | None = None
70
+
71
+ self._validate_tags(tags)
72
+ self.model_fabric = partial(
73
+ OutputModel,
74
+ name=model_name,
75
+ label_enumeration=label_enumeration,
76
+ tags=tags,
77
+ comment=comment,
78
+ framework=framework,
79
+ base_model_id=base_model_id,
80
+ )
81
+ return
82
+
83
+ @staticmethod
84
+ def _validate_tags(tags: list[str] | None) -> None:
85
+ if tags is None:
86
+ return
87
+ if "LightningCheckpoint" not in tags:
88
+ tags.append("LightningCheckpoint")
89
+ return None
90
+
91
+ @property
92
+ def output_model_(self) -> OutputModel:
93
+ """Returns the OutputModel instance based on `new_model_per_upload` setting."""
94
+ if self.new_model_per_upload:
95
+ model = self.model_fabric()
96
+ self._output_model = self.model_fabric()
97
+ else:
98
+ if self._output_model is None:
99
+ self._output_model = self.model_fabric()
100
+ model = self._output_model
101
+ return model
102
+
103
+ @override
104
+ def upload_checkpoint(
105
+ self,
106
+ path: str | Path,
107
+ ) -> None:
108
+ if isinstance(path, Path):
109
+ path = str(path)
110
+ if path == self._last_uploaded_model_path:
111
+ if self.verbose:
112
+ logger.info("Model unchanged since last upload")
113
+ return
114
+
115
+ if self.verbose:
116
+ logger.info(f"Uploading model from {path}")
117
+
118
+ self.output_model_.update_weights(
119
+ path,
120
+ auto_delete_file=False,
121
+ async_enable=False,
122
+ )
123
+ self.output_model_.update_design(config_dict=self.config_dict)
124
+
125
+ self._last_uploaded_model_path = path
126
+ return
@@ -0,0 +1,18 @@
1
+ from .composite import CompositeScheduler
2
+ from .cosine import CosineParamScheduler
3
+ from .cosine import CosineScheduler
4
+ from .cosine_with_plateu import CosineWithPlateauParamScheduler
5
+ from .cosine_with_plateu import CosineWithPlateuScheduler
6
+ from .linear import LinearParamScheduler
7
+ from .linear import LinearScheduler
8
+
9
+
10
+ __all__ = [
11
+ "CompositeScheduler",
12
+ "CosineParamScheduler",
13
+ "CosineScheduler",
14
+ "CosineWithPlateauParamScheduler",
15
+ "CosineWithPlateuScheduler",
16
+ "LinearParamScheduler",
17
+ "LinearScheduler",
18
+ ]
@@ -6,18 +6,20 @@ from typing import Any
6
6
  class BaseScheduler(ABC):
7
7
  """Base class for learning rate schedulers."""
8
8
 
9
+ @abstractmethod
9
10
  def state_dict(self) -> dict[str, Any]:
10
11
  """Get the state as a state dictionary."""
11
- return {
12
- key: value
13
- for key, value in self.__dict__.items()
14
- if key not in ["optimizer", "scheduler_values"]
15
- }
12
+ raise NotImplementedError
16
13
 
14
+ @abstractmethod
17
15
  def load_state_dict(self, state_dict: dict[str, Any]) -> None:
18
16
  """Load the state from a state dictionary."""
19
- self.__dict__.update(state_dict)
20
- return
17
+ raise NotImplementedError
18
+
19
+ @abstractmethod
20
+ def _verify(self) -> None:
21
+ """Verify the scheduler configuration."""
22
+ raise NotImplementedError
21
23
 
22
24
  def __getstate__(self) -> dict[str, Any]:
23
25
  """Get the state for pickling."""
@@ -2,7 +2,6 @@ from typing import Any
2
2
  from typing import override
3
3
 
4
4
  import numpy as np
5
- import numpy.typing as npt
6
5
  import torch
7
6
 
8
7
  from .base import BaseScheduler
@@ -29,18 +28,24 @@ class _CosineSchedulerCore(BaseScheduler):
29
28
  if freeze_ratio is not None:
30
29
  if not (0 < freeze_ratio < 1):
31
30
  raise ValueError(f"Freeze ratio must be in (0, 1), got {freeze_ratio}.")
31
+ pre_annealing_ratio = (warmup_ratio if warmup_ratio is not None else 0) + (
32
+ freeze_ratio if freeze_ratio is not None else 0
33
+ )
34
+ if pre_annealing_ratio > 1:
35
+ raise ValueError(
36
+ "The sum of warmup_ratio and freeze_ratio must <= 1, got "
37
+ f"{pre_annealing_ratio}."
38
+ )
32
39
 
33
40
  self.param_name = param_name
34
41
  self.num_iters = num_iters
35
42
  self.base_value = base_value
36
43
  self.final_value = final_value
37
-
38
44
  self.warmup_ratio = warmup_ratio
39
45
  self.warmup_value = warmup_value
40
-
41
46
  self.freeze_ratio = freeze_ratio
42
47
 
43
- self.scheduler_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
48
+ self.scheduled_values: np.ndarray = np.array([], dtype=np.float64)
44
49
  self.current_value_ = self.base_value
45
50
  return
46
51
 
@@ -63,31 +68,29 @@ class _CosineSchedulerCore(BaseScheduler):
63
68
  warmup_iters = 0
64
69
  warmup_schedule = np.array([], dtype=np.float64)
65
70
 
71
+ # Create cosine annealing schedule
66
72
  cosine_annealing_iters = self.num_iters - warmup_iters - freeze_iters
67
- if cosine_annealing_iters <= 0:
68
- raise ValueError("Cosine annealing iters must be > 0.")
69
-
70
- # Create cosine schedule
71
- iters = np.arange(cosine_annealing_iters)
72
- schedule = self.final_value + 0.5 * (self.base_value - self.final_value) * (
73
- 1 + np.cos(np.pi * iters / len(iters))
74
- )
73
+ if cosine_annealing_iters > 0:
74
+ iters = np.arange(cosine_annealing_iters)
75
+ cosine_annealing_schedule = self.final_value + 0.5 * (
76
+ self.base_value - self.final_value
77
+ ) * (1 + np.cos(np.pi * iters / len(iters)))
78
+ else:
79
+ cosine_annealing_schedule = np.array([], dtype=np.float64)
75
80
 
76
81
  # Concatenate all parts of the schedule
77
- self.scheduler_values = np.concatenate(
78
- (freeze_schedule, warmup_schedule, schedule)
82
+ self.scheduled_values = np.concatenate(
83
+ (freeze_schedule, warmup_schedule, cosine_annealing_schedule)
79
84
  )
80
-
81
- if len(self.scheduler_values) != self.num_iters:
82
- raise ValueError(
83
- f"Scheduler length ({len(self.scheduler_values)}) does not match num_iters ({self.num_iters})."
84
- )
85
+ self._verify()
85
86
  return
86
87
 
87
88
  @override
88
- def load_state_dict(self, state_dict: dict[str, Any]) -> None:
89
- super().load_state_dict(state_dict)
90
- self.scheduler_values = np.array([], dtype=np.float64)
89
+ def _verify(self) -> None:
90
+ if len(self.scheduled_values) != self.num_iters:
91
+ raise ValueError(
92
+ f"Scheduler length ({len(self.scheduled_values)}) does not match num_iters ({self.num_iters})."
93
+ )
91
94
  return
92
95
 
93
96
  @override
@@ -95,13 +98,13 @@ class _CosineSchedulerCore(BaseScheduler):
95
98
  raise NotImplementedError
96
99
 
97
100
  def _get_value(self, it: int) -> float:
98
- if len(self.scheduler_values) == 0:
101
+ if len(self.scheduled_values) == 0:
99
102
  self._create_scheduler()
100
103
 
101
104
  if it >= self.num_iters:
102
105
  value: float = self.final_value
103
106
  else:
104
- value: float = self.scheduler_values[it]
107
+ value: float = self.scheduled_values[it]
105
108
  self.current_value_ = value
106
109
  return value
107
110
 
@@ -163,6 +166,21 @@ class CosineScheduler(_CosineSchedulerCore):
163
166
  self.param_group_field = param_group_field
164
167
  return
165
168
 
169
+ @override
170
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
171
+ self.__dict__.update(state_dict)
172
+ self.scheduled_values = np.array([], dtype=np.float64)
173
+ return
174
+
175
+ @override
176
+ def state_dict(self) -> dict[str, Any]:
177
+ state = {
178
+ k: v
179
+ for k, v in self.__dict__.items()
180
+ if k not in ["scheduled_values", "optimizer"]
181
+ }
182
+ return state
183
+
166
184
  @override
167
185
  def step(self, it: int) -> None:
168
186
  value = self._get_value(it)
@@ -209,3 +227,14 @@ class CosineParamScheduler(_CosineSchedulerCore):
209
227
  """
210
228
  value = self._get_value(it)
211
229
  return value
230
+
231
+ @override
232
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
233
+ self.__dict__.update(state_dict)
234
+ self.scheduled_values = np.array([], dtype=np.float64)
235
+ return
236
+
237
+ @override
238
+ def state_dict(self) -> dict[str, Any]:
239
+ state = {k: v for k, v in self.__dict__.items() if k != "scheduled_values"}
240
+ return state