kostyl-toolkit 0.1.36__tar.gz → 0.1.37__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/PKG-INFO +1 -1
  2. kostyl_toolkit-0.1.37/kostyl/ml/dist_utils.py +129 -0
  3. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/callbacks/checkpoint.py +2 -2
  4. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/extensions/custom_module.py +0 -5
  5. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/extensions/pretrained_model.py +6 -4
  6. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/loggers/tb_logger.py +2 -2
  7. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/utils.py +1 -1
  8. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/utils/logging.py +67 -52
  9. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/pyproject.toml +1 -1
  10. kostyl_toolkit-0.1.36/kostyl/ml/dist_utils.py +0 -107
  11. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/README.md +0 -0
  12. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/__init__.py +0 -0
  13. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/__init__.py +0 -0
  14. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/clearml/__init__.py +0 -0
  15. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/clearml/dataset_utils.py +0 -0
  16. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/clearml/logging_utils.py +0 -0
  17. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/clearml/pulling_utils.py +0 -0
  18. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/configs/__init__.py +0 -0
  19. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/configs/base_model.py +0 -0
  20. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/configs/hyperparams.py +0 -0
  21. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/configs/training_settings.py +0 -0
  22. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/data_processing_utils.py +0 -0
  23. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/__init__.py +0 -0
  24. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/callbacks/__init__.py +0 -0
  25. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/callbacks/early_stopping.py +0 -0
  26. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/extensions/__init__.py +0 -0
  27. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/lightning/loggers/__init__.py +0 -0
  28. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/metrics_formatting.py +0 -0
  29. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/params_groups.py +0 -0
  30. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/registry_uploader.py +0 -0
  31. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/schedulers/__init__.py +0 -0
  32. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/schedulers/base.py +0 -0
  33. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/schedulers/composite.py +0 -0
  34. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/schedulers/cosine.py +0 -0
  35. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/schedulers/cosine_with_plateu.py +0 -0
  36. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/ml/schedulers/linear.py +0 -0
  37. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/utils/__init__.py +0 -0
  38. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/utils/dict_manipulations.py +0 -0
  39. {kostyl_toolkit-0.1.36 → kostyl_toolkit-0.1.37}/kostyl/utils/fs.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.36
