kostyl-toolkit 0.1.36__py3-none-any.whl → 0.1.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. kostyl/ml/base_uploader.py +17 -0
  2. kostyl/ml/configs/__init__.py +2 -2
  3. kostyl/ml/configs/mixins.py +50 -0
  4. kostyl/ml/{data_processing_utils.py → data_collator.py} +6 -3
  5. kostyl/ml/dist_utils.py +53 -33
  6. kostyl/ml/integrations/clearml/__init__.py +7 -0
  7. kostyl/ml/{registry_uploader.py → integrations/clearml/checkpoint_uploader.py} +3 -13
  8. kostyl/ml/{configs/base_model.py → integrations/clearml/config_mixin.py} +7 -63
  9. kostyl/ml/{clearml/pulling_utils.py → integrations/clearml/loading_utils.py} +32 -5
  10. kostyl/ml/integrations/lightning/__init__.py +14 -0
  11. kostyl/ml/{lightning → integrations/lightning}/callbacks/checkpoint.py +27 -42
  12. kostyl/ml/{lightning → integrations/lightning}/loggers/tb_logger.py +2 -2
  13. kostyl/ml/{lightning/extensions/pretrained_model.py → integrations/lightning/mixins.py} +6 -4
  14. kostyl/ml/{lightning/extensions/custom_module.py → integrations/lightning/module.py} +2 -38
  15. kostyl/ml/{lightning → integrations/lightning}/utils.py +1 -1
  16. kostyl/ml/schedulers/__init__.py +4 -4
  17. kostyl/ml/schedulers/{cosine_with_plateu.py → plateau.py} +59 -36
  18. kostyl/utils/logging.py +67 -52
  19. {kostyl_toolkit-0.1.36.dist-info → kostyl_toolkit-0.1.38.dist-info}/METADATA +1 -1
  20. kostyl_toolkit-0.1.38.dist-info/RECORD +40 -0
  21. {kostyl_toolkit-0.1.36.dist-info → kostyl_toolkit-0.1.38.dist-info}/WHEEL +2 -2
  22. kostyl/ml/lightning/__init__.py +0 -5
  23. kostyl/ml/lightning/extensions/__init__.py +0 -5
  24. kostyl_toolkit-0.1.36.dist-info/RECORD +0 -38
  25. /kostyl/ml/{clearml → integrations}/__init__.py +0 -0
  26. /kostyl/ml/{clearml → integrations/clearml}/dataset_utils.py +0 -0
  27. /kostyl/ml/{clearml/logging_utils.py → integrations/clearml/version_utils.py} +0 -0
  28. /kostyl/ml/{lightning → integrations/lightning}/callbacks/__init__.py +0 -0
  29. /kostyl/ml/{lightning → integrations/lightning}/callbacks/early_stopping.py +0 -0
  30. /kostyl/ml/{lightning → integrations/lightning}/loggers/__init__.py +0 -0
  31. /kostyl/ml/{metrics_formatting.py → integrations/lightning/metrics_formatting.py} +0 -0
@@ -0,0 +1,17 @@
1
+ from abc import ABC
2
+ from abc import abstractmethod
3
+ from pathlib import Path
4
+
5
+ from kostyl.utils.logging import setup_logger
6
+
7
+
8
+ logger = setup_logger()
9
+
10
+
11
+ class ModelCheckpointUploader(ABC):
12
+ """Abstract base class for uploading model checkpoints to a registry backend."""
13
+
14
+ @abstractmethod
15
+ def upload_checkpoint(self, path: str | Path) -> None:
16
+ """Upload the checkpoint located at the given path to the configured registry backend."""
17
+ raise NotImplementedError
@@ -1,8 +1,8 @@
1
- from .base_model import KostylBaseModel
2
1
  from .hyperparams import HyperparamsConfig
3
2
  from .hyperparams import Lr
4
3
  from .hyperparams import Optimizer
5
4
  from .hyperparams import WeightDecay
5
+ from .mixins import ConfigLoadingMixin
6
6
  from .training_settings import CheckpointConfig
