kostyl-toolkit 0.1.35__py3-none-any.whl → 0.1.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
+ from typing import Literal
2
+
1
3
  from pydantic import BaseModel
2
4
  from pydantic import Field
3
5
  from pydantic import model_validator
@@ -8,11 +10,25 @@ from kostyl.utils.logging import setup_logger
8
10
  logger = setup_logger(fmt="only_message")
9
11
 
10
12
 
11
- class Optimizer(BaseModel):
12
- """Optimizer hyperparameters configuration."""
13
+ class AdamConfig(BaseModel):
14
+ """AdamW optimizer hyperparameters configuration."""
15
+
16
+ type: Literal["AdamW"] = "AdamW"
17
+ betas: tuple[float, float] = (0.9, 0.999)
18
+ is_adamw: bool = True
19
+
20
+
21
+ class AdamWithPrecisionConfig(BaseModel):
22
+ """Adam optimizer with low-precision hyperparameters configuration."""
23
+
24
+ type: Literal["Adam8bit", "Adam4bit", "AdamFp8"]
25
+ betas: tuple[float, float] = (0.9, 0.999)
26
+ block_size: int
27
+ bf16_stochastic_round: bool = False
28
+ is_adamw: bool = True
29
+
13
30
 
14
- adamw_beta1: float = 0.9
15
- adamw_beta2: float = 0.999
31
+ Optimizer = AdamConfig | AdamWithPrecisionConfig
16
32
 
17
33
 
18
34
  class Lr(BaseModel):
@@ -73,6 +89,6 @@ class HyperparamsConfig(BaseModel):
73
89
  """Model training hyperparameters configuration."""
74
90
 
75
91
  grad_clip_val: float | None = Field(default=None, gt=0, validate_default=False)
76
- optimizer: Optimizer = Optimizer()
92
+ optimizer: Optimizer
77
93
  lr: Lr
78
94
  weight_decay: WeightDecay
@@ -25,21 +25,31 @@ PRECISION = Literal[
25
25
  "16",
26
26
  "bf16",
27
27
  ]
28
+ DTYPE = Literal["float32", "float16", "bfloat16", "float64"]
29
+
30
+
31
+ class SingleDeviceStrategyConfig(BaseModel):
32
+ """Single device strategy configuration."""
33
+
34
+ type: Literal["single_device"]
28
35
 
29
36
 
30
37
  class FSDP1StrategyConfig(BaseModel):
31
38
  """Fully Sharded Data Parallel (FSDP) strategy configuration."""
32
39
 
33
40
  type: Literal["fsdp1"]
34
- param_dtype: Literal["float32", "float16", "bfloat16"]
35
- reduce_dtype: Literal["float32", "float16", "bfloat16"]
36
- buffer_dtype: Literal["float32", "float16", "bfloat16"]
41
+ param_dtype: DTYPE | None
42
+ reduce_dtype: DTYPE | None
43
+ buffer_dtype: DTYPE | None
37
44
 
38
45
 
39
- class SingleDeviceStrategyConfig(BaseModel):
40
- """Single device strategy configuration."""
46
+ class FSDP2StrategyConfig(BaseModel):
47
+ """Fully Sharded Data Parallel (FSDP) strategy configuration."""
41
48
 
42
- type: Literal["single_device"]
49
+ type: Literal["fsdp2"]
50
+ param_dtype: DTYPE | None
51
+ reduce_dtype: DTYPE | None
52
+ buffer_dtype: DTYPE | None
43
53
 
44
54
 
45
55
  class DDPStrategyConfig(BaseModel):
@@ -82,6 +92,7 @@ class CheckpointConfig(BaseModel):
82
92
  monitor: str = "val_loss"
83
93
  mode: str = "min"
84
94
  filename: str = "{epoch:02d}-{val_loss:.2f}"
95
+ save_weights_only: bool = True
85
96
 
86
97
 
87
98
  class DataConfig(BaseModel):
@@ -299,9 +299,9 @@ class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
299
299
  def setup_checkpoint_callback(
300
300
  dirpath: Path,
301
301
  ckpt_cfg: CheckpointConfig,
302
- save_weights_only: bool = True,
303
302
  registry_uploader_callback: RegistryUploaderCallback | None = None,
304
303
  uploading_strategy: Literal["only-best", "every-checkpoint"] | None = None,
304
+ remove_folder_if_exists: bool = True,
305
305
  ) -> ModelCheckpointWithRegistryUploader | ModelCheckpoint:
306
306
  """
307
307
  Create and configure a checkpoint callback for model saving.
@@ -313,14 +313,13 @@ def setup_checkpoint_callback(
313
313
  Args:
314
314
  dirpath: Path to the directory for saving checkpoints.
315
315
  ckpt_cfg: Checkpoint configuration (filename, monitor, mode, save_top_k).
316
- save_weights_only: If True, only model weights are saved without optimizer and lr-scheduler state.
317
- Defaults to True.
318
316
  registry_uploader_callback: Optional callback for uploading checkpoints to a remote registry.
319
317
  Must be specified together with uploading_strategy.
320
318
  uploading_strategy: Checkpoint upload mode:
321
319
  - "only-best": only the best checkpoint is uploaded
322
320
  - "every-checkpoint": every saved checkpoint is uploaded
323
321
  Must be specified together with registry_uploader_callback.
322
+ remove_folder_if_exists: If True, removes existing checkpoint directory before creating a new one.
324
323
 
325
324
  Returns:
326
325
  ModelCheckpointWithRegistryUploader if registry_uploader_callback is provided,
@@ -331,7 +330,7 @@ def setup_checkpoint_callback(
331
330
 
332
331
  Note:
333
332
  If the dirpath directory already exists, it will be removed and recreated
334
- (only on the main process in distributed training).
333
+ (only on the main process in distributed training) if remove_folder_if_exists is True.
335
334
 
336
335
  """
337
336
  if (registry_uploader_callback is None) != (uploading_strategy is None):