3
+ Version: 0.1.37
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -0,0 +1,129 @@
1
+ import math
2
+ import os
3
+ from typing import Literal
4
+
5
+ import torch.distributed as dist
6
+
7
+ from kostyl.utils.logging import KostylLogger
8
+ from kostyl.utils.logging import setup_logger
9
+
10
+
11
+ module_logger = setup_logger()
12
+
13
+
14
+ def log_dist(
15
+ msg: str,
16
+ logger: KostylLogger | None = None,
17
+ level: Literal["info", "warning", "error", "warning_once", "debug"] = "info",
18
+ log_scope: Literal["only-zero-rank", "world"] = "world",
19
+ group: dist.ProcessGroup | None = None,
20
+ ) -> None:
21
+ """
22
+ Log a message in a distributed environment based on the specified verbosity level.
23
+
24
+ Args:
25
+ msg (str): The message to log.
26
+ log_scope (Literal["only-zero-rank", "world"]): The verbosity level for logging.
27
+ - "only-zero-rank": Log only from the main process (rank 0).
28
+ - "world": Log from all processes in the distributed environment.
29
+ logger (KostylLogger | None): The logger instance to use. If None, the module logger is used.
30
+ level (Literal["info", "warning", "error", "warning_once", "debug"]): The logging level.
31
+ group (dist.ProcessGroup | None): Optional process group used to determine ranks. Defaults to the global process group.
32
+
33
+ """
34
+ if logger is None:
35
+ logger = module_logger
36
+
37
+ log_attr = getattr(logger, level, None)
38
+ if log_attr is None:
39
+ raise ValueError(f"Invalid logging level: {level}")
40
+
41
+ if not dist.is_initialized():
42
+ module_logger.warning_once(
43
+ "Distributed process group is not initialized; logging from all ranks."
44
+ )
45
+ log_attr(msg)
46
+ return
47
+
48
+ match log_scope:
49
+ case "only-zero-rank":
50
+ if group is None:
51
+ module_logger.debug(
52
+ "No process group provided; assuming global group for rank check."
53
+ )
54
+ group = dist.group.WORLD
55
+ group_rank = dist.get_rank(group=group)
56
+ if dist.get_global_rank(group=group, group_rank=group_rank) == 0: # pyright: ignore[reportArgumentType]
57
+ log_attr(msg)
58
+ case "world":
59
+ log_attr(msg)
60
+ case _:
61
+ raise ValueError(f"Invalid logging verbosity level: {log_scope}")
62
+ return
63
+
64
+
65
+ def scale_lrs_by_world_size(
66
+ lrs: dict[str, float],
67
+ group: dist.ProcessGroup | None = None,
68
+ config_name: str = "",
69
+ inv_scale: bool = False,
70
+ verbose_level: Literal["only-zero-rank", "world"] | None = None,
71
+ ) -> dict[str, float]:
72
+ """
73
+ Scale learning-rate configuration values to match the active distributed world size.
74
+
75
+ Note:
76
+ The value in the `lrs` will be modified in place.
77
+
78
+ Args:
79
+ lrs (dict[str, float]): A dictionary of learning rate names and their corresponding values to be scaled.
80
+ group (dist.ProcessGroup | None): Optional process group used to determine
81
+ the target world size. Defaults to the global process group.
82
+ config_name (str): Human-readable identifier included in log messages.
83
+ inv_scale (bool): If True, use the inverse square-root scale factor.
84
+ verbose_level (Literal["only-zero-rank", "world"] | None): Verbosity level for logging scaled values.
85
+ - "only-zero-rank": Log only from the main process (rank 0).
86
+ - "world": Log from all processes in the distributed environment.
87
+ - None: No logging.
88
+
89
+ Returns:
90
+ dict[str, float]: The learning-rate configuration with scaled values.
91
+
92
+ """
93
+ world_size = dist.get_world_size(group=group)
94
+
95
+ if inv_scale:
96
+ scale = 1 / math.sqrt(world_size)
97
+ else:
98
+ scale = math.sqrt(world_size)
99
+
100
+ for name, value in lrs.items():
101
+ old_value = value
102
+ new_value = value * scale
103
+ if verbose_level is not None:
104
+ log_dist(
105
+ f"New {config_name} lr {name.upper()}: {new_value}; OLD: {old_value}",
106
+ log_scope=verbose_level,
107
+ group=group,
108
+ )
109
+ lrs[name] = new_value
110
+ return lrs
111
+
112
+
113
+ def get_local_rank(group: dist.ProcessGroup | None = None) -> int:
114
+ """Gets the local rank of the current process in a distributed setting."""
115
+ if dist.is_initialized() and group is not None:
116
+ return dist.get_rank(group=group)
117
+ if "SLURM_LOCALID" in os.environ:
118
+ return int(os.environ["SLURM_LOCALID"])
119
+ if "LOCAL_RANK" in os.environ:
120
+ return int(os.environ["LOCAL_RANK"])
121
+ return 0
122
+
123
+
124
+ def is_local_zero_rank() -> bool:
125
+ """Checks if the current process is the main process (rank 0) for the local node in a distributed setting."""
126
+ rank = get_local_rank()
127
+ if rank != 0:
128
+ return False
129
+ return True
@@ -10,7 +10,7 @@ from lightning.fabric.utilities.types import _PATH
10
10
  from lightning.pytorch.callbacks import ModelCheckpoint
11
11
 
12
12
  from kostyl.ml.configs import CheckpointConfig
13
- from kostyl.ml.dist_utils import is_main_process
13
+ from kostyl.ml.dist_utils import is_local_zero_rank
14
14
  from kostyl.ml.lightning import KostylLightningModule
15
15
  from kostyl.ml.registry_uploader import RegistryUploaderCallback
16
16
  from kostyl.utils import setup_logger
@@ -339,7 +339,7 @@ def setup_checkpoint_callback(
339
339
  )
340
340
 
341
341
  if dirpath.exists():
342
- if is_main_process():
342
+ if is_local_zero_rank():
343
343
  logger.warning(f"Checkpoint directory {dirpath} already exists.")
