kostyl-toolkit 0.1.35__tar.gz → 0.1.37__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/PKG-INFO +1 -1
  2. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/configs/hyperparams.py +21 -5
  3. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/configs/training_settings.py +17 -6
  4. kostyl_toolkit-0.1.37/kostyl/ml/dist_utils.py +129 -0
  5. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/callbacks/checkpoint.py +10 -10
  6. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/extensions/custom_module.py +0 -5
  7. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/extensions/pretrained_model.py +6 -4
  8. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/loggers/tb_logger.py +2 -2
  9. kostyl_toolkit-0.1.37/kostyl/ml/lightning/utils.py +58 -0
  10. kostyl_toolkit-0.1.37/kostyl/ml/registry_uploader.py +126 -0
  11. kostyl_toolkit-0.1.37/kostyl/ml/schedulers/__init__.py +18 -0
  12. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/schedulers/base.py +9 -7
  13. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/schedulers/cosine.py +53 -24
  14. kostyl_toolkit-0.1.37/kostyl/ml/schedulers/cosine_with_plateu.py +277 -0
  15. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/schedulers/linear.py +36 -11
  16. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/utils/logging.py +68 -53
  17. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/pyproject.toml +2 -2
  18. kostyl_toolkit-0.1.35/kostyl/ml/dist_utils.py +0 -107
  19. kostyl_toolkit-0.1.35/kostyl/ml/lightning/training_utils.py +0 -241
  20. kostyl_toolkit-0.1.35/kostyl/ml/registry_uploader.py +0 -99
  21. kostyl_toolkit-0.1.35/kostyl/ml/schedulers/__init__.py +0 -6
  22. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/README.md +0 -0
  23. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/__init__.py +0 -0
  24. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/__init__.py +0 -0
  25. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/clearml/__init__.py +0 -0
  26. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/clearml/dataset_utils.py +0 -0
  27. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/clearml/logging_utils.py +0 -0
  28. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/clearml/pulling_utils.py +0 -0
  29. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/configs/__init__.py +0 -0
  30. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/configs/base_model.py +0 -0
  31. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/data_processing_utils.py +0 -0
  32. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/__init__.py +0 -0
  33. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/callbacks/__init__.py +0 -0
  34. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/callbacks/early_stopping.py +0 -0
  35. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/extensions/__init__.py +0 -0
  36. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/loggers/__init__.py +0 -0
  37. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/metrics_formatting.py +0 -0
  38. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/params_groups.py +0 -0
  39. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/ml/schedulers/composite.py +0 -0
  40. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/utils/__init__.py +0 -0
  41. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/utils/dict_manipulations.py +0 -0
  42. {kostyl_toolkit-0.1.35 → kostyl_toolkit-0.1.37}/kostyl/utils/fs.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.35
3
+ Version: 0.1.37
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -1,3 +1,5 @@
1
+ from typing import Literal
2
+
1
3
  from pydantic import BaseModel
2
4
  from pydantic import Field
3
5
  from pydantic import model_validator
@@ -8,11 +10,25 @@ from kostyl.utils.logging import setup_logger
8
10
  logger = setup_logger(fmt="only_message")
9
11
 
10
12
 
11
- class Optimizer(BaseModel):
12
- """Optimizer hyperparameters configuration."""
13
+ class AdamConfig(BaseModel):
14
+ """AdamW optimizer hyperparameters configuration."""
15
+
16
+ type: Literal["AdamW"] = "AdamW"
17
+ betas: tuple[float, float] = (0.9, 0.999)
18
+ is_adamw: bool = True
19
+
20
+
21
+ class AdamWithPrecisionConfig(BaseModel):
22
+ """Adam optimizer with low-precision hyperparameters configuration."""
23
+
24
+ type: Literal["Adam8bit", "Adam4bit", "AdamFp8"]
25
+ betas: tuple[float, float] = (0.9, 0.999)
26
+ block_size: int
27
+ bf16_stochastic_round: bool = False
28
+ is_adamw: bool = True
29
+
13
30
 
