kostyl-toolkit 0.1.35__py3-none-any.whl → 0.1.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
+ from typing import Literal
2
+
1
3
  from pydantic import BaseModel
2
4
  from pydantic import Field
3
5
  from pydantic import model_validator
@@ -8,11 +10,25 @@ from kostyl.utils.logging import setup_logger
8
10
  logger = setup_logger(fmt="only_message")
9
11
 
10
12
 
11
- class Optimizer(BaseModel):
12
- """Optimizer hyperparameters configuration."""
13
+ class AdamConfig(BaseModel):
14
+ """AdamW optimizer hyperparameters configuration."""
15
+
16
+ type: Literal["AdamW"] = "AdamW"
17
+ betas: tuple[float, float] = (0.9, 0.999)
18
+ is_adamw: bool = True
19
+
20
+
21
+ class AdamWithPrecisionConfig(BaseModel):
22
+ """Adam optimizer with low-precision hyperparameters configuration."""
23
+
24
+ type: Literal["Adam8bit", "Adam4bit", "AdamFp8"]
25
+ betas: tuple[float, float] = (0.9, 0.999)
26
+ block_size: int
27
+ bf16_stochastic_round: bool = False
28
+ is_adamw: bool = True
29
+
13
30
 
14
- adamw_beta1: float = 0.9
15
- adamw_beta2: float = 0.999
31
+ Optimizer = AdamConfig | AdamWithPrecisionConfig
16
32
 
17
33
 
18
34
  class Lr(BaseModel):
@@ -73,6 +89,6 @@ class HyperparamsConfig(BaseModel):
73
89
  """Model training hyperparameters configuration."""
74
90
 
75
91
  grad_clip_val: float | None = Field(default=None, gt=0, validate_default=False)
76
- optimizer: Optimizer = Optimizer()
92
+ optimizer: Optimizer
77
93
  lr: Lr
78
94
  weight_decay: WeightDecay
@@ -25,21 +25,31 @@ PRECISION = Literal[
25
25
  "16",
26
26
  "bf16",
27
27
  ]
28
+ DTYPE = Literal["float32", "float16", "bfloat16", "float64"]
29
+
30
+
31
+ class SingleDeviceStrategyConfig(BaseModel):
32
+ """Single device strategy configuration."""
33
+
34
+ type: Literal["single_device"]
28
35
 
29
36
 
30
37
  class FSDP1StrategyConfig(BaseModel):
31
38
  """Fully Sharded Data Parallel (FSDP) strategy configuration."""
32
39
 
33
40
  type: Literal["fsdp1"]
34
- param_dtype: Literal["float32", "float16", "bfloat16"]
35
- reduce_dtype: Literal["float32", "float16", "bfloat16"]
36
- buffer_dtype: Literal["float32", "float16", "bfloat16"]
41
+ param_dtype: DTYPE | None
42
+ reduce_dtype: DTYPE | None
43
+ buffer_dtype: DTYPE | None
37
44
 
38
45
 
39
- class SingleDeviceStrategyConfig(BaseModel):
40
- """Single device strategy configuration."""
46
+ class FSDP2StrategyConfig(BaseModel):
47
+ """Fully Sharded Data Parallel (FSDP) strategy configuration."""
41
48
 
42
- type: Literal["single_device"]
49
+ type: Literal["fsdp2"]
50
+ param_dtype: DTYPE | None
51
+ reduce_dtype: DTYPE | None
52
+ buffer_dtype: DTYPE | None
43
53
 
44
54
 
45
55
  class DDPStrategyConfig(BaseModel):
@@ -82,6 +92,7 @@ class CheckpointConfig(BaseModel):
82
92
  monitor: str = "val_loss"
83
93
  mode: str = "min"
84
94
  filename: str = "{epoch:02d}-{val_loss:.2f}"
95
+ save_weights_only: bool = True
85
96
 
86
97
 
87
98
  class DataConfig(BaseModel):