344
344
  if remove_folder_if_exists:
345
345
  rmtree(dirpath)
@@ -26,11 +26,6 @@ module_logger = setup_logger(fmt="only_message")
26
26
  class KostylLightningModule(L.LightningModule):
27
27
  """Custom PyTorch Lightning Module with logging, checkpointing, and distributed training utilities."""
28
28
 
29
- @property
30
- def process_group(self) -> ProcessGroup | None:
31
- """Returns the data parallel process group for distributed training."""
32
- return self.get_process_group()
33
-
34
29
  def get_process_group(self) -> ProcessGroup | None:
35
30
  """
36
31
  Retrieves the data parallel process group for distributed training.
@@ -12,12 +12,12 @@ from kostyl.utils.logging import setup_logger
12
12
  logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
13
13
 
14
14
 
15
- class LightningCheckpointLoaderMixin(PreTrainedModel):
15
+ class LightningCheckpointLoaderMixin:
16
16
  """A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
17
17
 
18
18
  @classmethod
19
- def from_lightning_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
20
- cls: type[TModelInstance],
19
+ def from_lightning_checkpoint[TModelInstance: PreTrainedModel]( # noqa: C901
20
+ cls: type[TModelInstance], # pyright: ignore[reportGeneralTypeIssues]
21
21
  checkpoint_path: str | Path,
22
22
  config_key: str = "config",
23
23
  weights_prefix: str | None = "model.",
@@ -78,7 +78,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
78
78
  mmap=True,
79
79
  )
80
80
 
81
- # 1. Восстанавливаем конфиг
81
+ # Load config
82
82
  config_cls = cast(type[PretrainedConfig], cls.config_class)
83
83
  config_dict = checkpoint_dict[config_key]
84
84
  config_dict.update(kwargs)
@@ -91,6 +91,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
91
91
 
92
92
  raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
93
93
 
94
+ # Handle weights prefix
94
95
  if weights_prefix:
95
96
  if not weights_prefix.endswith("."):
96
97
  weights_prefix = weights_prefix + "."
@@ -117,6 +118,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
117
118
  else:
118
119
  state_dict = raw_state_dict
119
120
 