@@ -342,8 +341,9 @@ def setup_checkpoint_callback(
342
341
  if dirpath.exists():
343
342
  if is_main_process():
344
343
  logger.warning(f"Checkpoint directory {dirpath} already exists.")
345
- rmtree(dirpath)
346
- logger.warning(f"Removed existing checkpoint directory {dirpath}.")
344
+ if remove_folder_if_exists:
345
+ rmtree(dirpath)
346
+ logger.warning(f"Removed existing checkpoint directory {dirpath}.")
347
347
  else:
348
348
  logger.info(f"Creating checkpoint directory {dirpath}.")
349
349
  dirpath.mkdir(parents=True, exist_ok=True)
@@ -356,7 +356,7 @@ def setup_checkpoint_callback(
356
356
  monitor=ckpt_cfg.monitor,
357
357
  mode=ckpt_cfg.mode,
358
358
  verbose=True,
359
- save_weights_only=save_weights_only,
359
+ save_weights_only=ckpt_cfg.save_weights_only,
360
360
  registry_uploader_callback=registry_uploader_callback,
361
361
  uploading_mode=uploading_strategy,
362
362
  )
@@ -368,6 +368,6 @@ def setup_checkpoint_callback(
368
368
  monitor=ckpt_cfg.monitor,
369
369
  mode=ckpt_cfg.mode,
370
370
  verbose=True,
371
- save_weights_only=save_weights_only,
371
+ save_weights_only=ckpt_cfg.save_weights_only,
372
372
  )
373
373
  return checkpoint_callback
@@ -0,0 +1,58 @@
1
+ from typing import cast
2
+
3
+ import lightning as L
4
+ import torch.distributed as dist
5
+ from torch.distributed import ProcessGroup
6
+
7
+ from kostyl.ml.configs import DDPStrategyConfig
8
+ from kostyl.ml.configs import FSDP1StrategyConfig
9
+ from kostyl.ml.configs import SingleDeviceStrategyConfig
10
+ from kostyl.utils.logging import setup_logger
11
+
12
+
13
+ TRAINING_STRATEGIES = (
14
+ FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
15
+ )
16
+
17
+ logger = setup_logger(add_rank=True)
18
+
19
+
20
+ def estimate_total_steps(
21
+ trainer: L.Trainer, dp_process_group: ProcessGroup | None = None
22
+ ) -> int:
23
+ """
24
+ Estimates the total number of training steps with respect to data parallelism and gradient accumulation.
25
+
26
+ Args:
27
+ trainer: The PyTorch Lightning Trainer instance.
28
+ dp_process_group: The data parallel process group. If None, the world process group will be used.
29
+
30
+ """
31
+ if dist.is_initialized():
32
+ world_size = dist.get_world_size(dp_process_group)
33
+ else:
34
+ world_size = 1
35
+
36
+ datamodule = trainer.datamodule # type: ignore
37
+ if datamodule is None:
38
+ raise ValueError("Trainer must have a datamodule to estimate total steps.")
39
+ datamodule = cast(L.LightningDataModule, datamodule)
40
+
41
+ logger.info("Loading `train_dataloader` to estimate number of stepping batches.")
42
+ datamodule.setup("fit")
43
+
44
+ dataloader_len = len(datamodule.train_dataloader())
45
+ steps_per_epoch = dataloader_len // trainer.accumulate_grad_batches // world_size
46
+
47
+ if trainer.max_epochs is None:
48
+ raise ValueError("Trainer must have `max_epochs` set to estimate total steps.")
49
+ total_steps = steps_per_epoch * trainer.max_epochs
50
+
51
+ logger.info(
52
+ f"Total steps: {total_steps} (per-epoch: {steps_per_epoch}) "
53
+ f"-> Dataloader len: {dataloader_len} "
54
+ f"-> Accumulate grad batches: {trainer.accumulate_grad_batches} "
55
+ f"-> Epochs: {trainer.max_epochs} "
56
+ f"-> DataParallel size: {world_size}"
57
+ )
58
+ return total_steps
@@ -1,13 +1,12 @@
1
1
  from abc import ABC
2
2
  from abc import abstractmethod
3
3
  from collections.abc import Callable
4
+ from functools import partial
4
5
  from pathlib import Path
5
6
  from typing import override
6
7
 
7
8
  from clearml import OutputModel
8
9
 
9
- from kostyl.ml.clearml.logging_utils import find_version_in_tags
10
- from kostyl.ml.clearml.logging_utils import increment_version
11
10
  from kostyl.utils.logging import setup_logger
12
11
 
13
12
 
@@ -28,51 +27,79 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
28
27
 
29
28
  def __init__(
30
29
  self,
31
- output_model: OutputModel,
30
+ model_name: str,
32
31
  config_dict: dict[str, str] | None = None,
32
+ label_enumeration: dict[str, int] | None = None,
33
+ tags: list[str] | None = None,
34
+ comment: str | None = None,
35
+ framework: str | None = None,
36
+ base_model_id: str | None = None,
37
+ new_model_per_upload: bool = True,
33
38
  verbose: bool = True,
34
- enable_tag_versioning: bool = False,
35
39
  ) -> None:
36
40
  """
37
41
  Initializes the ClearMLRegistryUploaderCallback.
38
42
 
39
43
  Args:
40
- output_model: ClearML OutputModel instance representing the model to upload.
41
- verbose: Whether to log messages during upload.
44
+ model_name: The name for the newly created model.
45
+ label_enumeration: The label enumeration dictionary of string (label) to integer (value) pairs.
42
46
  config_dict: Optional configuration dictionary to associate with the model.
43
- enable_tag_versioning: Whether to enable versioning in tags. If True,
44
- the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
47
+ tags: A list of strings which are tags for the model.
48
+ comment: A comment / description for the model.
49
+ framework: The framework of the model (e.g., "PyTorch", "TensorFlow").
50
+ base_model_id: Optional ClearML model ID to use as a base for the new model
51
+ new_model_per_upload: Whether to create a new ClearML model
52
+ for every upload or update weights of the same model. When updating weights,
53
+ the last uploaded checkpoint will be replaced (and deleted).
54
+ verbose: Whether to log messages during upload.
45
55
 
46
56
  """
47
57
  super().__init__()
48
- self.output_model = output_model
49
- self.config_dict = config_dict
50
- self.verbose = verbose
51
- self.enable_tag_versioning = enable_tag_versioning
58
+ if base_model_id is not None and new_model_per_upload:
59
+ raise ValueError(
60
+ "Cannot set base_model_id when new_model_per_upload is True."
61
+ )
52
62
 
63
+ self.verbose = verbose
64
+ self.new_model_per_upload = new_model_per_upload
53
65
  self.best_model_path: str = ""
54
-
66
+ self.config_dict = config_dict
67
+ self._output_model: OutputModel | None = None
55
68
  self._last_uploaded_model_path: str = ""
56
69
  self._upload_callback: Callable | None = None
57
70
 
58
- self._validate_tags()
71
+ self._validate_tags(tags)
72
+ self.model_fabric = partial(
73
+ OutputModel,
74
+ name=model_name,
75
+ label_enumeration=label_enumeration,
76
+ tags=tags,
77
+ comment=comment,
78
+ framework=framework,
79
+ base_model_id=base_model_id,
80
+ )
59
81
  return
60
82
 
61
- def _validate_tags(self) -> None:
62
- output_model_tags = self.output_model.tags or []
63
- if self.enable_tag_versioning:
64
- version = find_version_in_tags(output_model_tags)
65
- if version is None:
66
- output_model_tags.append("v1.0")
67
- else:
68
- new_version = increment_version(version)
69
- output_model_tags.remove(version)
70
- output_model_tags.append(new_version)
71
- if "LightningCheckpoint" not in output_model_tags:
72
- output_model_tags.append("LightningCheckpoint")
73
- self.output_model.tags = output_model_tags
83
+ @staticmethod
84
+ def _validate_tags(tags: list[str] | None) -> None:
85
+ if tags is None:
86
+ return
87
+ if "LightningCheckpoint" not in tags:
88
+ tags.append("LightningCheckpoint")
74
89
  return None
75
90
 
91
+ @property
92
+ def output_model_(self) -> OutputModel:
93
+ """Returns the OutputModel instance based on `new_model_per_upload` setting."""
94
+ if self.new_model_per_upload:
95
+ model = self.model_fabric()
96
+ self._output_model = self.model_fabric()
97
+ else:
98
+ if self._output_model is None:
99
+ self._output_model = self.model_fabric()
100
+ model = self._output_model
101
+ return model
102
+
76
103
  @override
77
104
  def upload_checkpoint(
78
105
  self,
@@ -88,12 +115,12 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
88
115
  if self.verbose:
89
116
  logger.info(f"Uploading model from {path}")
90
117
 
91
- self.output_model.update_weights(
118
+ self.output_model_.update_weights(
92
119
  path,
93
120
  auto_delete_file=False,
94
121
  async_enable=False,
95
122
  )
96
- self.output_model.update_design(config_dict=self.config_dict)
123
+ self.output_model_.update_design(config_dict=self.config_dict)
97
124
 
98
125
  self._last_uploaded_model_path = path
99
126
  return
@@ -1,6 +1,18 @@
1
1
  from .composite import CompositeScheduler
2
2
  from .cosine import CosineParamScheduler
3
3
  from .cosine import CosineScheduler
4
+ from .cosine_with_plateu import CosineWithPlateauParamScheduler
5
+ from .cosine_with_plateu import CosineWithPlateuScheduler
6
+ from .linear import LinearParamScheduler
7
+ from .linear import LinearScheduler
4
8
 
5
9
 
6
- __all__ = ["CompositeScheduler", "CosineParamScheduler", "CosineScheduler"]
10
+ __all__ = [
11
+ "CompositeScheduler",
12
+ "CosineParamScheduler",
13
+ "CosineScheduler",
14
+ "CosineWithPlateauParamScheduler",
15
+ "CosineWithPlateuScheduler",
16
+ "LinearParamScheduler",
17
+ "LinearScheduler",
18
+ ]
@@ -6,18 +6,20 @@ from typing import Any
6
6
  class BaseScheduler(ABC):
7
7
  """Base class for learning rate schedulers."""
8
8
 
9
+ @abstractmethod
9
10
  def state_dict(self) -> dict[str, Any]:
10
11
  """Get the state as a state dictionary."""
11
- return {
12
- key: value
13
- for key, value in self.__dict__.items()
14
- if key not in ["optimizer", "scheduler_values"]
15
- }
12
+ raise NotImplementedError
16
13
 
14
+ @abstractmethod
17
15
  def load_state_dict(self, state_dict: dict[str, Any]) -> None:
18
16
  """Load the state from a state dictionary."""
19
- self.__dict__.update(state_dict)
20
- return
17
+ raise NotImplementedError
18
+
19
+ @abstractmethod
20
+ def _verify(self) -> None:
21
+ """Verify the scheduler configuration."""
22
+ raise NotImplementedError
21
23
 
22
24
  def __getstate__(self) -> dict[str, Any]:
23
25
  """Get the state for pickling."""
@@ -2,7 +2,6 @@ from typing import Any
2
2
  from typing import override
3
3
 
4
4
  import numpy as np
5
- import numpy.typing as npt
6
5
  import torch
7
6
 
8
7
  from .base import BaseScheduler
@@ -29,18 +28,24 @@ class _CosineSchedulerCore(BaseScheduler):
29
28
  if freeze_ratio is not None:
30
29
  if not (0 < freeze_ratio < 1):
31
30
  raise ValueError(f"Freeze ratio must be in (0, 1), got {freeze_ratio}.")
31
+ pre_annealing_ratio = (warmup_ratio if warmup_ratio is not None else 0) + (
32
+ freeze_ratio if freeze_ratio is not None else 0
33
+ )
34
+ if pre_annealing_ratio > 1:
35
+ raise ValueError(
36
+ "The sum of warmup_ratio and freeze_ratio must <= 1, got "
37
+ f"{pre_annealing_ratio}."
38
+ )
32
39
 
33
40
  self.param_name = param_name
34
41
  self.num_iters = num_iters
35
42
  self.base_value = base_value
36
43
  self.final_value = final_value
37
-
38
44
  self.warmup_ratio = warmup_ratio
39
45
  self.warmup_value = warmup_value
40
-
41
46
  self.freeze_ratio = freeze_ratio
42
47
 
43
- self.scheduler_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
48
+ self.scheduled_values: np.ndarray = np.array([], dtype=np.float64)
44
49
  self.current_value_ = self.base_value
45
50
  return
46
51
 
@@ -63,31 +68,29 @@ class _CosineSchedulerCore(BaseScheduler):
63
68
  warmup_iters = 0
64
69
  warmup_schedule = np.array([], dtype=np.float64)
65
70
 
71
+ # Create cosine annealing schedule
66
72
  cosine_annealing_iters = self.num_iters - warmup_iters - freeze_iters
67
- if cosine_annealing_iters <= 0:
68
- raise ValueError("Cosine annealing iters must be > 0.")
69
-
70
- # Create cosine schedule
71
- iters = np.arange(cosine_annealing_iters)
72
- schedule = self.final_value + 0.5 * (self.base_value - self.final_value) * (
73
- 1 + np.cos(np.pi * iters / len(iters))
74
- )
73
+ if cosine_annealing_iters > 0:
74
+ iters = np.arange(cosine_annealing_iters)
75
+ cosine_annealing_schedule = self.final_value + 0.5 * (
76
+ self.base_value - self.final_value
77
+ ) * (1 + np.cos(np.pi * iters / len(iters)))
78
+ else:
79
+ cosine_annealing_schedule = np.array([], dtype=np.float64)
75
80
 
76
81
  # Concatenate all parts of the schedule
77
- self.scheduler_values = np.concatenate(
78
- (freeze_schedule, warmup_schedule, schedule)
82
+ self.scheduled_values = np.concatenate(
83
+ (freeze_schedule, warmup_schedule, cosine_annealing_schedule)
79
84
  )
80
-
81
- if len(self.scheduler_values) != self.num_iters:
82
- raise ValueError(
83
- f"Scheduler length ({len(self.scheduler_values)}) does not match num_iters ({self.num_iters})."
84
- )
85
+ self._verify()
85
86
  return
86
87
 
87
88
  @override
88
- def load_state_dict(self, state_dict: dict[str, Any]) -> None:
89
- super().load_state_dict(state_dict)
90
- self.scheduler_values = np.array([], dtype=np.float64)
89
+ def _verify(self) -> None:
90
+ if len(self.scheduled_values) != self.num_iters:
91
+ raise ValueError(
92
+ f"Scheduler length ({len(self.scheduled_values)}) does not match num_iters ({self.num_iters})."
93
+ )
91
94
  return
92
95
 
93
96
  @override
@@ -95,13 +98,13 @@ class _CosineSchedulerCore(BaseScheduler):
95
98
  raise NotImplementedError
96
99
 
97
100
  def _get_value(self, it: int) -> float:
98
- if len(self.scheduler_values) == 0:
101
+ if len(self.scheduled_values) == 0:
99
102
  self._create_scheduler()
100
103
 
101
104
  if it >= self.num_iters:
102
105
  value: float = self.final_value
103
106
  else:
104
- value: float = self.scheduler_values[it]
107
+ value: float = self.scheduled_values[it]
105
108
  self.current_value_ = value
106
109
  return value
107
110
 
@@ -163,6 +166,21 @@ class CosineScheduler(_CosineSchedulerCore):
163
166
  self.param_group_field = param_group_field
164
167
  return
165
168
 
169
+ @override
170
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
171
+ self.__dict__.update(state_dict)
172
+ self.scheduled_values = np.array([], dtype=np.float64)
173
+ return
174
+
175
+ @override
176
+ def state_dict(self) -> dict[str, Any]:
177
+ state = {
178
+ k: v
179
+ for k, v in self.__dict__.items()
180
+ if k not in ["scheduled_values", "optimizer"]
181
+ }
182
+ return state
183
+
166
184
  @override
167
185
  def step(self, it: int) -> None:
168
186
  value = self._get_value(it)
@@ -209,3 +227,14 @@ class CosineParamScheduler(_CosineSchedulerCore):
209
227
  """
210
228
  value = self._get_value(it)
211
229
  return value
230
+
231
+ @override
232
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
233
+ self.__dict__.update(state_dict)
234
+ self.scheduled_values = np.array([], dtype=np.float64)
235
+ return
236
+
237
+ @override
238
+ def state_dict(self) -> dict[str, Any]:
239
+ state = {k: v for k, v in self.__dict__.items() if k != "scheduled_values"}
240
+ return state
@@ -0,0 +1,277 @@
1
+ from typing import Any
2
+ from typing import override
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from .base import BaseScheduler
8
+
9
+
10
+ class _CosineWithPlateauSchedulerCore(BaseScheduler):
11
+ """Core cosine with plateau scheduler logic."""
12
+
13
+ def __init__(
14
+ self,
15
+ param_name: str,
16
+ num_iters: int,
17
+ base_value: float,
18
+ final_value: float,
19
+ plateau_ratio: float,
20
+ warmup_value: float | None = None,
21
+ warmup_ratio: float | None = None,
22
+ freeze_ratio: float | None = None,
23
+ ) -> None:
24
+ if warmup_ratio is not None:
25
+ if not (0 < warmup_ratio < 1):
26
+ raise ValueError(f"Warmup ratio must be in (0, 1), got {warmup_ratio}.")
27
+ if (warmup_value is None) != (warmup_ratio is None):
28
+ raise ValueError(
29
+ "Both warmup_ratio and warmup_value must be provided or neither."
30
+ )
31
+ if freeze_ratio is not None:
32
+ if not (0 < freeze_ratio < 1):
33
+ raise ValueError(f"Freeze ratio must be in (0, 1), got {freeze_ratio}.")
34
+ if not (0 < plateau_ratio < 1):
35
+ raise ValueError(f"Plateau ratio must be in (0, 1), got {plateau_ratio}.")
36
+
37
+ pre_annealing_ratio = (
38
+ plateau_ratio
39
+ + (warmup_ratio if warmup_ratio is not None else 0)
40
+ + (freeze_ratio if freeze_ratio is not None else 0)
41
+ )
42
+ if pre_annealing_ratio > 1:
43
+ raise ValueError(
44
+ "The sum of plateau_ratio, warmup_ratio, and freeze_ratio must <= 1, got "
45
+ f"{pre_annealing_ratio}."
46
+ )
47
+
48
+ self.param_name = param_name
49
+ self.num_iters = num_iters
50
+ self.base_value = base_value
51
+ self.final_value = final_value
52
+ self.cosine_annealing_ratio = 1 - pre_annealing_ratio
53
+ self.plateau_ratio = plateau_ratio
54
+ self.warmup_ratio = warmup_ratio
55
+ self.warmup_value = warmup_value
56
+ self.freeze_ratio = freeze_ratio
57
+
58
+ self.scheduled_values: np.ndarray = np.array([], dtype=np.float64)
59
+ self.current_value_ = self.base_value
60
+ return
61
+
62
+ def _create_scheduler(self) -> None:
63
+ # Create freeze schedule
64
+ if self.freeze_ratio is not None:
65
+ freeze_iters = int(self.num_iters * self.freeze_ratio)
66
+ freeze_schedule = np.zeros(freeze_iters, dtype=np.float64)
67
+ else:
68
+ freeze_iters = 0
69
+ freeze_schedule = np.array([], dtype=np.float64)
70
+
71
+ # Create linear warmup schedule
72
+ if self.warmup_ratio is not None and self.warmup_value is not None:
73
+ warmup_iters = int(self.num_iters * self.warmup_ratio)
74
+ warmup_schedule = np.linspace(
75
+ self.warmup_value, self.base_value, warmup_iters, dtype=np.float64
76
+ )
77
+ else:
78
+ warmup_iters = 0
79
+ warmup_schedule = np.array([], dtype=np.float64)
80
+
81
+ # Create cosine annealing schedule
82
+ if self.cosine_annealing_ratio > 0:
83
+ cosine_annealing_iters = int(self.num_iters * self.cosine_annealing_ratio)
84
+ iters = np.arange(cosine_annealing_iters)
85
+ cosine_annealing_schedule = self.final_value + 0.5 * (
86
+ self.base_value - self.final_value
87
+ ) * (1 + np.cos(np.pi * iters / len(iters)))
88
+ else:
89
+ cosine_annealing_iters = 0
90
+ cosine_annealing_schedule = np.array([], dtype=np.float64)
91
+
92
+ plateau_iters = (
93
+ self.num_iters - warmup_iters - freeze_iters - cosine_annealing_iters
94
+ )
95
+ if plateau_iters > 0:
96
+ plateau_schedule = np.full(plateau_iters, self.base_value, dtype=np.float64)
97
+ else:
98
+ plateau_schedule = np.array([], dtype=np.float64)
99
+
100
+ # Concatenate all parts of the schedule
101
+ self.scheduled_values = np.concatenate(
102
+ (
103
+ freeze_schedule,
104
+ warmup_schedule,
105
+ plateau_schedule,
106
+ cosine_annealing_schedule,
107
+ )
108
+ )
109
+ self._verify()
110
+ return
111
+
112
+ @override
113
+ def _verify(self) -> None:
114
+ if len(self.scheduled_values) != self.num_iters:
115
+ raise ValueError(
116
+ f"Scheduler length ({len(self.scheduled_values)}) does not match num_iters ({self.num_iters})."
117
+ )
118
+ return
119
+
120
+ @override
121
+ def step(self, it: int) -> None | float:
122
+ raise NotImplementedError
123
+
124
+ def _get_value(self, it: int) -> float:
125
+ if len(self.scheduled_values) == 0:
126
+ self._create_scheduler()
127
+
128
+ if it >= self.num_iters:
129
+ value: float = self.final_value
130
+ else:
131
+ value: float = self.scheduled_values[it]
132
+ self.current_value_ = value
133
+ return value
134
+
135
+ @override
136
+ def current_value(self) -> dict[str, float]:
137
+ return {self.param_name: self.current_value_}
138
+
139
+
140
+ class CosineWithPlateuScheduler(_CosineWithPlateauSchedulerCore):
141
+ """
142
+ Applies a cosine schedule with plateau to an optimizer param-group field.
143
+
144
+ Schedule phases: freeze (0) → warmup → plateau (base_value) → cosine annealing to final_value.
145
+ The plateau phase maintains the base_value before cosine annealing begins.
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ optimizer: torch.optim.Optimizer,
151
+ param_group_field: str,
152
+ num_iters: int,
153
+ base_value: float,
154
+ final_value: float,
155
+ plateau_ratio: float,
156
+ warmup_value: float | None = None,
157
+ warmup_ratio: float | None = None,
158
+ freeze_ratio: float | None = None,
159
+ multiplier_field: str | None = None,
160
+ skip_if_zero: bool = False,
161
+ apply_if_field: str | None = None,
162
+ ignore_if_field: str | None = None,
163
+ ) -> None:
164
+ """
165
+ Configure cosine scheduling for matching optimizer groups.
166
+
167
+ Args:
168
+ optimizer: Optimizer whose param groups are updated in-place.
169
+ param_group_field: Name of the field that receives the scheduled value.
170
+ num_iters: Number of scheduler iterations before clamping at ``final_value``.
171
+ base_value: Value maintained during plateau phase and used as cosine start.
172
+ final_value: Value approached as iterations progress during cosine annealing.
173
+ plateau_ratio: Fraction of iterations to maintain ``base_value`` before cosine annealing.
174
+ warmup_ratio: Optional fraction of iterations to linearly ramp from ``warmup_value`` to ``base_value``.
175
+ warmup_value: Starting value for the warmup ramp.
176
+ freeze_ratio: Optional fraction of iterations to keep the value frozen at zero at the beginning.
177
+ multiplier_field: Optional per-group multiplier applied to the scheduled value.
178
+ skip_if_zero: Leave groups untouched when their target field equals zero.
179
+ apply_if_field: Require this flag to be present in a param group before updating.
180
+ ignore_if_field: Skip groups that declare this flag.
181
+
182
+ """
183
+ self.apply_if_field = apply_if_field
184
+ self.ignore_if_field = ignore_if_field
185
+ self.optimizer = optimizer
186
+ self.multiplier_field = multiplier_field
187
+ self.skip_if_zero = skip_if_zero
188
+ super().__init__(
189
+ param_name=param_group_field,
190
+ num_iters=num_iters,
191
+ base_value=base_value,
192
+ final_value=final_value,
193
+ plateau_ratio=plateau_ratio,
194
+ warmup_ratio=warmup_ratio,
195
+ warmup_value=warmup_value,
196
+ freeze_ratio=freeze_ratio,
197
+ )
198
+ self.param_group_field = param_group_field
199
+ return
200
+
201
+ @override
202
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
203
+ self.__dict__.update(state_dict)
204
+ self.scheduled_values = np.array([], dtype=np.float64)
205
+ return
206
+
207
+ @override
208
+ def state_dict(self) -> dict[str, Any]:
209
+ state = {
210
+ k: v
211
+ for k, v in self.__dict__.items()
212
+ if k not in ["scheduled_values", "optimizer"]
213
+ }
214
+ return state
215
+
216
+ @override
217
+ def step(self, it: int) -> None:
218
+ value = self._get_value(it)
219
+ for pg in self.optimizer.param_groups:
220
+ if self.param_group_field not in pg:
221
+ raise ValueError(
222
+ f"Parameter group field '{self.param_group_field}' not found in optimizer parameter groups."
223
+ )
224
+
225
+ if (self.apply_if_field is not None) and (self.apply_if_field not in pg):
226
+ continue
227
+
228
+ if (self.ignore_if_field is not None) and (self.ignore_if_field in pg):
229
+ continue
230
+
231
+ if self.skip_if_zero and pg[self.param_group_field] == 0:
232
+ continue
233
+
234
+ if self.multiplier_field is not None:
235
+ if self.multiplier_field not in pg:
236
+ multiplier = 1.0
237
+ else:
238
+ multiplier = pg[self.multiplier_field]
239
+ pg[self.param_group_field] = value * multiplier
240
+ else:
241
+ pg[self.param_group_field] = value
242
+ return
243
+
244
+
245
+ class CosineWithPlateauParamScheduler(_CosineWithPlateauSchedulerCore):
246
+ """
247
+ Standalone cosine scheduler with plateau for non-optimizer parameters.
248
+
249
+ Schedule phases: freeze (0) → warmup → plateau (base_value) → cosine annealing to final_value.
250
+ The plateau phase maintains the base_value before cosine annealing begins.
251
+ """
252
+
253
+ @override
254
+ def step(self, it: int) -> float:
255
+ """
256
+ Computes the value corresponding to the given iteration step.
257
+
258
+ Args:
259
+ it: The current iteration index used for value computation.
260
+
261
+ Returns:
262
+ The computed value for the provided iteration step as a float.
263
+
264
+ """
265
+ value = self._get_value(it)
266
+ return value
267
+
268
+ @override
269
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
270
+ self.__dict__.update(state_dict)
271
+ self.scheduled_values = np.array([], dtype=np.float64)
272
+ return
273
+
274
+ @override
275
+ def state_dict(self) -> dict[str, Any]:
276
+ state = {k: v for k, v in self.__dict__.items() if k != "scheduled_values"}
277
+ return state
@@ -21,24 +21,23 @@ class _LinearScheduleBase(BaseScheduler):
21
21
  self.start_value = start_value
22
22
  self.final_value = final_value
23
23
 
24
- self.scheduler_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
24
+ self.scheduled_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
25
25
  self.current_value_ = self.start_value
26
26
  return
27
27
 
28
28
  def _create_scheduler(self) -> None:
29
- self.scheduler_values = np.linspace(
29
+ self.scheduled_values = np.linspace(
30
30
  self.start_value, self.final_value, num=self.num_iters, dtype=np.float64
31
31
  )
32
- if len(self.scheduler_values) != self.num_iters:
33
- raise ValueError(
34
- f"Scheduler length ({len(self.scheduler_values)}) does not match total_iters ({self.num_iters})."
35
- )
32
+ self._verify()
36
33
  return
37
34
 
38
35
  @override
39
- def load_state_dict(self, state_dict: dict[str, Any]) -> None:
40
- super().load_state_dict(state_dict)
41
- self.scheduler_values = np.array([], dtype=np.float64)
36
+ def _verify(self) -> None:
37
+ if len(self.scheduled_values) != self.num_iters:
38
+ raise ValueError(
39
+ f"Scheduler length ({len(self.scheduled_values)}) does not match total_iters ({self.num_iters})."
40
+ )
42
41
  return
43
42
 
44
43
  @override
@@ -46,13 +45,13 @@ class _LinearScheduleBase(BaseScheduler):
46
45
  raise NotImplementedError
47
46
 
48
47
  def _get_value(self, it: int) -> float:
49
- if len(self.scheduler_values) == 0:
48
+ if len(self.scheduled_values) == 0:
50
49
  self._create_scheduler()
51
50
 
52
51
  if it >= self.num_iters:
53
52
  value: float = self.final_value
54
53
  else:
55
- value: float = self.scheduler_values[it]
54
+ value: float = self.scheduled_values[it]
56
55
  self.current_value_ = value
57
56
  return value
58
57
 
@@ -105,6 +104,21 @@ class LinearScheduler(_LinearScheduleBase):
105
104
  self.param_group_field = param_group_field
106
105
  return
107
106
 
107
+ @override
108
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
109
+ self.__dict__.update(state_dict)
110
+ self.scheduled_values = np.array([], dtype=np.float64)
111
+ return
112
+
113
+ @override
114
+ def state_dict(self) -> dict[str, Any]:
115
+ state = {
116
+ k: v
117
+ for k, v in self.__dict__.items()
118
+ if k not in ["scheduled_values", "optimizer"]
119
+ }
120
+ return state
121
+
108
122
  @override
109
123
  def step(self, it: int) -> None:
110
124
  value = self._get_value(it)
@@ -137,6 +151,17 @@ class LinearScheduler(_LinearScheduleBase):
137
151
  class LinearParamScheduler(_LinearScheduleBase):
138
152
  """LinearParamScheduler adjusts a parameter value using a linear scheduler."""
139
153
 
154
+ @override
155
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
156
+ self.__dict__.update(state_dict)
157
+ self.scheduled_values = np.array([], dtype=np.float64)
158
+ return
159
+
160
+ @override
161
+ def state_dict(self) -> dict[str, Any]:
162
+ state = {k: v for k, v in self.__dict__.items() if k != "scheduled_values"}
163
+ return state
164
+
140
165
  @override
141
166
  def step(self, it: int) -> float:
142
167
  """
kostyl/utils/logging.py CHANGED
@@ -94,7 +94,7 @@ _PRESETS = {"default": _DEFAULT_FMT, "only_message": _ONLY_MESSAGE_FMT}
94
94
 
95
95
  def setup_logger(
96
96
  name: str | None = None,
97
- fmt: Literal["default", "only_message"] | str = "default",
97
+ fmt: Literal["default", "only_message"] | str = "only_message",
98
98
  level: str = "INFO",
99
99
  add_rank: bool | None = None,
100
100
  sink=sys.stdout,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.35
3
+ Version: 0.1.36
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -6,32 +6,33 @@ kostyl/ml/clearml/logging_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWa
6
6
  kostyl/ml/clearml/pulling_utils.py,sha256=jMlVXcYRumwWnPlELRlgEdfq5L6Wir_EcfTmOoWBLTA,4077
7
7
  kostyl/ml/configs/__init__.py,sha256=IetcivbqYGutowLqxdKp7QR4tkXKBr4m8t4Zkk9jHZU,911
8
8
  kostyl/ml/configs/base_model.py,sha256=Eofn14J9RsjpVx_J4rp6C19pDDCANU4hr3JtX-d0FpQ,4820
9
- kostyl/ml/configs/hyperparams.py,sha256=2S_VEZ07RWquNFSWjHBb3OUpBlTznbUpFSchzMpSBOc,2879
10
- kostyl/ml/configs/training_settings.py,sha256=Sq2tiRuwkbmi9zKDG2JghZLXo5DDt_eQqN_KYJSdcTY,2509
9
+ kostyl/ml/configs/hyperparams.py,sha256=lvtbvOFEoTBAJug7FR35xMQdPLgDQjRoP2fyDP-jD7E,3305
10
+ kostyl/ml/configs/training_settings.py,sha256=wT9CHuLaKrLwonsc87Ee421EyFis_c9fqOgn9bSClm8,2747
11
11
  kostyl/ml/data_processing_utils.py,sha256=jjEjV0S0wREgZkzg27ip0LpI8cQqkwe2QwATmAqm9-g,3832
12
12
  kostyl/ml/dist_utils.py,sha256=Onf0KHVLA8oeUgZTcTdmR9qiM22f2uYLoNwgLbMGJWk,3495
13
13
  kostyl/ml/lightning/__init__.py,sha256=R36PImjVvzBF9t_z9u6RYVnUFJJ-sNDUOdboWUojHmM,173
14
14
  kostyl/ml/lightning/callbacks/__init__.py,sha256=EnKkNwwNDZnEqKRlpY4FVrqP88ECPF6nlT2bSLUIKRk,194
15
- kostyl/ml/lightning/callbacks/checkpoint.py,sha256=sZ9OqudO-gXp7FqtWaOH46TXVpeCJxV-EowyRPN836k,18983
15
+ kostyl/ml/lightning/callbacks/checkpoint.py,sha256=COW7WErj4EMxJNMn97WQO-G2A3LbI6GQOCpIZu3Cblk,19060
16
16
  kostyl/ml/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
17
17
  kostyl/ml/lightning/extensions/__init__.py,sha256=OY6QGv1agYgqqKf1xJBrxgp_i8FunVfPzYezfaRrGXU,182
18
18
  kostyl/ml/lightning/extensions/custom_module.py,sha256=iQrnPz-WTmRfvLo94C5fQc2Qwa1IpHtUh1sCpVwTSFM,6602
19
19
  kostyl/ml/lightning/extensions/pretrained_model.py,sha256=eRfQBzAjVernHl9A4PP5uTLvjjmcNKPdTu7ABFLq7HI,5196
20
20
  kostyl/ml/lightning/loggers/__init__.py,sha256=e51dszaoJbuzwBkbdugmuDsPldoSO4yaRgmZUg1Bdy0,71
21
21
  kostyl/ml/lightning/loggers/tb_logger.py,sha256=j02HK5ue8yzXXV8FWKmmXyHkFlIxgHx-ahHWk_rFCZs,893
22
- kostyl/ml/lightning/training_utils.py,sha256=u7X9ysF9Gqy8CdwacdcDlNQNsbagYAhslbv-1WLJ45k,9052
22
+ kostyl/ml/lightning/utils.py,sha256=imvMbgOKRtCUiiRGEcVtN-hxw-aEFKHdCWc0J_CIZp4,1980
23
23
  kostyl/ml/metrics_formatting.py,sha256=U6vdNENZLvp2dT1L3HqFKtXrHwGKoDXN93hvamPGHjM,1341
24
24
  kostyl/ml/params_groups.py,sha256=nUyw5d06Pvy9QPiYtZzLYR87xwXqJLxbHthgQH8oSCM,3583
25
- kostyl/ml/registry_uploader.py,sha256=W90TYo_WKv2oBE6nqEJl4hecYmJyyuKwQJ9_uUPGnJQ,3346
26
- kostyl/ml/schedulers/__init__.py,sha256=bxXbsU_WYnVbhvNNnuI7cOAh2Axz7D25TaleBTZhYfc,197
27
- kostyl/ml/schedulers/base.py,sha256=9M2iOoOVSRojR_liPX1qo3Nn4iMXSM5ZJuAFWZTulUk,1327
25
+ kostyl/ml/registry_uploader.py,sha256=BbyLXvF8AL145k7g6MRkJ7gf_3Um53p3Pn5280vVD9U,4384
26
+ kostyl/ml/schedulers/__init__.py,sha256=_EtZu8DwTCSv4-eR84kRstEZblHylVqda7WQUOXIKfw,534
27
+ kostyl/ml/schedulers/base.py,sha256=bjmwgdZpnSqpCnHPnKC6MEiRO79cwxMJpZq-eQVNs2M,1353
28
28
  kostyl/ml/schedulers/composite.py,sha256=ee4xlMDMMtjKPkbTF2ue9GTr9DuGCGjZWf11mHbi6aE,2387
29
- kostyl/ml/schedulers/cosine.py,sha256=t74_ByT22L5NQKpnBVU9UGzBVx1ZM2GTylb9ct3_PVg,7627
30
- kostyl/ml/schedulers/linear.py,sha256=7HPkVWcPa0lbaZywutXSDdVLLSihAyWk5XIE2Dzj_5Q,5168
29
+ kostyl/ml/schedulers/cosine.py,sha256=y8ylrgVOkVcr2-ExoqqNW--tdDX88TBYPQCOppIf2_M,8685
30
+ kostyl/ml/schedulers/cosine_with_plateu.py,sha256=0-X6wl3HgsTiLIbISb9lOxIVWXHDEND7rILitMWtIiM,10195
31
+ kostyl/ml/schedulers/linear.py,sha256=RnnnblRuRXP3LT03QVIHUaK2kNsiMP1AedrMoeyh3qk,5843
31
32
  kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
32
33
  kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
33
34
  kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
34
- kostyl/utils/logging.py,sha256=Vye0u4-yeOSUc-f03gpQbxSktTbFiilTWLEVr00ZHvc,5796
35
- kostyl_toolkit-0.1.35.dist-info/WHEEL,sha256=ZyFSCYkV2BrxH6-HRVRg3R9Fo7MALzer9KiPYqNxSbo,79
36
- kostyl_toolkit-0.1.35.dist-info/METADATA,sha256=KL4-Z421DpchI6KUZ6tVATy99urk1OP2OY4Uf5r9R3U,4269
37
- kostyl_toolkit-0.1.35.dist-info/RECORD,,
35
+ kostyl/utils/logging.py,sha256=LSbyQFLAIa89xPb4tcobE2BwVIHHUSaDXqOIKVzLoWs,5801
36
+ kostyl_toolkit-0.1.36.dist-info/WHEEL,sha256=eycQt0QpYmJMLKpE3X9iDk8R04v2ZF0x82ogq-zP6bQ,79
37
+ kostyl_toolkit-0.1.36.dist-info/METADATA,sha256=Lfyx6u3LKZ6co4s7GZgJp31zoy-NViriSGqwjIzOQFA,4269
38
+ kostyl_toolkit-0.1.36.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: uv 0.9.18
2
+ Generator: uv 0.9.24
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,241 +0,0 @@
1
- from dataclasses import dataclass
2
- from dataclasses import fields
3
- from pathlib import Path
4
- from typing import Literal
5
- from typing import cast
6
-
7
- import lightning as L
8
- import torch
9
- import torch.distributed as dist
10
- from clearml import OutputModel
11
- from clearml import Task
12
- from lightning.pytorch.callbacks import Callback
13
- from lightning.pytorch.callbacks import EarlyStopping
14
- from lightning.pytorch.callbacks import LearningRateMonitor
15
- from lightning.pytorch.callbacks import ModelCheckpoint
16
- from lightning.pytorch.loggers import TensorBoardLogger
17
- from lightning.pytorch.strategies import DDPStrategy
18
- from lightning.pytorch.strategies import FSDPStrategy
19
- from torch.distributed import ProcessGroup
20
- from torch.distributed.fsdp import MixedPrecision
21
- from torch.nn import Module
22
-
23
- from kostyl.ml.configs import CheckpointConfig
24
- from kostyl.ml.configs import DDPStrategyConfig
25
- from kostyl.ml.configs import EarlyStoppingConfig
26
- from kostyl.ml.configs import FSDP1StrategyConfig
27
- from kostyl.ml.configs import SingleDeviceStrategyConfig
28
- from kostyl.ml.lightning.callbacks import setup_checkpoint_callback
29
- from kostyl.ml.lightning.callbacks import setup_early_stopping_callback
30
- from kostyl.ml.lightning.loggers import setup_tb_logger
31
- from kostyl.ml.registry_uploader import ClearMLRegistryUploaderCallback
32
- from kostyl.utils.logging import setup_logger
33
-
34
-
35
- TRAINING_STRATEGIES = (
36
- FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
37
- )
38
-
39
- logger = setup_logger(add_rank=True)
40
-
41
-
42
- def estimate_total_steps(
43
- trainer: L.Trainer, process_group: ProcessGroup | None = None
44
- ) -> int:
45
- """
46
- Estimates the total number of training steps based on the
47
- dataloader length, accumulation steps, and distributed world size.
48
- """ # noqa: D205
49
- if dist.is_initialized():
50
- world_size = dist.get_world_size(process_group)
51
- else:
52
- world_size = 1
53
-
54
- datamodule = trainer.datamodule # type: ignore
55
- if datamodule is None:
56
- raise ValueError("Trainer must have a datamodule to estimate total steps.")
57
- datamodule = cast(L.LightningDataModule, datamodule)
58
-
59
- logger.info("Loading `train_dataloader` to estimate number of stepping batches.")
60
- datamodule.setup("fit")
61
-
62
- dataloader_len = len(datamodule.train_dataloader())
63
- steps_per_epoch = dataloader_len // trainer.accumulate_grad_batches // world_size
64
-
65
- if trainer.max_epochs is None:
66
- raise ValueError("Trainer must have `max_epochs` set to estimate total steps.")
67
- total_steps = steps_per_epoch * trainer.max_epochs
68
-
69
- logger.info(
70
- f"Total steps: {total_steps} (per-epoch: {steps_per_epoch})\n"
71
- f"-> Dataloader len: {dataloader_len}\n"
72
- f"-> Accumulate grad batches: {trainer.accumulate_grad_batches}\n"
73
- f"-> Epochs: {trainer.max_epochs}\n "
74
- f"-> World size: {world_size}"
75
- )
76
- return total_steps
77
-
78
-
79
- @dataclass
80
- class Callbacks:
81
- """Dataclass to hold PyTorch Lightning callbacks."""
82
-
83
- checkpoint: ModelCheckpoint
84
- lr_monitor: LearningRateMonitor
85
- early_stopping: EarlyStopping | None = None
86
-
87
- def to_list(self) -> list[Callback]:
88
- """Convert dataclass fields to a list of Callbacks. None values are omitted."""
89
- callbacks: list[Callback] = [
90
- getattr(self, field.name)
91
- for field in fields(self)
92
- if getattr(self, field.name) is not None
93
- ]
94
- return callbacks
95
-
96
-
97
- def setup_callbacks(
98
- task: Task,
99
- root_path: Path,
100
- checkpoint_cfg: CheckpointConfig,
101
- early_stopping_cfg: EarlyStoppingConfig | None,
102
- output_model: OutputModel,
103
- checkpoint_upload_strategy: Literal["only-best", "every-checkpoint"],
104
- config_dict: dict[str, str] | None = None,
105
- enable_tag_versioning: bool = False,
106
- ) -> Callbacks:
107
- """
108
- Set up PyTorch Lightning callbacks for training.
109
-
110
- Creates and configures a set of callbacks including checkpoint saving,
111
- learning rate monitoring, model registry uploading, and optional early stopping.
112
-
113
- Args:
114
- task: ClearML task for organizing checkpoints by task name and ID.
115
- root_path: Root directory for saving checkpoints.
116
- checkpoint_cfg: Configuration for checkpoint saving behavior.
117
- checkpoint_upload_strategy: Model upload strategy:
118
- - `"only-best"`: Upload only the best checkpoint based on monitored metric.
119
- - `"every-checkpoint"`: Upload every saved checkpoint.
120
- output_model: ClearML OutputModel instance for model registry integration.
121
- early_stopping_cfg: Configuration for early stopping. If None, early stopping
122
- is disabled.
123
- config_dict: Optional configuration dictionary to store with the model
124
- in the registry.
125
- enable_tag_versioning: Whether to auto-increment version tags (e.g., "v1.0")
126
- on the uploaded model.
127
-
128
- Returns:
129
- Callbacks dataclass containing configured checkpoint, lr_monitor,
130
- and optionally early_stopping callbacks.
131
-
132
- """
133
- lr_monitor = LearningRateMonitor(
134
- logging_interval="step", log_weight_decay=True, log_momentum=False
135
- )
136
- model_uploader = ClearMLRegistryUploaderCallback(
137
- output_model=output_model,
138
- config_dict=config_dict,
139
- verbose=True,
140
- enable_tag_versioning=enable_tag_versioning,
141
- )
142
- checkpoint_callback = setup_checkpoint_callback(
143
- root_path / "checkpoints" / task.name / task.id,
144
- checkpoint_cfg,
145
- registry_uploader_callback=model_uploader,
146
- uploading_strategy=checkpoint_upload_strategy,
147
- )
148
- if early_stopping_cfg is not None:
149
- early_stopping_callback = setup_early_stopping_callback(early_stopping_cfg)
150
- else:
151
- early_stopping_callback = None
152
-
153
- callbacks = Callbacks(
154
- checkpoint=checkpoint_callback,
155
- lr_monitor=lr_monitor,
156
- early_stopping=early_stopping_callback,
157
- )
158
- return callbacks
159
-
160
-
161
- def setup_loggers(task: Task, root_path: Path) -> list[TensorBoardLogger]:
162
- """
163
- Set up PyTorch Lightning loggers for training.
164
-
165
- Args:
166
- task: ClearML task used to organize log directories by task name and ID.
167
- root_path: Root directory for storing TensorBoard logs.
168
-
169
- Returns:
170
- List of configured TensorBoard loggers.
171
-
172
- """
173
- loggers = [
174
- setup_tb_logger(root_path / "runs" / task.name / task.id),
175
- ]
176
- return loggers
177
-
178
-
179
- def setup_strategy(
180
- strategy_settings: TRAINING_STRATEGIES,
181
- devices: list[int] | int,
182
- auto_wrap_policy: set[type[Module]] | None = None,
183
- ) -> Literal["auto"] | FSDPStrategy | DDPStrategy:
184
- """
185
- Configure and return a PyTorch Lightning training strategy.
186
-
187
- Args:
188
- strategy_settings: Strategy configuration object. Must be one of:
189
- - `FSDP1StrategyConfig`: Fully Sharded Data Parallel strategy (requires 2+ devices).
190
- - `DDPStrategyConfig`: Distributed Data Parallel strategy (requires 2+ devices).
191
- - `SingleDeviceStrategyConfig`: Single device training (requires exactly 1 device).
192
- devices: Device(s) to use for training. Either a list of device IDs or
193
- a single integer representing the number of devices.
194
- auto_wrap_policy: Set of module types that should be wrapped for FSDP.
195
- Required when using `FSDP1StrategyConfig`, ignored otherwise.
196
-
197
- Returns:
198
- Configured strategy: `FSDPStrategy`, `DDPStrategy`, or `"auto"` for single device.
199
-
200
- Raises:
201
- ValueError: If device count doesn't match strategy requirements or
202
- if `auto_wrap_policy` is missing for FSDP.
203
-
204
- """
205
- if isinstance(devices, list):
206
- num_devices = len(devices)
207
- else:
208
- num_devices = devices
209
-
210
- match strategy_settings:
211
- case FSDP1StrategyConfig():
212
- if num_devices < 2:
213
- raise ValueError("FSDP strategy requires multiple devices.")
214
-
215
- if auto_wrap_policy is None:
216
- raise ValueError("auto_wrap_policy must be provided for FSDP strategy.")
217
-
218
- mixed_precision_config = MixedPrecision(
219
- param_dtype=getattr(torch, strategy_settings.param_dtype),
220
- reduce_dtype=getattr(torch, strategy_settings.reduce_dtype),
221
- buffer_dtype=getattr(torch, strategy_settings.buffer_dtype),
222
- )
223
- strategy = FSDPStrategy(
224
- auto_wrap_policy=auto_wrap_policy,
225
- mixed_precision=mixed_precision_config,
226
- )
227
- case DDPStrategyConfig():
228
- if num_devices < 2:
229
- raise ValueError("DDP strategy requires at least two devices.")
230
- strategy = DDPStrategy(
231
- find_unused_parameters=strategy_settings.find_unused_parameters
232
- )
233
- case SingleDeviceStrategyConfig():
234
- if num_devices != 1:
235
- raise ValueError("SingleDevice strategy requires exactly one device.")
236
- strategy = "auto"
237
- case _:
238
- raise ValueError(
239
- f"Unsupported strategy type: {type(strategy_settings.trainer.strategy)}"
240
- )
241
- return strategy