14
- adamw_beta1: float = 0.9
15
- adamw_beta2: float = 0.999
31
+ Optimizer = AdamConfig | AdamWithPrecisionConfig
16
32
 
17
33
 
18
34
  class Lr(BaseModel):
@@ -73,6 +89,6 @@ class HyperparamsConfig(BaseModel):
73
89
  """Model training hyperparameters configuration."""
74
90
 
75
91
  grad_clip_val: float | None = Field(default=None, gt=0, validate_default=False)
76
- optimizer: Optimizer = Optimizer()
92
+ optimizer: Optimizer
77
93
  lr: Lr
78
94
  weight_decay: WeightDecay
@@ -25,21 +25,31 @@ PRECISION = Literal[
25
25
  "16",
26
26
  "bf16",
27
27
  ]
28
+ DTYPE = Literal["float32", "float16", "bfloat16", "float64"]
29
+
30
+
31
+ class SingleDeviceStrategyConfig(BaseModel):
32
+ """Single device strategy configuration."""
33
+
34
+ type: Literal["single_device"]
28
35
 
29
36
 
30
37
  class FSDP1StrategyConfig(BaseModel):
31
38
  """Fully Sharded Data Parallel (FSDP) strategy configuration."""
32
39
 
33
40
  type: Literal["fsdp1"]
34
- param_dtype: Literal["float32", "float16", "bfloat16"]
35
- reduce_dtype: Literal["float32", "float16", "bfloat16"]
36
- buffer_dtype: Literal["float32", "float16", "bfloat16"]
41
+ param_dtype: DTYPE | None
42
+ reduce_dtype: DTYPE | None
43
+ buffer_dtype: DTYPE | None
37
44
 
38
45
 
39
- class SingleDeviceStrategyConfig(BaseModel):
40
- """Single device strategy configuration."""
46
+ class FSDP2StrategyConfig(BaseModel):
47
+ """Fully Sharded Data Parallel (FSDP) strategy configuration."""
41
48
 
42
- type: Literal["single_device"]
49
+ type: Literal["fsdp2"]
50
+ param_dtype: DTYPE | None
51
+ reduce_dtype: DTYPE | None
52
+ buffer_dtype: DTYPE | None
43
53
 
44
54
 
45
55
  class DDPStrategyConfig(BaseModel):
@@ -82,6 +92,7 @@ class CheckpointConfig(BaseModel):
82
92
  monitor: str = "val_loss"
83
93
  mode: str = "min"
84
94
  filename: str = "{epoch:02d}-{val_loss:.2f}"
95
+ save_weights_only: bool = True
85
96
 
86
97
 
87
98
  class DataConfig(BaseModel):