7
7
  from .training_settings import DataConfig
8
8
  from .training_settings import DDPStrategyConfig
@@ -15,12 +15,12 @@ from .training_settings import TrainingSettings
15
15
 
16
16
  __all__ = [
17
17
  "CheckpointConfig",
18
+ "ConfigLoadingMixin",
18
19
  "DDPStrategyConfig",
19
20
  "DataConfig",
20
21
  "EarlyStoppingConfig",
21
22
  "FSDP1StrategyConfig",
22
23
  "HyperparamsConfig",
23
- "KostylBaseModel",
24
24
  "LightningTrainerParameters",
25
25
  "Lr",
26
26
  "Optimizer",
@@ -0,0 +1,50 @@
1
+ from pathlib import Path
2
+
3
+ from pydantic import BaseModel as PydanticBaseModel
4
+
5
+ from kostyl.utils.fs import load_config
6
+
7
+
8
+ class ConfigLoadingMixin[TConfig: PydanticBaseModel]:
9
+ """Mixin providing configuration loading functionality for Pydantic models."""
10
+
11
+ @classmethod
12
+ def from_file(
13
+ cls: type[TConfig], # pyright: ignore
14
+ path: str | Path,
15
+ ) -> TConfig:
16
+ """
17
+ Create an instance of the class from a configuration file.
18
+
19
+ Args:
20
+ cls_: The class type to instantiate.
21
+ path (str | Path): Path to the configuration file.
22
+
23
+ Returns:
24
+ An instance of the class created from the configuration file.
25
+
26
+ """
27
+ config = load_config(path)
28
+ instance = cls.model_validate(config)
29
+ return instance
30
+
31
+ @classmethod
32
+ def from_dict(
33
+ cls: type[TConfig], # pyright: ignore
34
+ state_dict: dict,
35
+ ) -> TConfig:
36
+ """
37
+ Creates an instance from a dictionary.
38
+
39
+ Args:
40
+ cls_: The class type to instantiate.
41
+ state_dict (dict): A dictionary representing the state of the
42
+ class that must be validated and used for initialization.
43
+
44
+ Returns:
45
+ An initialized instance of the class based on the
46
+ provided state dictionary.
47
+
48
+ """
49
+ instance = cls.model_validate(state_dict)
50
+ return instance
@@ -36,6 +36,7 @@ class BatchCollatorWithKeyAlignment:
36
36
  keys_mapping: A dictionary mapping original keys to new keys.
37
37
  keys_to_keep: A set of keys to retain as-is from the original items.
38
38
  max_length: If provided, truncates "input_ids" and "attention_mask" to this length.
39
+ Only 1D tensors/lists are supported.
39
40
 
40
41
  Raises:
41
42
  ValueError: If both `keys_mapping` and `keys_to_keep` are None.
@@ -59,14 +60,16 @@ class BatchCollatorWithKeyAlignment:
59
60
  def _truncate_data(self, key: str, value: Any) -> Any:
60
61
  match value:
61
62
  case torch.Tensor():
62
- if value.dim() > 2:
63
+ if value.dim() >= 2:
63
64
  raise ValueError(
64
- f"Expected value with dim <= 2 for key {key}, got {value.dim()}"
65
+ f"Expected tensor with dim < 2 for key {key}, got {value.dim()}. "
66
+ "Check your data or disable truncation with `max_length=None`."
65
67
  )
66
68
  case list():
67
69
  if isinstance(value[0], list):
68
70
  raise ValueError(
69
- f"Expected value with dim <= 2 for key {key}, got nested lists"
71
+ f"Expected value with dim <= 2 for key {key}, got nested lists. "
72
+ "Check your data or disable truncation with `max_length=None`."
70
73
  )
71
74
  value = value[: self.max_length]
72
75
  return value
kostyl/ml/dist_utils.py CHANGED
@@ -4,47 +4,69 @@ from typing import Literal
4
4
 
5
5
  import torch.distributed as dist
6
6
 
7
+ from kostyl.utils.logging import KostylLogger
7
8
  from kostyl.utils.logging import setup_logger
8
9
 
9
10
 
10
- logger = setup_logger(add_rank=True)
11
+ module_logger = setup_logger()
11
12
 
12
13
 
13
- def log_dist(msg: str, how: Literal["only-zero-rank", "world"]) -> None:
14
+ def log_dist(
15
+ msg: str,
16
+ logger: KostylLogger | None = None,
17
+ level: Literal["info", "warning", "error", "warning_once", "debug"] = "info",
18
+ log_scope: Literal["only-zero-rank", "world"] = "world",
19
+ group: dist.ProcessGroup | None = None,
20
+ ) -> None:
14
21
  """
15
22
  Log a message in a distributed environment based on the specified verbosity level.
16
23
 
17
24
  Args:
18
25
  msg (str): The message to log.
19
- how (Literal["only-zero-rank", "world"]): The verbosity level for logging.
26
+ log_scope (Literal["only-zero-rank", "world"]): The verbosity level for logging.
20
27
  - "only-zero-rank": Log only from the main process (rank 0).
21
28
  - "world": Log from all processes in the distributed environment.
29
+ logger (KostylLogger | None): The logger instance to use. If None, the module logger is used.
30
+ level (Literal["info", "warning", "error", "warning_once", "debug"]): The logging level.
31
+ group (dist.ProcessGroup | None): Optional process group used to determine ranks. Defaults to the global process group.
22
32
 
23
33
  """
24
- match how:
25
- case _ if not dist.is_initialized():
26
- logger.warning_once(
27
- "Distributed logging requested but torch.distributed is not initialized."
28
- )
29
- logger.info(msg)
34
+ if logger is None:
35
+ logger = module_logger
36
+
37
+ log_attr = getattr(logger, level, None)
38
+ if log_attr is None:
39
+ raise ValueError(f"Invalid logging level: {level}")
40
+
41
+ if not dist.is_initialized():
42
+ module_logger.warning_once(
43
+ "Distributed process group is not initialized. Logging from the current process only."
44
+ )
45
+ log_attr(msg)
46
+ return
47
+
48
+ match log_scope:
30
49
  case "only-zero-rank":
31
- if is_main_process():
32
- logger.info(msg)
50
+ if group is None:
51
+ module_logger.debug(
52
+ "No process group provided; assuming global group for rank check."
53
+ )
54
+ group = dist.group.WORLD
55
+ group_rank = dist.get_rank(group=group)
56
+ if dist.get_global_rank(group=group, group_rank=group_rank) == 0: # pyright: ignore[reportArgumentType]
57
+ log_attr(msg)
33
58
  case "world":
34
- logger.info(msg)
59
+ log_attr(msg)
35
60
  case _:
36
- logger.warning_once(
37
- f"Invalid logging verbosity level requested: {how}. Message not logged."
38
- )
61
+ raise ValueError(f"Invalid logging verbosity level: {log_scope}")
39
62
  return
40
63
 
41
64
 
42
65
  def scale_lrs_by_world_size(
43
66
  lrs: dict[str, float],
44
67
  group: dist.ProcessGroup | None = None,
45
- config_name: str = "",
46
68
  inv_scale: bool = False,
47
- verbose: Literal["only-zero-rank", "world"] | None = None,
69
+ verbose_level: Literal["only-zero-rank", "world"] | None = None,
48
70
  ) -> dict[str, float]:
49
71
  """
50
72
  Scale learning-rate configuration values to match the active distributed world size.
@@ -56,9 +78,8 @@ def scale_lrs_by_world_size(
56
78
  lrs (dict[str, float]): A dictionary of learning rate names and their corresponding values to be scaled.
57
79
  group (dist.ProcessGroup | None): Optional process group used to determine
58
80
  the target world size. Defaults to the global process group.
59
- config_name (str): Human-readable identifier included in log messages.
60
81
  inv_scale (bool): If True, use the inverse square-root scale factor.
61
- verbose (Literal["only-zero-rank", "world"] | None): Verbosity level for logging scaled values.
82
+ verbose_level (Literal["only-zero-rank", "world"] | None): Verbosity level for logging scaled values.
62
83
  - "only-zero-rank": Log only from the main process (rank 0).
63
84
  - "world": Log from all processes in the distributed environment.
64
85
  - None: No logging.
@@ -77,31 +98,30 @@ def scale_lrs_by_world_size(
77
98
  for name, value in lrs.items():
78
99
  old_value = value
79
100
  new_value = value * scale
80
- if verbose is not None:
101
+ if verbose_level is not None:
81
102
  log_dist(
82
- f"New {config_name} lr {name.upper()}: {new_value}; OLD: {old_value}",
83
- verbose,
103
+ f"lr {name.upper()}: {new_value}; OLD: {old_value}",
104
+ log_scope=verbose_level,
105
+ group=group,
84
106
  )
85
107
  lrs[name] = new_value
86
108
  return lrs
87
109
 
88
110
 
89
- def get_rank() -> int:
90
- """Gets the rank of the current process in a distributed setting."""
91
- if dist.is_initialized():
92
- return dist.get_rank()
93
- if "RANK" in os.environ:
94
- return int(os.environ["RANK"])
95
- if "SLURM_PROCID" in os.environ:
96
- return int(os.environ["SLURM_PROCID"])
111
+ def get_local_rank(group: dist.ProcessGroup | None = None) -> int:
112
+ """Gets the local rank of the current process in a distributed setting."""
113
+ if dist.is_initialized() and group is not None:
114
+ return dist.get_rank(group=group)
115
+ if "SLURM_LOCALID" in os.environ:
116
+ return int(os.environ["SLURM_LOCALID"])
97
117
  if "LOCAL_RANK" in os.environ:
98
118
  return int(os.environ["LOCAL_RANK"])
99
119
  return 0
100
120
 
101
121
 
102
- def is_main_process() -> bool:
103
- """Checks if the current process is the main process (rank 0) in a distributed setting."""
104
- rank = get_rank()
122
+ def is_local_zero_rank() -> bool:
123
+ """Checks if the current process is the main process (rank 0) for the local node in a distributed setting."""
124
+ rank = get_local_rank()
105
125
  if rank != 0:
106
126
  return False
107
127
  return True
@@ -0,0 +1,7 @@
1
+ try:
2
+ import clearml # noqa: F401
3
+ except ImportError as e:
4
+ raise ImportError(
5
+ "ClearML integration requires the 'clearml' package. "
6
+ "Please install it via 'pip install clearml'."
7
+ ) from e
@@ -1,5 +1,3 @@
1
- from abc import ABC
2
- from abc import abstractmethod
3
1
  from collections.abc import Callable
4
2
  from functools import partial
5
3
  from pathlib import Path
@@ -7,22 +5,14 @@ from typing import override
7
5
 
8
6
  from clearml import OutputModel
9
7
 
8
+ from kostyl.ml.base_uploader import ModelCheckpointUploader
10
9
  from kostyl.utils.logging import setup_logger
11
10
 
12
11
 
13
12
  logger = setup_logger()
14
13
 
15
14
 
16
- class RegistryUploaderCallback(ABC):
17
- """Abstract Lightning callback responsible for tracking and uploading the best-performing model checkpoint."""
18
-
19
- @abstractmethod
20
- def upload_checkpoint(self, path: str | Path) -> None:
21
- """Upload the checkpoint located at the given path to the configured registry backend."""
22
- raise NotImplementedError
23
-
24
-
25
- class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
15
+ class ClearMLCheckpointUploader(ModelCheckpointUploader):
26
16
  """PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
27
17
 
28
18
  def __init__(
@@ -38,7 +28,7 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
38
28
  verbose: bool = True,
39
29
  ) -> None:
40
30
  """
41
- Initializes the ClearMLRegistryUploaderCallback.
31
+ Initializes the ClearMLRegistryUploader.
42
32
 
43
33
  Args:
44
34
  model_name: The name for the newly created model.
@@ -1,75 +1,25 @@
1
1
  from pathlib import Path
2
- from typing import Self
3
- from typing import TypeVar
4
2
 
5
3
  from caseconverter import pascalcase
6
4
  from caseconverter import snakecase
7
5
  from clearml import Task
8
- from pydantic import BaseModel as PydanticBaseModel
9
6
 
7
+ from kostyl.ml.configs import ConfigLoadingMixin
10
8
  from kostyl.utils.dict_manipulations import convert_to_flat_dict
11
9
  from kostyl.utils.dict_manipulations import flattened_dict_to_nested
12
10
  from kostyl.utils.fs import load_config
13
11
 
14
12
 
15
- TConfig = TypeVar("TConfig", bound=PydanticBaseModel)
16
-
17
-
18
- class BaseModelWithConfigLoading(PydanticBaseModel):
19
- """Pydantic class providing basic configuration loading functionality."""
20
-
21
- @classmethod
22
- def from_file(
23
- cls: type[Self], # pyright: ignore
24
- path: str | Path,
25
- ) -> Self:
26
- """
27
- Create an instance of the class from a configuration file.
28
-
29
- Args:
30
- cls_: The class type to instantiate.
31
- path (str | Path): Path to the configuration file.
32
-
33
- Returns:
34
- An instance of the class created from the configuration file.
35
-
36
- """
37
- config = load_config(path)
38
- instance = cls.model_validate(config)
39
- return instance
40
-
41
- @classmethod
42
- def from_dict(
43
- cls: type[Self], # pyright: ignore
44
- state_dict: dict,
45
- ) -> Self:
46
- """
47
- Creates an instance from a dictionary.
48
-
49
- Args:
50
- cls_: The class type to instantiate.
51
- state_dict (dict): A dictionary representing the state of the
52
- class that must be validated and used for initialization.
53
-
54
- Returns:
55
- An initialized instance of the class based on the
56
- provided state dictionary.
57
-
58
- """
59
- instance = cls.model_validate(state_dict)
60
- return instance
61
-
62
-
63
- class BaseModelWithClearmlSyncing(BaseModelWithConfigLoading):
64
- """Pydantic class providing ClearML configuration loading and syncing functionality."""
13
+ class BaseModelWithClearmlSyncing[TConfig: ConfigLoadingMixin]:
14
+ """Mixin providing ClearML task configuration syncing functionality for Pydantic models."""
65
15
 
66
16
  @classmethod
67
17
  def connect_as_file(
68
- cls: type[Self], # pyright: ignore
18
+ cls: type[TConfig], # pyright: ignore
69
19
  task: Task,
70
20
  path: str | Path,
71
21
  alias: str | None = None,
72
- ) -> Self:
22
+ ) -> TConfig:
73
23
  """
74
24
  Connects the configuration file to a ClearML task and creates an instance of the class from it.
75
25
 
@@ -104,11 +54,11 @@ class BaseModelWithClearmlSyncing(BaseModelWithConfigLoading):
104
54
 
105
55
  @classmethod
106
56
  def connect_as_dict(
107
- cls: type[Self], # pyright: ignore
57
+ cls: type[TConfig], # pyright: ignore
108
58
  task: Task,
109
59
  path: str | Path,
110
60
  alias: str | None = None,
111
- ) -> Self:
61
+ ) -> TConfig:
112
62
  """
113
63
  Connects configuration from a file as a dictionary to a ClearML task and creates an instance of the class.
114
64
 
@@ -135,9 +85,3 @@ class BaseModelWithClearmlSyncing(BaseModelWithConfigLoading):
135
85
 
136
86
  model = cls.from_dict(state_dict=config)
137
87
  return model
138
-
139
-
140
- class KostylBaseModel(BaseModelWithClearmlSyncing):
141
- """A Pydantic model class with basic configuration loading functionality."""
142
-
143
- pass
@@ -9,9 +9,26 @@ from transformers import AutoTokenizer
9
9
  from transformers import PreTrainedModel
10
10
  from transformers import PreTrainedTokenizerBase
11
11
 
12
- from kostyl.ml.lightning.extensions.pretrained_model import (
13
- LightningCheckpointLoaderMixin,
14
- )
12
+
13
+ try:
14
+ from kostyl.ml.integrations.lightning import (
15
+ LightningCheckpointLoaderMixin, # pyright: ignore[reportAssignmentType]
16
+ )
17
+
18
+ LIGHTING_MIXIN_AVAILABLE = True
19
+ except ImportError:
20
+
21
+ class LightningCheckpointLoaderMixin(PreTrainedModel): # noqa: D101
22
+ pass # type: ignore
23
+
24
+ @classmethod
25
+ def from_lightning_checkpoint(cls, *args: Any, **kwargs: Any) -> Any: # noqa: D103
26
+ raise ImportError(
27
+ "Loading from Lightning checkpoints requires lightning integration. "
28
+ "Please package install via 'pip install lightning' to enable this functionality."
29
+ )
30
+
31
+ LIGHTING_MIXIN_AVAILABLE = False
15
32
 
16
33
 
17
34
  def get_tokenizer_from_clearml(
@@ -89,13 +106,23 @@ def get_model_from_clearml[
89
106
  local_path = Path(input_model.get_local_copy(raise_on_error=True))
90
107
 
91
108
  if local_path.is_dir() and input_model._is_package():
109
+ if not issubclass(model, (PreTrainedModel, AutoModel)):
110
+ raise ValueError(
111
+ f"Model class {model.__name__} must be a subclass of PreTrainedModel or AutoModel for directory loads."
112
+ )
92
113
  model_instance = model.from_pretrained(local_path, **kwargs)
93
114
  elif local_path.suffix == ".ckpt":
115
+ if not LIGHTING_MIXIN_AVAILABLE:
116
+ raise ImportError(
117
+ "Loading from Lightning checkpoints requires lightning integration. "
118
+ "Please package install via 'pip install lightning' to enable this functionality."
119
+ )
94
120
  if not issubclass(model, LightningCheckpointLoaderMixin):
95
121
  raise ValueError(
96
- f"Model class {model.__name__} is not compatible with Lightning checkpoints."
122
+ f"Model class {model.__name__} is not compatible with Lightning checkpoints "
123
+ "(must inherit from LightningCheckpointLoaderMixin)."
97
124
  )
98
- model_instance = model.from_lightning_checkpoint(local_path, **kwargs)
125
+ model_instance = model.from_lightning_checkpoint(local_path, **kwargs) # type: ignore
99
126
  else:
100
127
  raise ValueError(
101
128
  f"Unsupported model format for path: {local_path}. "
@@ -0,0 +1,14 @@
1
+ try:
2
+ import lightning # noqa: F401
3
+ except ImportError as e:
4
+ raise ImportError(
5
+ "Lightning integration requires the 'lightning' package. "
6
+ "Please install it via 'pip install lightning'."
7
+ ) from e
8
+
9
+
10
+ from .mixins import LightningCheckpointLoaderMixin
11
+ from .module import KostylLightningModule
12
+
13
+
14
+ __all__ = ["KostylLightningModule", "LightningCheckpointLoaderMixin"]
@@ -9,17 +9,16 @@ import torch.distributed as dist
9
9
  from lightning.fabric.utilities.types import _PATH
10
10
  from lightning.pytorch.callbacks import ModelCheckpoint
11
11
 
12
+ from kostyl.ml.base_uploader import ModelCheckpointUploader
12
13
  from kostyl.ml.configs import CheckpointConfig
13
- from kostyl.ml.dist_utils import is_main_process
14
- from kostyl.ml.lightning import KostylLightningModule
15
- from kostyl.ml.registry_uploader import RegistryUploaderCallback
14
+ from kostyl.ml.dist_utils import is_local_zero_rank
16
15
  from kostyl.utils import setup_logger
17
16
 
18
17
 
19
18
  logger = setup_logger("callbacks/checkpoint.py")
20
19
 
21
20
 
22
- class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
21
+ class ModelCheckpointWithCheckpointUploader(ModelCheckpoint):
23
22
  r"""