121
+ # Instantiate model and load state dict
120
122
  model = cls.from_pretrained(
121
123
  pretrained_model_name_or_path=None,
122
124
  config=config,
@@ -3,7 +3,7 @@ from shutil import rmtree
3
3
 
4
4
  from lightning.pytorch.loggers import TensorBoardLogger
5
5
 
6
- from kostyl.ml.dist_utils import is_main_process
6
+ from kostyl.ml.dist_utils import is_local_zero_rank
7
7
  from kostyl.utils.logging import setup_logger
8
8
 
9
9
 
@@ -15,7 +15,7 @@ def setup_tb_logger(
15
15
  ) -> TensorBoardLogger:
16
16
  """Sets up a TensorBoardLogger for PyTorch Lightning."""
17
17
  if runs_dir.exists():
18
- if is_main_process():
18
+ if is_local_zero_rank():
19
19
  logger.warning(f"TensorBoard log directory {runs_dir} already exists.")
20
20
  rmtree(runs_dir)
21
21
  logger.warning(f"Removed existing TensorBoard log directory {runs_dir}.")
@@ -14,7 +14,7 @@ TRAINING_STRATEGIES = (
14
14
  FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
15
15
  )
16
16
 
17
- logger = setup_logger(add_rank=True)
17
+ logger = setup_logger()
18
18
 
19
19
 
20
20
  def estimate_total_steps(
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
+ import os
4
5
  import sys
5
6
  import uuid
6
7
  from collections import namedtuple
@@ -18,32 +19,18 @@ from loguru import logger as _base_logger
18
19
  if TYPE_CHECKING:
19
20
  from loguru import Logger
20
21
 
21
- class CustomLogger(Logger): # noqa: D101
22
+ class KostylLogger(Logger): # noqa: D101
22
23
  def log_once(self, level: str, message: str, *args, **kwargs) -> None: ... # noqa: ANN003, D102
23
24
  def warning_once(self, message: str, *args, **kwargs) -> None: ... # noqa: ANN003, D102
24
25
  else:
25
- CustomLogger = type(_base_logger)
26
+ KostylLogger = type(_base_logger)
26
27
 
27
28
  try:
28
- import torch.distributed as dist
29
29
  from torch.nn.modules.module import (
30
30
  _IncompatibleKeys, # pyright: ignore[reportAssignmentType]
31
31
  )
32
32
  except Exception:
33
33
 
34
- class _Dummy:
35
- @staticmethod
36
- def is_available() -> bool:
37
- return False
38
-
39
- @staticmethod
40
- def is_initialized() -> bool:
41
- return False
42
-
43
- @staticmethod
44
- def get_rank() -> int:
45
- return 0
46
-
47
34
  class _IncompatibleKeys(
48
35
  namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
49
36
  ):
@@ -56,14 +43,13 @@ except Exception:
56
43
 
57
44
  __str__ = __repr__
58
45
 
59
- dist = _Dummy()
60
46
  _IncompatibleKeys = _IncompatibleKeys
61
47
 
62
48
  _once_lock = Lock()
63
49
  _once_keys: set[tuple[str, str]] = set()
64
50
 
65
51
 
66
- def _log_once(self: CustomLogger, level: str, message: str, *args, **kwargs) -> None: # noqa: ANN003
52
+ def _log_once(self: KostylLogger, level: str, message: str, *args, **kwargs) -> None: # noqa: ANN003
67
53
  key = (message, level)
68
54
 
69
55
  with _once_lock:
@@ -75,7 +61,7 @@ def _log_once(self: CustomLogger, level: str, message: str, *args, **kwargs) ->
75
61
  return
76
62
 
77
63
 
78
- _base_logger = cast(CustomLogger, _base_logger)
64
+ _base_logger = cast(KostylLogger, _base_logger)
79
65
  _base_logger.log_once = _log_once # pyright: ignore[reportAttributeAccessIssue]
80
66
  _base_logger.warning_once = partialmethod(_log_once, "WARNING") # pyright: ignore[reportAttributeAccessIssue]
81
67
 
@@ -91,44 +77,83 @@ _DEFAULT_FMT = "<level>{level: <8}</level> {time:HH:mm:ss.SSS} [{extra[channel]}
91
77
  _ONLY_MESSAGE_FMT = "<level>{message}</level>"
92
78
  _PRESETS = {"default": _DEFAULT_FMT, "only_message": _ONLY_MESSAGE_FMT}
93
79
 
80
+ KOSTYL_LOG_LEVEL = os.getenv("KOSTYL_LOG_LEVEL", "INFO")
81
+
94
82
 
95
83
  def setup_logger(
96
84
  name: str | None = None,
97
85
  fmt: Literal["default", "only_message"] | str = "only_message",
98
- level: str = "INFO",
99
- add_rank: bool | None = None,
86
+ level: str | None = None,
100
87
  sink=sys.stdout,
101
88
  colorize: bool = True,
102
89
  serialize: bool = False,
103
- ) -> CustomLogger:
90
+ ) -> KostylLogger:
104
91
  """
105
- Returns a bound logger with its own sink and formatting.
92
+ Creates and configures a logger with custom formatting and output.
93
+
94
+ The function automatically removes the default sink on first call and creates
95
+ an isolated logger with a unique identifier for message filtering.
96
+
97
+ Args:
98
+ name (str | None, optional): Logger channel name. If None, automatically
99
+ uses the calling function's filename. Defaults to None.
100
+ fmt (Literal["default", "only_message"] | str, optional): Log message format.
101
+ Available presets:
102
+ - "default": includes level, time, and channel
103
+ - "only_message": outputs only the message itself
104
+ Custom format strings are also supported. Defaults to "only_message".
105
+ level (str | None, optional): Logging level (TRACE, DEBUG, INFO, SUCCESS,
106
+ WARNING, ERROR, CRITICAL). If None, uses the KOSTYL_LOG_LEVEL environment
107
+ variable or "INFO" by default. Defaults to None.
108
+ sink: Output object for logs (file, sys.stdout, sys.stderr, etc.).
109
+ Defaults to sys.stdout.
110
+ colorize (bool, optional): Enable colored output formatting.
111
+ Defaults to True.
112
+ serialize (bool, optional): Serialize logs to JSON format.
113
+ Defaults to False.
114
+
115
+ Returns:
116
+ CustomLogger: Configured logger instance with additional methods
117
+ log_once() and warning_once().
118
+
119
+ Example:
120
+ >>> # Basic usage with automatic name detection
121
+ >>> logger = setup_logger()
122
+ >>> logger.info("Hello World")
106
123
 
107
- Note: If name=None, the caller's filename (similar to __file__) is used automatically.
124
+ >>> # With custom name and level
125
+ >>> logger = setup_logger(name="MyApp", level="DEBUG")
126
+
127
+ >>> # With custom format
128
+ >>> logger = setup_logger(
129
+ ... name="API",
130
+ ... fmt="{level} | {time:YYYY-MM-DD HH:mm:ss} | {message}"
131
+ ... )
108
132
 
109
- Format example: "{level} {time:MM-DD HH:mm:ss} [{extra[channel]}] {message}"
110
133
  """
111
134
  global _DEFAULT_SINK_REMOVED
112
135
  if not _DEFAULT_SINK_REMOVED:
113
136
  _base_logger.remove()
114
137
  _DEFAULT_SINK_REMOVED = True
115
138
 
116
- if name is None:
117
- base = _caller_filename()
118
- else:
119
- base = name
139
+ if level is None:
140
+ if KOSTYL_LOG_LEVEL not in {
141
+ "TRACE",
142
+ "DEBUG",
143
+ "INFO",
144
+ "SUCCESS",
145
+ "WARNING",
146
+ "ERROR",
147
+ "CRITICAL",
148
+ }:
149
+ level = "INFO"
150
+ else:
151
+ level = KOSTYL_LOG_LEVEL
120
152
 
121
- if (add_rank is None) or add_rank:
122
- try:
123
- add_rank = dist.is_available() and dist.is_initialized()
124
- except Exception:
125
- add_rank = False
126
-
127
- if add_rank:
128
- rank = dist.get_rank()
129
- channel = f"rank:{rank} - {base}"
153
+ if name is None:
154
+ channel = _caller_filename()
130
155
  else:
131
- channel = base
156
+ channel = name
132
157
 
133
158
  if fmt in _PRESETS:
134
159
  fmt = _PRESETS[fmt]
@@ -146,7 +171,7 @@ def setup_logger(
146
171
  filter=lambda r: r["extra"].get("logger_id") == logger_id,
147
172
  )
148
173
  logger = _base_logger.bind(logger_id=logger_id, channel=channel)
149
- return cast(CustomLogger, logger)
174
+ return cast(KostylLogger, logger)
150
175
 
151
176
 
152
177
  def log_incompatible_keys(
@@ -154,22 +179,12 @@ def log_incompatible_keys(
154
179
  incompatible_keys: _IncompatibleKeys
155
180
  | tuple[list[str], list[str]]
156
181
  | dict[str, list[str]],
157
- model_specific_msg: str = "",
182
+ postfix_msg: str = "",
158
183
  ) -> None:
159
184
  """
160
185
  Logs warnings for incompatible keys encountered during model loading or state dict operations.
161
186
 
162
187
  Note: If incompatible_keys is of an unsupported type, an error message is logged and the function returns early.
163
-
164
- Args:
165
- logger (Logger): The logger instance used to output warning messages.
166
- incompatible_keys (_IncompatibleKeys | tuple[list[str], list[str]] | dict[str, list[str]]): An object containing lists of missing and unexpected keys.
167
- model_specific_msg (str, optional): A custom message to append to the log output, typically
168
- indicating the model or context. Defaults to an empty string.
169
-
170
- Returns:
171
- None
172
-
173
188
  """
174
189
  incompatible_keys_: dict[str, list[str]] = {}
175
190
  match incompatible_keys:
@@ -192,5 +207,5 @@ def log_incompatible_keys(
192
207
  return
193
208
 
194
209
  for name, keys in incompatible_keys_.items():
195
- logger.warning(f"{name} {model_specific_msg}: {', '.join(keys)}")
210
+ logger.warning(f"{name} {postfix_msg}: {', '.join(keys)}")
196
211
  return
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "kostyl-toolkit"
3
- version = "0.1.36"
3
+ version = "0.1.37"
4
4
  description = "Kickass Orchestration System for Training, Yielding & Logging "
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -1,107 +0,0 @@
1
- import math
2
- import os
3
- from typing import Literal
4
-
5
- import torch.distributed as dist
6
-
7
- from kostyl.utils.logging import setup_logger
8
-
9
-
10
- logger = setup_logger(add_rank=True)
11
-
12
-
13
- def log_dist(msg: str, how: Literal["only-zero-rank", "world"]) -> None:
14
- """
15
- Log a message in a distributed environment based on the specified verbosity level.
16
-
17
- Args:
18
- msg (str): The message to log.
19
- how (Literal["only-zero-rank", "world"]): The verbosity level for logging.
20
- - "only-zero-rank": Log only from the main process (rank 0).
21
- - "world": Log from all processes in the distributed environment.
22
-
23
- """
24
- match how:
25
- case _ if not dist.is_initialized():
26
- logger.warning_once(
27
- "Distributed logging requested but torch.distributed is not initialized."
28
- )
29
- logger.info(msg)
30
- case "only-zero-rank":
31
- if is_main_process():
32
- logger.info(msg)
33
- case "world":
34
- logger.info(msg)
35
- case _:
36
- logger.warning_once(
37
- f"Invalid logging verbosity level requested: {how}. Message not logged."
38
- )
39
- return
40
-
41
-
42
- def scale_lrs_by_world_size(
43
- lrs: dict[str, float],
44
- group: dist.ProcessGroup | None = None,
45
- config_name: str = "",
46
- inv_scale: bool = False,
47
- verbose: Literal["only-zero-rank", "world"] | None = None,
48
- ) -> dict[str, float]:
49
- """
50
- Scale learning-rate configuration values to match the active distributed world size.
51
-
52
- Note:
53
- The value in the `lrs` will be modified in place.
54
-
55
- Args:
56
- lrs (dict[str, float]): A dictionary of learning rate names and their corresponding values to be scaled.
57
- group (dist.ProcessGroup | None): Optional process group used to determine
58
- the target world size. Defaults to the global process group.
59
- config_name (str): Human-readable identifier included in log messages.
60
- inv_scale (bool): If True, use the inverse square-root scale factor.
61
- verbose (Literal["only-zero-rank", "world"] | None): Verbosity level for logging scaled values.
62
- - "only-zero-rank": Log only from the main process (rank 0).
63
- - "world": Log from all processes in the distributed environment.
64
- - None: No logging.
65
-
66
- Returns:
67
- dict[str, float]: The learning-rate configuration with scaled values.
68
-
69
- """
70
- world_size = dist.get_world_size(group=group)
71
-
72
- if inv_scale:
73
- scale = 1 / math.sqrt(world_size)
74
- else:
75
- scale = math.sqrt(world_size)
76
-
77
- for name, value in lrs.items():
78
- old_value = value
79
- new_value = value * scale
80
- if verbose is not None:
81
- log_dist(
82
- f"New {config_name} lr {name.upper()}: {new_value}; OLD: {old_value}",
83
- verbose,
84
- )
85
- lrs[name] = new_value
86
- return lrs
87
-
88
-
89
- def get_rank() -> int:
90
- """Gets the rank of the current process in a distributed setting."""
91
- if dist.is_initialized():
92
- return dist.get_rank()
93
- if "RANK" in os.environ:
94
- return int(os.environ["RANK"])
95
- if "SLURM_PROCID" in os.environ:
96
- return int(os.environ["SLURM_PROCID"])
97
- if "LOCAL_RANK" in os.environ:
98
- return int(os.environ["LOCAL_RANK"])
99
- return 0
100
-
101
-
102
- def is_main_process() -> bool:
103
- """Checks if the current process is the main process (rank 0) in a distributed setting."""
104
- rank = get_rank()
105
- if rank != 0:
106
- return False
107
- return True