@@ -0,0 +1,129 @@
1
+ import math
2
+ import os
3
+ from typing import Literal
4
+
5
+ import torch.distributed as dist
6
+
7
+ from kostyl.utils.logging import KostylLogger
8
+ from kostyl.utils.logging import setup_logger
9
+
10
+
11
+ module_logger = setup_logger()
12
+
13
+
14
+ def log_dist(
15
+ msg: str,
16
+ logger: KostylLogger | None = None,
17
+ level: Literal["info", "warning", "error", "warning_once", "debug"] = "info",
18
+ log_scope: Literal["only-zero-rank", "world"] = "world",
19
+ group: dist.ProcessGroup | None = None,
20
+ ) -> None:
21
+ """
22
+ Log a message in a distributed environment based on the specified verbosity level.
23
+
24
+ Args:
25
+ msg (str): The message to log.
26
+ log_scope (Literal["only-zero-rank", "world"]): The verbosity level for logging.
27
+ - "only-zero-rank": Log only from the main process (rank 0).
28
+ - "world": Log from all processes in the distributed environment.
29
+ logger (KostylLogger | None): The logger instance to use. If None, the module logger is used.
30
+ level (Literal["info", "warning", "error", "warning_once", "debug"]): The logging level.
31
+ group (dist.ProcessGroup | None): Optional process group used to determine ranks. Defaults to the global process group.
32
+
33
+ """
34
+ if logger is None:
35
+ logger = module_logger
36
+
37
+ log_attr = getattr(logger, level, None)
38
+ if log_attr is None:
39
+ raise ValueError(f"Invalid logging level: {level}")
40
+
41
+ if not dist.is_initialized():
42
+ module_logger.warning_once(
43
+ "Distributed process group is not initialized; logging from all ranks."
44
+ )
45
+ log_attr(msg)
46
+ return
47
+
48
+ match log_scope:
49
+ case "only-zero-rank":
50
+ if group is None:
51
+ module_logger.debug(
52
+ "No process group provided; assuming global group for rank check."
53
+ )
54
+ group = dist.group.WORLD
55
+ group_rank = dist.get_rank(group=group)
56
+ if dist.get_global_rank(group=group, group_rank=group_rank) == 0: # pyright: ignore[reportArgumentType]
57
+ log_attr(msg)
58
+ case "world":
59
+ log_attr(msg)
60
+ case _:
61
+ raise ValueError(f"Invalid logging verbosity level: {log_scope}")
62
+ return
63
+
64
+
65
+ def scale_lrs_by_world_size(
66
+ lrs: dict[str, float],
67
+ group: dist.ProcessGroup | None = None,
68
+ config_name: str = "",
69
+ inv_scale: bool = False,
70
+ verbose_level: Literal["only-zero-rank", "world"] | None = None,
71
+ ) -> dict[str, float]:
72
+ """
73
+ Scale learning-rate configuration values to match the active distributed world size.
74
+
75
+ Note:
76
+ The value in the `lrs` will be modified in place.
77
+
78
+ Args:
79
+ lrs (dict[str, float]): A dictionary of learning rate names and their corresponding values to be scaled.
80
+ group (dist.ProcessGroup | None): Optional process group used to determine
81
+ the target world size. Defaults to the global process group.
82
+ config_name (str): Human-readable identifier included in log messages.
83
+ inv_scale (bool): If True, use the inverse square-root scale factor.
84
+ verbose_level (Literal["only-zero-rank", "world"] | None): Verbosity level for logging scaled values.
85
+ - "only-zero-rank": Log only from the main process (rank 0).
86
+ - "world": Log from all processes in the distributed environment.
87
+ - None: No logging.
88
+
89
+ Returns:
90
+ dict[str, float]: The learning-rate configuration with scaled values.
91
+
92
+ """
93
+ world_size = dist.get_world_size(group=group)
94
+
95
+ if inv_scale:
96
+ scale = 1 / math.sqrt(world_size)
97
+ else:
98
+ scale = math.sqrt(world_size)
99
+
100
+ for name, value in lrs.items():
101
+ old_value = value
102
+ new_value = value * scale
103
+ if verbose_level is not None:
104
+ log_dist(
105
+ f"New {config_name} lr {name.upper()}: {new_value}; OLD: {old_value}",
106
+ log_scope=verbose_level,
107
+ group=group,
108
+ )
109
+ lrs[name] = new_value
110
+ return lrs
111
+
112
+
113
+ def get_local_rank(group: dist.ProcessGroup | None = None) -> int:
114
+ """Gets the local rank of the current process in a distributed setting."""
115
+ if dist.is_initialized() and group is not None:
116
+ return dist.get_rank(group=group)
117
+ if "SLURM_LOCALID" in os.environ:
118
+ return int(os.environ["SLURM_LOCALID"])
119
+ if "LOCAL_RANK" in os.environ:
120
+ return int(os.environ["LOCAL_RANK"])
121
+ return 0
122
+
123
+
124
+ def is_local_zero_rank() -> bool:
125
+ """Checks if the current process is the main process (rank 0) for the local node in a distributed setting."""
126
+ rank = get_local_rank()
127
+ if rank != 0:
128
+ return False
129
+ return True
@@ -10,7 +10,7 @@ from lightning.fabric.utilities.types import _PATH
10
10
  from lightning.pytorch.callbacks import ModelCheckpoint