kostyl/ml/dist_utils.py CHANGED
@@ -4,38 +4,61 @@ from typing import Literal
4
4
 
5
5
  import torch.distributed as dist
6
6
 
7
+ from kostyl.utils.logging import KostylLogger
7
8
  from kostyl.utils.logging import setup_logger
8
9
 
9
10
 
10
- logger = setup_logger(add_rank=True)
11
+ module_logger = setup_logger()
11
12
 
12
13
 
13
- def log_dist(msg: str, how: Literal["only-zero-rank", "world"]) -> None:
14
+ def log_dist(
15
+ msg: str,
16
+ logger: KostylLogger | None = None,
17
+ level: Literal["info", "warning", "error", "warning_once", "debug"] = "info",
18
+ log_scope: Literal["only-zero-rank", "world"] = "world",
19
+ group: dist.ProcessGroup | None = None,
20
+ ) -> None:
14
21
  """
15
22
  Log a message in a distributed environment based on the specified verbosity level.
16
23
 
17
24
  Args:
18
25
  msg (str): The message to log.
19
- how (Literal["only-zero-rank", "world"]): The verbosity level for logging.
26
+ log_scope (Literal["only-zero-rank", "world"]): The verbosity level for logging.
20
27
  - "only-zero-rank": Log only from the main process (rank 0).
21
28
  - "world": Log from all processes in the distributed environment.
29
+ logger (KostylLogger | None): The logger instance to use. If None, the module logger is used.
30
+ level (Literal["info", "warning", "error", "warning_once", "debug"]): The logging level.
31
+ group (dist.ProcessGroup | None): Optional process group used to determine ranks. Defaults to the global process group.
22
32
 
23
33
  """
24
- match how:
25
- case _ if not dist.is_initialized():
26
- logger.warning_once(
27
- "Distributed logging requested but torch.distributed is not initialized."
28
- )
29
- logger.info(msg)
34
+ if logger is None:
35
+ logger = module_logger
36
+
37
+ log_attr = getattr(logger, level, None)
38
+ if log_attr is None:
39
+ raise ValueError(f"Invalid logging level: {level}")
40
+
41
+ if not dist.is_initialized():
42
+ module_logger.warning_once(
43
+ "Distributed process group is not initialized; logging from all ranks."
44
+ )
45
+ log_attr(msg)
46
+ return
47
+
48
+ match log_scope:
30
49
  case "only-zero-rank":
31
- if is_main_process():
32
- logger.info(msg)
50
+ if group is None:
51
+ module_logger.debug(
52
+ "No process group provided; assuming global group for rank check."
53
+ )
54
+ group = dist.group.WORLD
55
+ group_rank = dist.get_rank(group=group)
56
+ if dist.get_global_rank(group=group, group_rank=group_rank) == 0: # pyright: ignore[reportArgumentType]
57
+ log_attr(msg)
33
58
  case "world":
34
- logger.info(msg)
59
+ log_attr(msg)
35
60
  case _:
36
- logger.warning_once(
37
- f"Invalid logging verbosity level requested: {how}. Message not logged."
38
- )
61
+ raise ValueError(f"Invalid logging verbosity level: {log_scope}")
39
62
  return
40
63
 
41
64
 
@@ -44,7 +67,7 @@ def scale_lrs_by_world_size(
44
67
  group: dist.ProcessGroup | None = None,
45
68
  config_name: str = "",
46
69
  inv_scale: bool = False,
47
- verbose: Literal["only-zero-rank", "world"] | None = None,
70
+ verbose_level: Literal["only-zero-rank", "world"] | None = None,
48
71
  ) -> dict[str, float]:
49
72
  """
50
73
  Scale learning-rate configuration values to match the active distributed world size.
@@ -58,7 +81,7 @@ def scale_lrs_by_world_size(
58
81
  the target world size. Defaults to the global process group.
59
82
  config_name (str): Human-readable identifier included in log messages.
60
83
  inv_scale (bool): If True, use the inverse square-root scale factor.
61
- verbose (Literal["only-zero-rank", "world"] | None): Verbosity level for logging scaled values.
84
+ verbose_level (Literal["only-zero-rank", "world"] | None): Verbosity level for logging scaled values.
62
85
  - "only-zero-rank": Log only from the main process (rank 0).
63
86
  - "world": Log from all processes in the distributed environment.
64
87
  - None: No logging.
@@ -77,31 +100,30 @@ def scale_lrs_by_world_size(
77
100
  for name, value in lrs.items():
78
101
  old_value = value
79
102
  new_value = value * scale
80
- if verbose is not None:
103
+ if verbose_level is not None:
81
104
  log_dist(
82
105
  f"New {config_name} lr {name.upper()}: {new_value}; OLD: {old_value}",
83
- verbose,
106
+ log_scope=verbose_level,
107
+ group=group,
84
108
  )
85
109
  lrs[name] = new_value
86
110
  return lrs
87
111
 
88
112
 
89
- def get_rank() -> int:
90
- """Gets the rank of the current process in a distributed setting."""
91
- if dist.is_initialized():
92
- return dist.get_rank()
93
- if "RANK" in os.environ:
94
- return int(os.environ["RANK"])
95
- if "SLURM_PROCID" in os.environ:
96
- return int(os.environ["SLURM_PROCID"])
113
+ def get_local_rank(group: dist.ProcessGroup | None = None) -> int:
114
+ """Gets the local rank of the current process in a distributed setting."""
115
+ if dist.is_initialized() and group is not None:
116
+ return dist.get_rank(group=group)
117
+ if "SLURM_LOCALID" in os.environ:
118
+ return int(os.environ["SLURM_LOCALID"])
97
119
  if "LOCAL_RANK" in os.environ:
98
120
  return int(os.environ["LOCAL_RANK"])
99
121
  return 0
100
122
 
101
123
 
102
- def is_main_process() -> bool:
103
- """Checks if the current process is the main process (rank 0) in a distributed setting."""
104
- rank = get_rank()
124
+ def is_local_zero_rank() -> bool:
125
+ """Checks if the current process is the main process (rank 0) for the local node in a distributed setting."""
126
+ rank = get_local_rank()
105
127
  if rank != 0:
106
128
  return False
107
129
  return True
@@ -10,7 +10,7 @@ from lightning.fabric.utilities.types import _PATH
10
10
  from lightning.pytorch.callbacks import ModelCheckpoint
11
11
 
12
12
  from kostyl.ml.configs import CheckpointConfig
13
- from kostyl.ml.dist_utils import is_main_process
13
+ from kostyl.ml.dist_utils import is_local_zero_rank
14
14
  from kostyl.ml.lightning import KostylLightningModule
15
15
  from kostyl.ml.registry_uploader import RegistryUploaderCallback
16
16
  from kostyl.utils import setup_logger
@@ -299,9 +299,9 @@ class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
299
299
  def setup_checkpoint_callback(
300
300
  dirpath: Path,
301
301
  ckpt_cfg: CheckpointConfig,
302
- save_weights_only: bool = True,
303
302
  registry_uploader_callback: RegistryUploaderCallback | None = None,
304
303
  uploading_strategy: Literal["only-best", "every-checkpoint"] | None = None,
304
+ remove_folder_if_exists: bool = True,
305
305
  ) -> ModelCheckpointWithRegistryUploader | ModelCheckpoint:
306
306
  """
307
307
  Create and configure a checkpoint callback for model saving.