24
23
  Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
25
24
  :class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
@@ -229,8 +228,8 @@ class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
229
228
 
230
229
  def __init__( # noqa: D107
231
230
  self,
232
- registry_uploader_callback: RegistryUploaderCallback,
233
- uploading_mode: Literal["only-best", "every-checkpoint"] = "only-best",
231
+ checkpoint_uploader: ModelCheckpointUploader,
232
+ upload_strategy: Literal["only-best", "every-checkpoint"] = "only-best",
234
233
  dirpath: _PATH | None = None,
235
234
  filename: str | None = None,
236
235
  monitor: str | None = None,
@@ -247,9 +246,9 @@ class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
247
246
  save_on_train_epoch_end: bool | None = None,
248
247
  enable_version_counter: bool = True,
249
248
  ) -> None:
250
- self.registry_uploader_callback = registry_uploader_callback
249
+ self.registry_uploader = checkpoint_uploader
251
250
  self.process_group: dist.ProcessGroup | None = None
252
- self.uploading_mode = uploading_mode
251
+ self.upload_strategy = upload_strategy
253
252
  super().__init__(
254
253
  dirpath=dirpath,
255
254
  filename=filename,
@@ -269,40 +268,26 @@ class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
269
268
  )
270
269
  return
271
270
 
272
- @override
273
- def setup(
274
- self,
275
- trainer: pl.Trainer,
276
- pl_module: pl.LightningModule | KostylLightningModule,
277
- stage: str,
278
- ) -> None:
279
- super().setup(trainer, pl_module, stage)
280
- if isinstance(pl_module, KostylLightningModule):
281
- self.process_group = pl_module.get_process_group()
282
- return
283
-
284
271
  @override