11
11
 
12
12
  from kostyl.ml.configs import CheckpointConfig
13
- from kostyl.ml.dist_utils import is_main_process
13
+ from kostyl.ml.dist_utils import is_local_zero_rank
14
14
  from kostyl.ml.lightning import KostylLightningModule
15
15
  from kostyl.ml.registry_uploader import RegistryUploaderCallback
16
16
  from kostyl.utils import setup_logger
@@ -299,9 +299,9 @@ class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
299
299
  def setup_checkpoint_callback(
300
300
  dirpath: Path,
301
301
  ckpt_cfg: CheckpointConfig,
302
- save_weights_only: bool = True,
303
302
  registry_uploader_callback: RegistryUploaderCallback | None = None,
304
303
  uploading_strategy: Literal["only-best", "every-checkpoint"] | None = None,
304
+ remove_folder_if_exists: bool = True,
305
305
  ) -> ModelCheckpointWithRegistryUploader | ModelCheckpoint:
306
306
  """
307
307
  Create and configure a checkpoint callback for model saving.
@@ -313,14 +313,13 @@ def setup_checkpoint_callback(
313
313
  Args:
314
314
  dirpath: Path to the directory for saving checkpoints.
315
315
  ckpt_cfg: Checkpoint configuration (filename, monitor, mode, save_top_k).
316
- save_weights_only: If True, only model weights are saved without optimizer and lr-scheduler state.
317
- Defaults to True.
318
316
  registry_uploader_callback: Optional callback for uploading checkpoints to a remote registry.
319
317
  Must be specified together with uploading_strategy.
320
318
  uploading_strategy: Checkpoint upload mode:
321
319
  - "only-best": only the best checkpoint is uploaded
322
320
  - "every-checkpoint": every saved checkpoint is uploaded
323
321
  Must be specified together with registry_uploader_callback.
322
+ remove_folder_if_exists: If True, removes existing checkpoint directory before creating a new one.
324
323
 
325
324
  Returns:
326
325
  ModelCheckpointWithRegistryUploader if registry_uploader_callback is provided,
@@ -331,7 +330,7 @@ def setup_checkpoint_callback(
331
330
 
332
331
  Note:
333
332
  If the dirpath directory already exists, it will be removed and recreated
334
- (only on the main process in distributed training).
333
+ (only on the main process in distributed training) if remove_folder_if_exists is True.
335
334
 
336
335
  """
337
336
  if (registry_uploader_callback is None) != (uploading_strategy is None):
@@ -340,10 +339,11 @@ def setup_checkpoint_callback(
340
339
  )
341
340
 
342
341
  if dirpath.exists():
343
- if is_main_process():
342
+ if is_local_zero_rank():
344
343
  logger.warning(f"Checkpoint directory {dirpath} already exists.")
345
- rmtree(dirpath)
346
- logger.warning(f"Removed existing checkpoint directory {dirpath}.")
344
+ if remove_folder_if_exists:
345
+ rmtree(dirpath)
346
+ logger.warning(f"Removed existing checkpoint directory {dirpath}.")
347
347
  else:
348
348
  logger.info(f"Creating checkpoint directory {dirpath}.")
349
349
  dirpath.mkdir(parents=True, exist_ok=True)
@@ -356,7 +356,7 @@ def setup_checkpoint_callback(
356
356
  monitor=ckpt_cfg.monitor,
357
357
  mode=ckpt_cfg.mode,
358
358
  verbose=True,
359
- save_weights_only=save_weights_only,
359
+ save_weights_only=ckpt_cfg.save_weights_only,
360
360
  registry_uploader_callback=registry_uploader_callback,
361
361
  uploading_mode=uploading_strategy,
362
362
  )
@@ -368,6 +368,6 @@ def setup_checkpoint_callback(
368
368
  monitor=ckpt_cfg.monitor,
369
369
  mode=ckpt_cfg.mode,
370
370
  verbose=True,
371
- save_weights_only=save_weights_only,
371
+ save_weights_only=ckpt_cfg.save_weights_only,
372
372
  )