@@ -313,14 +313,13 @@ def setup_checkpoint_callback(
313
313
  Args:
314
314
  dirpath: Path to the directory for saving checkpoints.
315
315
  ckpt_cfg: Checkpoint configuration (filename, monitor, mode, save_top_k).
316
- save_weights_only: If True, only model weights are saved without optimizer and lr-scheduler state.
317
- Defaults to True.
318
316
  registry_uploader_callback: Optional callback for uploading checkpoints to a remote registry.
319
317
  Must be specified together with uploading_strategy.
320
318
  uploading_strategy: Checkpoint upload mode:
321
319
  - "only-best": only the best checkpoint is uploaded
322
320
  - "every-checkpoint": every saved checkpoint is uploaded
323
321
  Must be specified together with registry_uploader_callback.
322
+ remove_folder_if_exists: If True, removes existing checkpoint directory before creating a new one.
324
323
 
325
324
  Returns:
326
325
  ModelCheckpointWithRegistryUploader if registry_uploader_callback is provided,
@@ -331,7 +330,7 @@ def setup_checkpoint_callback(
331
330
 
332
331
  Note:
333
332
  If the dirpath directory already exists, it will be removed and recreated
334
- (only on the main process in distributed training).
333
+ (only on the main process in distributed training) if remove_folder_if_exists is True.
335
334
 
336
335
  """
337
336
  if (registry_uploader_callback is None) != (uploading_strategy is None):
@@ -340,10 +339,11 @@ def setup_checkpoint_callback(
340
339
  )
341
340
 
342
341
  if dirpath.exists():
343
- if is_main_process():
342
+ if is_local_zero_rank():
344
343
  logger.warning(f"Checkpoint directory {dirpath} already exists.")
345
- rmtree(dirpath)
346
- logger.warning(f"Removed existing checkpoint directory {dirpath}.")
344
+ if remove_folder_if_exists:
345
+ rmtree(dirpath)
346
+ logger.warning(f"Removed existing checkpoint directory {dirpath}.")
347
347
  else:
348
348
  logger.info(f"Creating checkpoint directory {dirpath}.")
349
349
  dirpath.mkdir(parents=True, exist_ok=True)
@@ -356,7 +356,7 @@ def setup_checkpoint_callback(
356
356
  monitor=ckpt_cfg.monitor,
357
357
  mode=ckpt_cfg.mode,
358
358
  verbose=True,
359
- save_weights_only=save_weights_only,
359
+ save_weights_only=ckpt_cfg.save_weights_only,
360
360
  registry_uploader_callback=registry_uploader_callback,
361
361
  uploading_mode=uploading_strategy,
362
362
  )
@@ -368,6 +368,6 @@ def setup_checkpoint_callback(
368
368
  monitor=ckpt_cfg.monitor,
369
369
  mode=ckpt_cfg.mode,
370
370
  verbose=True,
371
- save_weights_only=save_weights_only,
371
+ save_weights_only=ckpt_cfg.save_weights_only,
372
372
  )
373
373
  return checkpoint_callback
@@ -26,11 +26,6 @@ module_logger = setup_logger(fmt="only_message")
26
26
  class KostylLightningModule(L.LightningModule):
27
27
  """Custom PyTorch Lightning Module with logging, checkpointing, and distributed training utilities."""
28
28
 
29
- @property
30
- def process_group(self) -> ProcessGroup | None:
31
- """Returns the data parallel process group for distributed training."""
32
- return self.get_process_group()
33
-
34
29
  def get_process_group(self) -> ProcessGroup | None:
35
30
  """
36
31
  Retrieves the data parallel process group for distributed training.
@@ -12,12 +12,12 @@ from kostyl.utils.logging import setup_logger
12
12
  logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
13
13
 
14
14
 
15
- class LightningCheckpointLoaderMixin(PreTrainedModel):
15
+ class LightningCheckpointLoaderMixin:
16
16
  """A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
17
17
 
18
18
  @classmethod
19
- def from_lightning_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
20
- cls: type[TModelInstance],
19
+ def from_lightning_checkpoint[TModelInstance: PreTrainedModel]( # noqa: C901
20
+ cls: type[TModelInstance], # pyright: ignore[reportGeneralTypeIssues]
21
21
  checkpoint_path: str | Path,
22
22
  config_key: str = "config",
23
23
  weights_prefix: str | None = "model.",
@@ -78,7 +78,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
78
78
  mmap=True,
79
79
  )
80
80
 
81
- # 1. Восстанавливаем конфиг
81
+ # Load config
82
82
  config_cls = cast(type[PretrainedConfig], cls.config_class)
83
83
  config_dict = checkpoint_dict[config_key]
84
84
  config_dict.update(kwargs)
@@ -91,6 +91,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
91
91
 
92
92
  raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
93
93
 
94
+ # Handle weights prefix
94
95
  if weights_prefix:
95
96
  if not weights_prefix.endswith("."):
96
97
  weights_prefix = weights_prefix + "."
@@ -117,6 +118,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
117
118
  else:
118
119
  state_dict = raw_state_dict
119
120
 
121
+ # Instantiate model and load state dict
120
122
  model = cls.from_pretrained(
121
123
  pretrained_model_name_or_path=None,
122
124
  config=config,
@@ -3,7 +3,7 @@ from shutil import rmtree
3
3
 
4
4
  from lightning.pytorch.loggers import TensorBoardLogger
5
5
 
6
- from kostyl.ml.dist_utils import is_main_process
6
+ from kostyl.ml.dist_utils import is_local_zero_rank
7
7
  from kostyl.utils.logging import setup_logger
8
8
 
9
9
 
@@ -15,7 +15,7 @@ def setup_tb_logger(
15
15
  ) -> TensorBoardLogger:
16
16
  """Sets up a TensorBoardLogger for PyTorch Lightning."""
17
17
  if runs_dir.exists():
18
- if is_main_process():
18
+ if is_local_zero_rank():
19
19
  logger.warning(f"TensorBoard log directory {runs_dir} already exists.")
20
20
  rmtree(runs_dir)
21
21
  logger.warning(f"Removed existing TensorBoard log directory {runs_dir}.")
@@ -0,0 +1,58 @@
1
+ from typing import cast
2
+
3
+ import lightning as L
4
+ import torch.distributed as dist
5
+ from torch.distributed import ProcessGroup
6
+
7
+ from kostyl.ml.configs import DDPStrategyConfig
8
+ from kostyl.ml.configs import FSDP1StrategyConfig
9
+ from kostyl.ml.configs import SingleDeviceStrategyConfig
10
+ from kostyl.utils.logging import setup_logger
11
+
12
+
13
+ TRAINING_STRATEGIES = (
14
+ FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
15
+ )
16
+
17
+ logger = setup_logger()
18
+
19
+
20
+ def estimate_total_steps(
21
+ trainer: L.Trainer, dp_process_group: ProcessGroup | None = None
22
+ ) -> int:
23
+ """
24
+ Estimates the total number of training steps with respect to data parallelism and gradient accumulation.
25
+
26
+ Args:
27
+ trainer: The PyTorch Lightning Trainer instance.
28
+ dp_process_group: The data parallel process group. If None, the world process group will be used.
29
+
30
+ """
31
+ if dist.is_initialized():
32
+ world_size = dist.get_world_size(dp_process_group)
33
+ else:
34
+ world_size = 1
35
+
36
+ datamodule = trainer.datamodule # type: ignore
37
+ if datamodule is None:
38
+ raise ValueError("Trainer must have a datamodule to estimate total steps.")
39
+ datamodule = cast(L.LightningDataModule, datamodule)
40
+
41
+ logger.info("Loading `train_dataloader` to estimate number of stepping batches.")
42
+ datamodule.setup("fit")
43
+
44
+ dataloader_len = len(datamodule.train_dataloader())
45
+ steps_per_epoch = dataloader_len // trainer.accumulate_grad_batches // world_size
46
+
47
+ if trainer.max_epochs is None:
48
+ raise ValueError("Trainer must have `max_epochs` set to estimate total steps.")
49
+ total_steps = steps_per_epoch * trainer.max_epochs
50
+
51
+ logger.info(
52
+ f"Total steps: {total_steps} (per-epoch: {steps_per_epoch}) "
53
+ f"-> Dataloader len: {dataloader_len} "
54
+ f"-> Accumulate grad batches: {trainer.accumulate_grad_batches} "
55
+ f"-> Epochs: {trainer.max_epochs} "
56
+ f"-> DataParallel size: {world_size}"
57
+ )
58
+ return total_steps
@@ -1,13 +1,12 @@
1
1
  from abc import ABC
2
2
  from abc import abstractmethod
3
3
  from collections.abc import Callable
4
+ from functools import partial
4
5
  from pathlib import Path
5
6
  from typing import override
6
7
 
7
8
  from clearml import OutputModel
8
9
 
9
- from kostyl.ml.clearml.logging_utils import find_version_in_tags
10
- from kostyl.ml.clearml.logging_utils import increment_version
11
10
  from kostyl.utils.logging import setup_logger
12
11
 
13
12
 
@@ -28,51 +27,79 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
28
27
 
29
28
  def __init__(
30
29
  self,
31
- output_model: OutputModel,
30
+ model_name: str,
32
31
  config_dict: dict[str, str] | None = None,
32
+ label_enumeration: dict[str, int] | None = None,
33
+ tags: list[str] | None = None,
34
+ comment: str | None = None,
35
+ framework: str | None = None,
36
+ base_model_id: str | None = None,
37
+ new_model_per_upload: bool = True,
33
38
  verbose: bool = True,
34
- enable_tag_versioning: bool = False,
35
39
  ) -> None:
36
40
  """
37
41
  Initializes the ClearMLRegistryUploaderCallback.
38
42
 
39
43
  Args:
40
- output_model: ClearML OutputModel instance representing the model to upload.
41
- verbose: Whether to log messages during upload.
44
+ model_name: The name for the newly created model.
45
+ label_enumeration: The label enumeration dictionary of string (label) to integer (value) pairs.
42
46
  config_dict: Optional configuration dictionary to associate with the model.
43
- enable_tag_versioning: Whether to enable versioning in tags. If True,
44
- the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
47
+ tags: A list of strings which are tags for the model.
48
+ comment: A comment / description for the model.
49
+ framework: The framework of the model (e.g., "PyTorch", "TensorFlow").
50
+ base_model_id: Optional ClearML model ID to use as a base for the new model
51
+ new_model_per_upload: Whether to create a new ClearML model
52
+ for every upload or update weights of the same model. When updating weights,
53
+ the last uploaded checkpoint will be replaced (and deleted).
54
+ verbose: Whether to log messages during upload.
45
55
 
46
56
  """
47
57
  super().__init__()
48
- self.output_model = output_model
49
- self.config_dict = config_dict
50
- self.verbose = verbose
51
- self.enable_tag_versioning = enable_tag_versioning
58
+ if base_model_id is not None and new_model_per_upload:
59
+ raise ValueError(
60
+ "Cannot set base_model_id when new_model_per_upload is True."
61
+ )
52
62
 
63
+ self.verbose = verbose
64
+ self.new_model_per_upload = new_model_per_upload
53
65
  self.best_model_path: str = ""
54
-
66
+ self.config_dict = config_dict
67
+ self._output_model: OutputModel | None = None
55
68
  self._last_uploaded_model_path: str = ""
56
69
  self._upload_callback: Callable | None = None
57
70
 
58
- self._validate_tags()
71
+ self._validate_tags(tags)
72
+ self.model_fabric = partial(
73
+ OutputModel,
74
+ name=model_name,
75
+ label_enumeration=label_enumeration,
76
+ tags=tags,
77
+ comment=comment,
78
+ framework=framework,
79
+ base_model_id=base_model_id,
80
+ )
59
81
  return
60
82
 
61
- def _validate_tags(self) -> None:
62
- output_model_tags = self.output_model.tags or []
63
- if self.enable_tag_versioning:
64
- version = find_version_in_tags(output_model_tags)
65
- if version is None:
66
- output_model_tags.append("v1.0")
67
- else:
68
- new_version = increment_version(version)
69
- output_model_tags.remove(version)
70
- output_model_tags.append(new_version)
71
- if "LightningCheckpoint" not in output_model_tags:
72
- output_model_tags.append("LightningCheckpoint")
73
- self.output_model.tags = output_model_tags
83
+ @staticmethod
84
+ def _validate_tags(tags: list[str] | None) -> None:
85
+ if tags is None:
86
+ return
87
+ if "LightningCheckpoint" not in tags:
88
+ tags.append("LightningCheckpoint")
74
89
  return None
75
90
 
91
+ @property
92
+ def output_model_(self) -> OutputModel:
93
+ """Returns the OutputModel instance based on `new_model_per_upload` setting."""
94
+ if self.new_model_per_upload:
95
+ model = self.model_fabric()
96
+ self._output_model = self.model_fabric()
97
+ else:
98
+ if self._output_model is None:
99
+ self._output_model = self.model_fabric()
100
+ model = self._output_model
101
+ return model
102
+
76
103
  @override
77
104
  def upload_checkpoint(
78
105
  self,
@@ -88,12 +115,12 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
88
115
  if self.verbose:
89
116
  logger.info(f"Uploading model from {path}")
90
117
 
91
- self.output_model.update_weights(
118
+ self.output_model_.update_weights(
92
119
  path,
93
120
  auto_delete_file=False,
94
121
  async_enable=False,
95
122
  )
96
- self.output_model.update_design(config_dict=self.config_dict)
123
+ self.output_model_.update_design(config_dict=self.config_dict)
97
124
 
98
125
  self._last_uploaded_model_path = path
99
126
  return
@@ -1,6 +1,18 @@
1
1
  from .composite import CompositeScheduler
2
2
  from .cosine import CosineParamScheduler
3
3
  from .cosine import CosineScheduler
4
+ from .cosine_with_plateu import CosineWithPlateauParamScheduler
5
+ from .cosine_with_plateu import CosineWithPlateuScheduler
6
+ from .linear import LinearParamScheduler
7
+ from .linear import LinearScheduler
4
8
 
5
9
 
6
- __all__ = ["CompositeScheduler", "CosineParamScheduler", "CosineScheduler"]
10
+ __all__ = [
11
+ "CompositeScheduler",
12
+ "CosineParamScheduler",
13
+ "CosineScheduler",
14
+ "CosineWithPlateauParamScheduler",
15
+ "CosineWithPlateuScheduler",
16
+ "LinearParamScheduler",
17
+ "LinearScheduler",
18
+ ]
@@ -6,18 +6,20 @@ from typing import Any
6
6
  class BaseScheduler(ABC):
7
7
  """Base class for learning rate schedulers."""
8
8
 
9
+ @abstractmethod
9
10
  def state_dict(self) -> dict[str, Any]:
10
11
  """Get the state as a state dictionary."""
11
- return {
12
- key: value
13
- for key, value in self.__dict__.items()
14
- if key not in ["optimizer", "scheduler_values"]
15
- }
12
+ raise NotImplementedError
16
13
 
14
+ @abstractmethod
17
15
  def load_state_dict(self, state_dict: dict[str, Any]) -> None:
18
16
  """Load the state from a state dictionary."""
19
- self.__dict__.update(state_dict)
20
- return
17
+ raise NotImplementedError
18
+
19
+ @abstractmethod
20
+ def _verify(self) -> None:
21
+ """Verify the scheduler configuration."""
22
+ raise NotImplementedError
21
23
 
22
24
  def __getstate__(self) -> dict[str, Any]:
23
25
  """Get the state for pickling."""