285
272
  def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
286
273
  super()._save_checkpoint(trainer, filepath)
287
- if dist.is_initialized():
288
- dist.barrier(group=self.process_group)
289
- if trainer.is_global_zero and self.registry_uploader_callback is not None:
290
- match self.uploading_mode:
274
+ if trainer.is_global_zero and self.registry_uploader is not None:
275
+ match self.upload_strategy:
291
276
  case "every-checkpoint":
292
- self.registry_uploader_callback.upload_checkpoint(filepath)
277
+ self.registry_uploader.upload_checkpoint(filepath)
293
278
  case "only-best":
294
279
  if filepath == self.best_model_path:
295
- self.registry_uploader_callback.upload_checkpoint(filepath)
280
+ self.registry_uploader.upload_checkpoint(filepath)
296
281
  return
297
282
 
298
283
 
299
284
  def setup_checkpoint_callback(
300
285
  dirpath: Path,
301
286
  ckpt_cfg: CheckpointConfig,
302
- registry_uploader_callback: RegistryUploaderCallback | None = None,
303
- uploading_strategy: Literal["only-best", "every-checkpoint"] | None = None,
287
+ checkpoint_uploader: ModelCheckpointUploader | None = None,
288
+ upload_strategy: Literal["only-best", "every-checkpoint"] | None = None,
304
289
  remove_folder_if_exists: bool = True,
305
- ) -> ModelCheckpointWithRegistryUploader | ModelCheckpoint:
290
+ ) -> ModelCheckpointWithCheckpointUploader | ModelCheckpoint:
306
291
  """
307
292
  Create and configure a checkpoint callback for model saving.
308
293
 
@@ -313,33 +298,33 @@ def setup_checkpoint_callback(
313
298
  Args:
314
299
  dirpath: Path to the directory for saving checkpoints.
315
300
  ckpt_cfg: Checkpoint configuration (filename, monitor, mode, save_top_k).
316
- registry_uploader_callback: Optional callback for uploading checkpoints to a remote registry.
317
- Must be specified together with uploading_strategy.
318
- uploading_strategy: Checkpoint upload mode:
301
+ checkpoint_uploader: Optional checkpoint uploader instance. If provided, enables
302
+ uploading of checkpoints to a remote registry.
303
+ upload_strategy: Checkpoint upload mode:
319
304
  - "only-best": only the best checkpoint is uploaded
320
305
  - "every-checkpoint": every saved checkpoint is uploaded
321
- Must be specified together with registry_uploader_callback.
306
+ Must be specified together with checkpoint_uploader.
322
307
  remove_folder_if_exists: If True, removes existing checkpoint directory before creating a new one.
323
308
 
324
309
  Returns:
325
- ModelCheckpointWithRegistryUploader if registry_uploader_callback is provided,
310
+ ModelCheckpointWithCheckpointUploader if checkpoint_uploader is provided,
326
311
  otherwise standard ModelCheckpoint.
327
312
 
328
313
  Raises:
329
- ValueError: If only one of registry_uploader_callback or uploading_mode is None.
314
+ ValueError: If only one of checkpoint_uploader or uploading_mode is None.
330
315
 
331
316
  Note:
332
317
  If the dirpath directory already exists, it will be removed and recreated
333
318
  (only on the main process in distributed training) if remove_folder_if_exists is True.
334
319
 
335
320
  """