373
373
  return checkpoint_callback
@@ -26,11 +26,6 @@ module_logger = setup_logger(fmt="only_message")
26
26
  class KostylLightningModule(L.LightningModule):
27
27
  """Custom PyTorch Lightning Module with logging, checkpointing, and distributed training utilities."""
28
28
 
29
- @property
30
- def process_group(self) -> ProcessGroup | None:
31
- """Returns the data parallel process group for distributed training."""
32
- return self.get_process_group()
33
-
34
29
  def get_process_group(self) -> ProcessGroup | None:
35
30
  """
36
31
  Retrieves the data parallel process group for distributed training.
@@ -12,12 +12,12 @@ from kostyl.utils.logging import setup_logger
12
12
  logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
13
13
 
14
14
 
15
- class LightningCheckpointLoaderMixin(PreTrainedModel):
15
+ class LightningCheckpointLoaderMixin:
16
16
  """A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
17
17
 
18
18
  @classmethod
19
- def from_lightning_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
20
- cls: type[TModelInstance],
19
+ def from_lightning_checkpoint[TModelInstance: PreTrainedModel]( # noqa: C901
20
+ cls: type[TModelInstance], # pyright: ignore[reportGeneralTypeIssues]
21
21
  checkpoint_path: str | Path,
22
22
  config_key: str = "config",
23
23
  weights_prefix: str | None = "model.",
@@ -78,7 +78,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
78
78
  mmap=True,
79
79
  )
80
80
 
81
- # 1. Восстанавливаем конфиг
81
+ # Load config
82
82
  config_cls = cast(type[PretrainedConfig], cls.config_class)
83
83
  config_dict = checkpoint_dict[config_key]
84
84
  config_dict.update(kwargs)
@@ -91,6 +91,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
91
91
 
92
92
  raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
93
93
 
94
+ # Handle weights prefix
94
95
  if weights_prefix:
95
96
  if not weights_prefix.endswith("."):
96
97
  weights_prefix = weights_prefix + "."
@@ -117,6 +118,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
117
118
  else:
118
119
  state_dict = raw_state_dict
119
120
 
