kostyl-toolkit 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,7 +35,7 @@ class ConfigLoadingMixin:
35
35
 
36
36
  @classmethod
37
37
  def from_file(
38
- cls: type[TConfig], # pyright: ignore[reportGeneralTypeIssues]
38
+ cls: type[TConfig], # pyright: ignore
39
39
  path: str | Path,
40
40
  ) -> TConfig:
41
41
  """
@@ -55,7 +55,7 @@ class ConfigLoadingMixin:
55
55
 
56
56
  @classmethod
57
57
  def from_dict(
58
- cls: type[TConfig], # pyright: ignore[reportGeneralTypeIssues]
58
+ cls: type[TConfig], # pyright: ignore
59
59
  state_dict: dict,
60
60
  ) -> TConfig:
61
61
  """
@@ -83,7 +83,7 @@ class ClearMLConfigMixin(ConfigLoadingMixin):
83
83
 
84
84
  @classmethod
85
85
  def connect_as_file(
86
- cls: type[TModel], # pyright: ignore[reportGeneralTypeIssues]
86
+ cls: type[TModel], # pyright: ignore
87
87
  task: clearml.Task,
88
88
  path: str | Path,
89
89
  alias: str | None = None,
@@ -122,7 +122,7 @@ class ClearMLConfigMixin(ConfigLoadingMixin):
122
122
 
123
123
  @classmethod
124
124
  def connect_as_dict(
125
- cls: type[TModel], # pyright: ignore[reportGeneralTypeIssues]
125
+ cls: type[TModel], # pyright: ignore
126
126
  task: clearml.Task,
127
127
  path: str | Path,
128
128
  alias: str | None = None,
@@ -91,7 +91,7 @@ class DataConfig(BaseModel):
91
91
  data_columns: list[str]
92
92
 
93
93
 
94
- class TrainingParams(ConfigLoadingMixin):
94
+ class TrainingParams(BaseModel, ConfigLoadingMixin):
95
95
  """Training parameters configuration."""
96
96
 
97
97
  trainer: LightningTrainerParameters
@@ -1,33 +1,70 @@
1
1
  import math
2
2
  import os
3
+ from typing import Literal
3
4
 
4
5
  import torch.distributed as dist
5
6
 
6
- from kostyl.ml_core.configs import Lr
7
7
  from kostyl.utils.logging import setup_logger
8
8
 
9
9
 
10
10
  logger = setup_logger(add_rank=True)
11
11
 
12
12
 
13
- def scale_lrs_by_world_size[Tlr: Lr](
14
- lr_config: Tlr,
13
+ def log_dist(msg: str, how: Literal["only-zero-rank", "world"]) -> None:
14
+ """
15
+ Log a message in a distributed environment based on the specified verbosity level.
16
+
17
+ Args:
18
+ msg (str): The message to log.
19
+ how (Literal["only-zero-rank", "world"]): The verbosity level for logging.
20
+ - "only-zero-rank": Log only from the main process (rank 0).
21
+ - "world": Log from all processes in the distributed environment.
22
+
23
+ """
24
+ match how:
25
+ case _ if not dist.is_initialized():
26
+ logger.warning_once(
27
+ "Distributed logging requested but torch.distributed is not initialized."
28
+ )
29
+ logger.info(msg)
30
+ case "only-zero-rank":
31
+ if is_main_process():
32
+ logger.info(msg)
33
+ case "world":
34
+ logger.info(msg)
35
+ case _:
36
+ logger.warning_once(
37
+ f"Invalid logging verbosity level requested: {how}. Message not logged."
38
+ )
39
+ return
40
+
41
+
42
+ def scale_lrs_by_world_size(
43
+ lrs: dict[str, float],
15
44
  group: dist.ProcessGroup | None = None,
16
45
  config_name: str = "",
17
46
  inv_scale: bool = False,
18
- ) -> Tlr:
47
+ verbose: Literal["only-zero-rank", "world"] | None = None,
48
+ ) -> dict[str, float]:
19
49
  """
20
50
  Scale learning-rate configuration values to match the active distributed world size.
21
51
 
52
+ Note:
53
+ The value in the `lrs` will be modified in place.
54
+
22
55
  Args:
23
- lr_config (Lr): Learning-rate configuration whose values will be scaled.
56
+ lrs (dict[str, float]): A dictionary of learning rate names and their corresponding values to be scaled.
24
57
  group (dist.ProcessGroup | None): Optional process group used to determine
25
58
  the target world size. Defaults to the global process group.
26
59
  config_name (str): Human-readable identifier included in log messages.
27
60
  inv_scale (bool): If True, use the inverse square-root scale factor.
61
+ verbose (Literal["only-zero-rank", "world"] | None): Verbosity level for logging scaled values.
62
+ - "only-zero-rank": Log only from the main process (rank 0).
63
+ - "world": Log from all processes in the distributed environment.
64
+ - None: No logging.
28
65
 
29
66
  Returns:
30
- Tlr: The learning-rate configuration with scaled values.
67
+ dict[str, float]: The learning-rate configuration with scaled values.
31
68
 
32
69
  """
33
70
  world_size = dist.get_world_size(group=group)
@@ -37,26 +74,16 @@ def scale_lrs_by_world_size[Tlr: Lr](
37
74
  else:
38
75
  scale = math.sqrt(world_size)
39
76
 
40
- logger.info(f"Scaling learning rates for world size: {world_size}")
41
- logger.info(f"Scale factor: {scale:.4f}")
42
- old_base = lr_config.base_value
43
- lr_config.base_value *= scale
44
- logger.info(f"New {config_name} lr BASE: {lr_config.base_value}; OLD: {old_base}")
45
-
46
- if lr_config.final_value is not None:
47
- old_final_value = lr_config.final_value
48
- lr_config.final_value *= scale
49
- logger.info(
50
- f"New {config_name} lr FINAL: {lr_config.final_value}; OLD: {old_final_value}"
51
- )
52
-
53
- if lr_config.warmup_value is not None:
54
- old_warmup_value = lr_config.warmup_value
55
- lr_config.warmup_value *= scale
56
- logger.info(
57
- f"New {config_name} lr WARMUP: {lr_config.warmup_value}; OLD: {old_warmup_value}"
58
- )
59
- return lr_config
77
+ for name, value in lrs.items():
78
+ old_value = value
79
+ new_value = value * scale
80
+ if verbose is not None:
81
+ log_dist(
82
+ f"New {config_name} lr {name.upper()}: {new_value}; OLD: {old_value}",
83
+ verbose,
84
+ )
85
+ lrs[name] = new_value
86
+ return lrs
60
87
 
61
88
 
62
89
  def _get_rank() -> int:
@@ -7,23 +7,66 @@ def create_params_groups(
7
7
  model: nn.Module,
8
8
  weight_decay: float,
9
9
  lr: float,
10
+ no_lr_keywords: set[str] | None = None,
11
+ no_decay_keywords: set[str] | None = None,
10
12
  ) -> list[dict]:
11
- """Create optimizer parameter groups for a PyTorch model with fine-grained weight decay control."""
13
+ """
14
+ Create optimizer parameter groups for a PyTorch model with fine-grained weight decay control.
15
+
16
+ This function iterates through the model's named parameters and assigns them to specific
17
+ parameter groups based on whether they should be subject to weight decay. Certain parameter
18
+ types (like normalization layers, biases, embeddings) are typically excluded from weight decay
19
+ to improve training stability.
20
+
21
+ Args:
22
+ model (nn.Module): The PyTorch model containing the parameters to optimize.
23
+ weight_decay (float): The default weight decay value to apply to parameters that are
24
+ not excluded.
25
+ lr (float): The learning rate to assign to all parameter groups.
26
+ no_lr_keywords (set[str] | None, optional): A set of string keywords. If a parameter's
27
+ name contains any of these keywords, its learning rate is set to 0.0.
28
+ Defaults to None, which uses an empty set.
29
+ no_decay_keywords (set[str] | None, optional): A set of string keywords. If a parameter's
30
+ name contains any of these keywords, its weight decay is set to 0.0.
31
+ If additional keywords are provided, they will be added to the default set.
32
+ Defaults to None, which uses a standard set of exclusion keywords:
33
+ {"norm", "bias", "embedding", "tokenizer", "ln", "scale"}.
34
+
35
+ Returns:
36
+ list[dict]: A list of dictionaries, where each dictionary represents a parameter group
37
+ compatible with PyTorch optimizers (e.g., `torch.optim.AdamW`). Each group contains:
38
+ - "params": The parameter tensor.
39
+ - "lr": The learning rate.
40
+ - "weight_decay": The specific weight decay value (0.0 or the provided default).
41
+
42
+ """
43
+ no_decay_keywords_ = {
44
+ "norm",
45
+ "bias",
46
+ "embedding",
47
+ "tokenizer",
48
+ "ln",
49
+ "scale",
50
+ }
51
+ if no_decay_keywords is not None:
52
+ no_decay_keywords_ = no_decay_keywords_.union(no_decay_keywords)
53
+
54
+ no_lr_keywords_ = set()
55
+ if no_lr_keywords is not None:
56
+ no_lr_keywords_ = no_lr_keywords_.union(no_lr_keywords)
57
+
12
58
  param_groups = []
13
59
  for name, param in model.named_parameters():
14
60
  if param.requires_grad is False:
15
61
  continue
16
- param_group = {"params": param, "lr": lr}
17
-
18
- if (
19
- ("norm" in name)
20
- or ("bias" in name)
21
- or ("embedding" in name)
22
- or ("tokenizer" in name)
23
- or ("output_projection_point" in name)
24
- or ("ln" in name)
25
- or ("scale" in name)
26
- ):
62
+
63
+ if any(keyword in name for keyword in no_lr_keywords_):
64
+ lr_ = 0.0
65
+ else:
66
+ lr_ = lr
67
+ param_group = {"params": param, "lr": lr_}
68
+
69
+ if any(keyword in name for keyword in no_decay_keywords_):
27
70
  param_group["weight_decay"] = 0.0
28
71
  else:
29
72
  param_group["weight_decay"] = weight_decay
kostyl/utils/logging.py CHANGED
@@ -5,9 +5,12 @@ import os
5
5
  import sys
6
6
  import uuid
7
7
  from copy import deepcopy
8
+ from functools import partialmethod
8
9
  from pathlib import Path
10
+ from threading import Lock
9
11
  from typing import TYPE_CHECKING
10
12
  from typing import Literal
13
+ from typing import cast
11
14
 
12
15
  from loguru import logger as _base_logger
13
16
  from torch.nn.modules.module import _IncompatibleKeys
@@ -16,6 +19,12 @@ from torch.nn.modules.module import _IncompatibleKeys
16
19
  if TYPE_CHECKING:
17
20
  from loguru import Logger
18
21
 
22
+ class CustomLogger(Logger): # noqa: D101
23
+ def log_once(self, level: str, message: str, *args, **kwargs) -> None: ... # noqa: ANN003, D102
24
+ def warning_once(self, message: str, *args, **kwargs) -> None: ... # noqa: ANN003, D102
25
+ else:
26
+ CustomLogger = type(_base_logger)
27
+
19
28
  try:
20
29
  import torch.distributed as dist
21
30
  except Exception:
@@ -31,10 +40,25 @@ except Exception:
31
40
 
32
41
  dist = _Dummy()
33
42
 
34
- _DEFAULT_SINK_REMOVED = False
35
- _DEFAULT_FMT = "<level>{level: <8}</level> {time:HH:mm:ss.SSS} [{extra[channel]}] <level>{message}</level>"
36
- _ONLY_MESSAGE_FMT = "<level>{message}</level>"
37
- _PRESETS = {"default": _DEFAULT_FMT, "only_message": _ONLY_MESSAGE_FMT}
43
+ _once_lock = Lock()
44
+ _once_keys: set[tuple[str, str]] = set()
45
+
46
+
47
+ def _log_once(self: CustomLogger, level: str, message: str, *args, **kwargs) -> None: # noqa: ANN003
48
+ key = (message, level)
49
+
50
+ with _once_lock:
51
+ if key in _once_keys:
52
+ return
53
+ _once_keys.add(key)
54
+
55
+ self.log(level, message, *args, **kwargs)
56
+ return
57
+
58
+
59
+ _base_logger = cast(CustomLogger, _base_logger)
60
+ _base_logger.log_once = _log_once # pyright: ignore[reportAttributeAccessIssue]
61
+ _base_logger.warning_once = partialmethod(_log_once, "WARNING") # pyright: ignore[reportAttributeAccessIssue]
38
62
 
39
63
 
40
64
  def _caller_filename() -> str:
@@ -43,6 +67,12 @@ def _caller_filename() -> str:
43
67
  return name
44
68
 
45
69
 
70
+ _DEFAULT_SINK_REMOVED = False
71
+ _DEFAULT_FMT = "<level>{level: <8}</level> {time:HH:mm:ss.SSS} [{extra[channel]}] <level>{message}</level>"
72
+ _ONLY_MESSAGE_FMT = "<level>{message}</level>"
73
+ _PRESETS = {"default": _DEFAULT_FMT, "only_message": _ONLY_MESSAGE_FMT}
74
+
75
+
46
76
  def setup_logger(
47
77
  name: str | None = None,
48
78
  fmt: Literal["default", "only_message"] | str = "default",
@@ -51,7 +81,7 @@ def setup_logger(
51
81
  sink=sys.stdout,
52
82
  colorize: bool = True,
53
83
  serialize: bool = False,
54
- ) -> Logger:
84
+ ) -> CustomLogger:
55
85
  """
56
86
  Returns a bound logger with its own sink and formatting.
57
87
 
@@ -96,8 +126,8 @@ def setup_logger(
96
126
  serialize=serialize,
97
127
  filter=lambda r: r["extra"].get("logger_id") == logger_id,
98
128
  )
99
-
100
- return _base_logger.bind(logger_id=logger_id, channel=channel)
129
+ logger = _base_logger.bind(logger_id=logger_id, channel=channel)
130
+ return cast(CustomLogger, logger)
101
131
 
102
132
 
103
133
  def log_incompatible_keys(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: clearml[s3]>=2.0.2
@@ -4,10 +4,10 @@ kostyl/ml_core/clearml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3
4
4
  kostyl/ml_core/clearml/logging_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWaJTyq8,1210
5
5
  kostyl/ml_core/clearml/pulling_utils.py,sha256=Yf70ux8dS0_ENdvfbNQkXOrDxwd4ed2GnRCmOR2ppEk,3252
6
6
  kostyl/ml_core/configs/__init__.py,sha256=RKSHp5J8eksqMxFu5xkpSxyswSpgKhrHLjltLS3yZXc,896
7
- kostyl/ml_core/configs/config_base.py,sha256=zla6qwIzIIg4i0ETG7Er2qYfc48hoGOPIbLRq1xqJPs,5376
7
+ kostyl/ml_core/configs/config_base.py,sha256=ctjedEKZbwByUr5HA-Ic0dVCPWPAIPL9kK8T0S-BOvk,5276
8
8
  kostyl/ml_core/configs/hyperparams.py,sha256=iKzuFOAL3xSVGjXlvRX_mbSBt0pqh6RQAxyHPmN-Bik,2974
9
- kostyl/ml_core/configs/training_params.py,sha256=ocPC2dAUFpxu2jgWvPFDdVFcgAsQEonJM4yPzGSpx20,2587
10
- kostyl/ml_core/dist_utils.py,sha256=C9lzT37jl7C2igQzqtvXNTdz3NJ6ORzrBRjIDl7PC7o,2221
9
+ kostyl/ml_core/configs/training_params.py,sha256=a8ewftu_xDatlbJ6qk_87WkuRpdThBGYQA2fHbjb9RU,2598
10
+ kostyl/ml_core/dist_utils.py,sha256=G8atjzkRbXZZiZh9rdEYBmeXqX26rJdDDovft2n6xiU,3201
11
11
  kostyl/ml_core/lightning/__init__.py,sha256=-F3JAyq8KU1d-nACWryGu8d1CbvWbQ1rXFdeRwfE2X8,175
12
12
  kostyl/ml_core/lightning/callbacks/__init__.py,sha256=Vd-rozY4T9Prr3IMqbliXxj6sC6y9XsovHQqRwzc2HI,297
13
13
  kostyl/ml_core/lightning/callbacks/checkpoint.py,sha256=RgkNNmsbAz9fdMYGlEgn9Qs_DF8LiuY7Bp1Hu4ZW98s,1946
@@ -20,14 +20,14 @@ kostyl/ml_core/lightning/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRk
20
20
  kostyl/ml_core/lightning/loggers/tb_logger.py,sha256=Zh9n-lLu-bXMld-FIUO3lJfCyDf0IQFhS3JVShDJmvg,937
21
21
  kostyl/ml_core/lightning/steps_estimation.py,sha256=fTZ0IrUEZV3H6VYlx4GYn56oco56mMiB7FO9F0Z7qc4,1511
22
22
  kostyl/ml_core/metrics_formatting.py,sha256=w0rTz61z0Um_d2pomYLvcQFcZX_C-KolZcIPRsa1efE,1421
23
- kostyl/ml_core/params_groups.py,sha256=AKQABbor3eOsNihzm0C3MvzbHRgwFxb5XTXUF3wdRbw,1542
23
+ kostyl/ml_core/params_groups.py,sha256=nUyw5d06Pvy9QPiYtZzLYR87xwXqJLxbHthgQH8oSCM,3583
24
24
  kostyl/ml_core/schedulers/__init__.py,sha256=bxXbsU_WYnVbhvNNnuI7cOAh2Axz7D25TaleBTZhYfc,197
25
25
  kostyl/ml_core/schedulers/base.py,sha256=9M2iOoOVSRojR_liPX1qo3Nn4iMXSM5ZJuAFWZTulUk,1327
26
26
  kostyl/ml_core/schedulers/composite.py,sha256=ee4xlMDMMtjKPkbTF2ue9GTr9DuGCGjZWf11mHbi6aE,2387
27
27
  kostyl/ml_core/schedulers/cosine.py,sha256=jufULVHn_L_ZZEc3ZTG3QCY_pc0jlAMH5Aw496T31jo,8203
28
28
  kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
29
29
  kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
30
- kostyl/utils/logging.py,sha256=126Zs0ym9w8IgM8wdUVgVp2kLqgRZM-cWtG6bQ--InI,4214
31
- kostyl_toolkit-0.1.1.dist-info/WHEEL,sha256=YUH1mBqsx8Dh2cQG2rlcuRYUhJddG9iClegy4IgnHik,79
32
- kostyl_toolkit-0.1.1.dist-info/METADATA,sha256=nhYquV3AKueHR_DVVvSw0jWgZTQJSIkmvXL6mVGcqeQ,4053
33
- kostyl_toolkit-0.1.1.dist-info/RECORD,,
30
+ kostyl/utils/logging.py,sha256=3MvfDPArZhwakHu5nMlp_LpOsWg0E0SP26y41clsBtA,5232
31
+ kostyl_toolkit-0.1.2.dist-info/WHEEL,sha256=YUH1mBqsx8Dh2cQG2rlcuRYUhJddG9iClegy4IgnHik,79
32
+ kostyl_toolkit-0.1.2.dist-info/METADATA,sha256=4aZUWVa-k5qqIZJFlOqyCLSwT3S-V_znIRMR1d3_tJ0,4053
33
+ kostyl_toolkit-0.1.2.dist-info/RECORD,,