336
- if (registry_uploader_callback is None) != (uploading_strategy is None):
321
+ if (checkpoint_uploader is None) != (upload_strategy is None):
337
322
  raise ValueError(
338
- "Both registry_uploader_callback and uploading_mode must be provided or neither."
323
+ "Both checkpoint_uploader and upload_strategy must be provided or neither."
339
324
  )
340
325
 
341
326
  if dirpath.exists():
342
- if is_main_process():
327
+ if is_local_zero_rank():
343
328
  logger.warning(f"Checkpoint directory {dirpath} already exists.")
344
329
  if remove_folder_if_exists:
345
330
  rmtree(dirpath)
@@ -348,8 +333,8 @@ def setup_checkpoint_callback(
348
333
  logger.info(f"Creating checkpoint directory {dirpath}.")
349
334
  dirpath.mkdir(parents=True, exist_ok=True)
350
335
 
351
- if (registry_uploader_callback is not None) and (uploading_strategy is not None):
352
- checkpoint_callback = ModelCheckpointWithRegistryUploader(
336
+ if (checkpoint_uploader is not None) and (upload_strategy is not None):
337
+ checkpoint_callback = ModelCheckpointWithCheckpointUploader(
353
338
  dirpath=dirpath,
354
339
  filename=ckpt_cfg.filename,
355
340
  save_top_k=ckpt_cfg.save_top_k,
@@ -357,8 +342,8 @@ def setup_checkpoint_callback(
357
342
  mode=ckpt_cfg.mode,
358
343
  verbose=True,
359
344
  save_weights_only=ckpt_cfg.save_weights_only,
360
- registry_uploader_callback=registry_uploader_callback,
361
- uploading_mode=uploading_strategy,
345
+ checkpoint_uploader=checkpoint_uploader,
346
+ upload_strategy=upload_strategy,
362
347
  )
363
348
  else:
364
349
  checkpoint_callback = ModelCheckpoint(
@@ -3,7 +3,7 @@ from shutil import rmtree
3
3
 
4
4
  from lightning.pytorch.loggers import TensorBoardLogger
5
5
 
6
- from kostyl.ml.dist_utils import is_main_process
6
+ from kostyl.ml.dist_utils import is_local_zero_rank
7
7
  from kostyl.utils.logging import setup_logger
8
8
 
9
9
 
@@ -15,7 +15,7 @@ def setup_tb_logger(
15
15
  ) -> TensorBoardLogger:
16
16
  """Sets up a TensorBoardLogger for PyTorch Lightning."""
17
17
  if runs_dir.exists():
18
- if is_main_process():
18
+ if is_local_zero_rank():
19
19
  logger.warning(f"TensorBoard log directory {runs_dir} already exists.")
20
20
  rmtree(runs_dir)
21
21
  logger.warning(f"Removed existing TensorBoard log directory {runs_dir}.")