121
+ # Instantiate model and load state dict
120
122
  model = cls.from_pretrained(
121
123
  pretrained_model_name_or_path=None,
122
124
  config=config,
@@ -3,7 +3,7 @@ from shutil import rmtree
3
3
 
4
4
  from lightning.pytorch.loggers import TensorBoardLogger
5
5
 
6
- from kostyl.ml.dist_utils import is_main_process
6
+ from kostyl.ml.dist_utils import is_local_zero_rank
7
7
  from kostyl.utils.logging import setup_logger
8
8
 
9
9
 
@@ -15,7 +15,7 @@ def setup_tb_logger(
15
15
  ) -> TensorBoardLogger:
16
16
  """Sets up a TensorBoardLogger for PyTorch Lightning."""
17
17
  if runs_dir.exists():
18
- if is_main_process():
18
+ if is_local_zero_rank():
19
19
  logger.warning(f"TensorBoard log directory {runs_dir} already exists.")
20
20
  rmtree(runs_dir)
21
21
  logger.warning(f"Removed existing TensorBoard log directory {runs_dir}.")
@@ -0,0 +1,58 @@
1
+ from typing import cast
2
+
3
+ import lightning as L
4
+ import torch.distributed as dist
5
+ from torch.distributed import ProcessGroup
6
+
7
+ from kostyl.ml.configs import DDPStrategyConfig
8
+ from kostyl.ml.configs import FSDP1StrategyConfig
9
+ from kostyl.ml.configs import SingleDeviceStrategyConfig
10
+ from kostyl.utils.logging import setup_logger
11
+
12
+
13
+ TRAINING_STRATEGIES = (
14
+ FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
15
+ )
16
+
17
+ logger = setup_logger()
18
+
19
+
20
+ def estimate_total_steps(
21
+ trainer: L.Trainer, dp_process_group: ProcessGroup | None = None
22
+ ) -> int:
23
+ """
24
+ Estimates the total number of training steps with respect to data parallelism and gradient accumulation.
25
+
26
+ Args:
27
+ trainer: The PyTorch Lightning Trainer instance.
28
+ dp_process_group: The data parallel process group. If None, the world process group will be used.
29
+
30
+ """
31
+ if dist.is_initialized():
32
+ world_size = dist.get_world_size(dp_process_group)
33
+ else:
34
+ world_size = 1
35
+
36
+ datamodule = trainer.datamodule # type: ignore
37
+ if datamodule is None:
38
+ raise ValueError("Trainer must have a datamodule to estimate total steps.")
39
+ datamodule = cast(L.LightningDataModule, datamodule)
40
+
41
+ logger.info("Loading `train_dataloader` to estimate number of stepping batches.")
42
+ datamodule.setup("fit")
43
+
44
+ dataloader_len = len(datamodule.train_dataloader())
45
+ steps_per_epoch = dataloader_len // trainer.accumulate_grad_batches // world_size
46
+
47
+ if trainer.max_epochs is None:
48
+ raise ValueError("Trainer must have `max_epochs` set to estimate total steps.")
49
+ total_steps = steps_per_epoch * trainer.max_epochs
50
+
51
+ logger.info(
52
+ f"Total steps: {total_steps} (per-epoch: {steps_per_epoch}) "
53
+ f"-> Dataloader len: {dataloader_len} "
54
+ f"-> Accumulate grad batches: {trainer.accumulate_grad_batches} "
55
+ f"-> Epochs: {trainer.max_epochs} "
56
+ f"-> DataParallel size: {world_size}"
57
+ )
58
+ return total_steps
@@ -0,0 +1,126 @@
1
+ from abc import ABC
2
+ from abc import abstractmethod
3
+ from collections.abc import Callable
4
+ from functools import partial
5
+ from pathlib import Path
6
+ from typing import override
7
+
8
+ from clearml import OutputModel
9
+
10
+ from kostyl.utils.logging import setup_logger
11
+
12
+
13
+ logger = setup_logger()
14
+
15
+
16
+ class RegistryUploaderCallback(ABC):
17
+ """Abstract Lightning callback responsible for tracking and uploading the best-performing model checkpoint."""
18
+
19
+ @abstractmethod
20
+ def upload_checkpoint(self, path: str | Path) -> None:
21
+ """Upload the checkpoint located at the given path to the configured registry backend."""
22
+ raise NotImplementedError
23
+
24
+
25
+ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
26
+ """PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
27
+
28
+ def __init__(
29
+ self,
30
+ model_name: str,
31
+ config_dict: dict[str, str] | None = None,
32
+ label_enumeration: dict[str, int] | None = None,
33
+ tags: list[str] | None = None,
34
+ comment: str | None = None,
35
+ framework: str | None = None,
36
+ base_model_id: str | None = None,
37
+ new_model_per_upload: bool = True,
38
+ verbose: bool = True,
39
+ ) -> None:
40
+ """
41
+ Initializes the ClearMLRegistryUploaderCallback.
42
+
43
+ Args:
44
+ model_name: The name for the newly created model.
45
+ label_enumeration: The label enumeration dictionary of string (label) to integer (value) pairs.
46
+ config_dict: Optional configuration dictionary to associate with the model.
47
+ tags: A list of strings which are tags for the model.
48
+ comment: A comment / description for the model.
49
+ framework: The framework of the model (e.g., "PyTorch", "TensorFlow").
50
+ base_model_id: Optional ClearML model ID to use as a base for the new model
51
+ new_model_per_upload: Whether to create a new ClearML model
52
+ for every upload or update weights of the same model. When updating weights,
53
+ the last uploaded checkpoint will be replaced (and deleted).
54
+ verbose: Whether to log messages during upload.
55
+
56
+ """
57
+ super().__init__()
58
+ if base_model_id is not None and new_model_per_upload:
59
+ raise ValueError(
60
+ "Cannot set base_model_id when new_model_per_upload is True."
61
+ )
62
+
63
+ self.verbose = verbose
64
+ self.new_model_per_upload = new_model_per_upload
65
+ self.best_model_path: str = ""
66
+ self.config_dict = config_dict
67
+ self._output_model: OutputModel | None = None
68
+ self._last_uploaded_model_path: str = ""
69
+ self._upload_callback: Callable | None = None
70
+
71
+ self._validate_tags(tags)
72
+ self.model_fabric = partial(
73
+ OutputModel,
74
+ name=model_name,
75
+ label_enumeration=label_enumeration,
76
+ tags=tags,
77
+ comment=comment,
78
+ framework=framework,
79
+ base_model_id=base_model_id,
80
+ )
81
+ return
82
+
83
+ @staticmethod
84
+ def _validate_tags(tags: list[str] | None) -> None:
85
+ if tags is None:
86
+ return
87
+ if "LightningCheckpoint" not in tags:
88
+ tags.append("LightningCheckpoint")
89
+ return None
90
+
91
+ @property
92
+ def output_model_(self) -> OutputModel:
93
+ """Returns the OutputModel instance based on `new_model_per_upload` setting."""
94
+ if self.new_model_per_upload:
95
+ model = self.model_fabric()
96
+ self._output_model = self.model_fabric()
97
+ else:
98
+ if self._output_model is None:
99
+ self._output_model = self.model_fabric()
100
+ model = self._output_model
101
+ return model
102
+
103
+ @override
104
+ def upload_checkpoint(
105
+ self,
106
+ path: str | Path,
107
+ ) -> None:
108
+ if isinstance(path, Path):
109
+ path = str(path)
110
+ if path == self._last_uploaded_model_path:
111
+ if self.verbose:
112
+ logger.info("Model unchanged since last upload")
113
+ return
114
+
115
+ if self.verbose:
116
+ logger.info(f"Uploading model from {path}")
117
+
118
+ self.output_model_.update_weights(
119
+ path,
120
+ auto_delete_file=False,
121
+ async_enable=False,
122
+ )
123
+ self.output_model_.update_design(config_dict=self.config_dict)
124
+
125
+ self._last_uploaded_model_path = path
126
+ return
@@ -0,0 +1,18 @@
1
+ from .composite import CompositeScheduler
2
+ from .cosine import CosineParamScheduler
3
+ from .cosine import CosineScheduler
4
+ from .cosine_with_plateu import CosineWithPlateauParamScheduler
5
+ from .cosine_with_plateu import CosineWithPlateuScheduler
6
+ from .linear import LinearParamScheduler
7
+ from .linear import LinearScheduler
8
+
9
+
10
+ __all__ = [
11
+ "CompositeScheduler",
12
+ "CosineParamScheduler",
13
+ "CosineScheduler",
14
+ "CosineWithPlateauParamScheduler",
15
+ "CosineWithPlateuScheduler",
16
+ "LinearParamScheduler",
17
+ "LinearScheduler",
18
+ ]
@@ -6,18 +6,20 @@ from typing import Any
6
6
  class BaseScheduler(ABC):
7
7
  """Base class for learning rate schedulers."""
8
8
 
9
+ @abstractmethod
9
10
  def state_dict(self) -> dict[str, Any]:
10
11
  """Get the state as a state dictionary."""
11
- return {
12
- key: value
13
- for key, value in self.__dict__.items()
14
- if key not in ["optimizer", "scheduler_values"]
15
- }
12
+ raise NotImplementedError
16
13
 
14
+ @abstractmethod
17
15
  def load_state_dict(self, state_dict: dict[str, Any]) -> None:
18
16
  """Load the state from a state dictionary."""
19
- self.__dict__.update(state_dict)
20
- return
17
+ raise NotImplementedError
18
+
19
+ @abstractmethod
20
+ def _verify(self) -> None:
21
+ """Verify the scheduler configuration."""
22
+ raise NotImplementedError
21
23
 
22
24
  def __getstate__(self) -> dict[str, Any]:
23
25
  """Get the state for pickling."""