kostyl-toolkit 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kostyl/ml_core/configs/config_base.py +4 -4
- kostyl/ml_core/configs/training_params.py +1 -1
- kostyl/ml_core/dist_utils.py +53 -26
- kostyl/ml_core/params_groups.py +55 -12
- kostyl/utils/logging.py +37 -7
- {kostyl_toolkit-0.1.1.dist-info → kostyl_toolkit-0.1.2.dist-info}/METADATA +1 -1
- {kostyl_toolkit-0.1.1.dist-info → kostyl_toolkit-0.1.2.dist-info}/RECORD +8 -8
- {kostyl_toolkit-0.1.1.dist-info → kostyl_toolkit-0.1.2.dist-info}/WHEEL +0 -0
|
@@ -35,7 +35,7 @@ class ConfigLoadingMixin:
|
|
|
35
35
|
|
|
36
36
|
@classmethod
|
|
37
37
|
def from_file(
|
|
38
|
-
cls: type[TConfig], # pyright: ignore
|
|
38
|
+
cls: type[TConfig], # pyright: ignore
|
|
39
39
|
path: str | Path,
|
|
40
40
|
) -> TConfig:
|
|
41
41
|
"""
|
|
@@ -55,7 +55,7 @@ class ConfigLoadingMixin:
|
|
|
55
55
|
|
|
56
56
|
@classmethod
|
|
57
57
|
def from_dict(
|
|
58
|
-
cls: type[TConfig], # pyright: ignore
|
|
58
|
+
cls: type[TConfig], # pyright: ignore
|
|
59
59
|
state_dict: dict,
|
|
60
60
|
) -> TConfig:
|
|
61
61
|
"""
|
|
@@ -83,7 +83,7 @@ class ClearMLConfigMixin(ConfigLoadingMixin):
|
|
|
83
83
|
|
|
84
84
|
@classmethod
|
|
85
85
|
def connect_as_file(
|
|
86
|
-
cls: type[TModel], # pyright: ignore
|
|
86
|
+
cls: type[TModel], # pyright: ignore
|
|
87
87
|
task: clearml.Task,
|
|
88
88
|
path: str | Path,
|
|
89
89
|
alias: str | None = None,
|
|
@@ -122,7 +122,7 @@ class ClearMLConfigMixin(ConfigLoadingMixin):
|
|
|
122
122
|
|
|
123
123
|
@classmethod
|
|
124
124
|
def connect_as_dict(
|
|
125
|
-
cls: type[TModel], # pyright: ignore
|
|
125
|
+
cls: type[TModel], # pyright: ignore
|
|
126
126
|
task: clearml.Task,
|
|
127
127
|
path: str | Path,
|
|
128
128
|
alias: str | None = None,
|
kostyl/ml_core/dist_utils.py
CHANGED
|
@@ -1,33 +1,70 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import os
|
|
3
|
+
from typing import Literal
|
|
3
4
|
|
|
4
5
|
import torch.distributed as dist
|
|
5
6
|
|
|
6
|
-
from kostyl.ml_core.configs import Lr
|
|
7
7
|
from kostyl.utils.logging import setup_logger
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
logger = setup_logger(add_rank=True)
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def
|
|
14
|
-
|
|
13
|
+
def log_dist(msg: str, how: Literal["only-zero-rank", "world"]) -> None:
|
|
14
|
+
"""
|
|
15
|
+
Log a message in a distributed environment based on the specified verbosity level.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
msg (str): The message to log.
|
|
19
|
+
how (Literal["only-zero-rank", "world"]): The verbosity level for logging.
|
|
20
|
+
- "only-zero-rank": Log only from the main process (rank 0).
|
|
21
|
+
- "world": Log from all processes in the distributed environment.
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
match how:
|
|
25
|
+
case _ if not dist.is_initialized():
|
|
26
|
+
logger.warning_once(
|
|
27
|
+
"Distributed logging requested but torch.distributed is not initialized."
|
|
28
|
+
)
|
|
29
|
+
logger.info(msg)
|
|
30
|
+
case "only-zero-rank":
|
|
31
|
+
if is_main_process():
|
|
32
|
+
logger.info(msg)
|
|
33
|
+
case "world":
|
|
34
|
+
logger.info(msg)
|
|
35
|
+
case _:
|
|
36
|
+
logger.warning_once(
|
|
37
|
+
f"Invalid logging verbosity level requested: {how}. Message not logged."
|
|
38
|
+
)
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def scale_lrs_by_world_size(
|
|
43
|
+
lrs: dict[str, float],
|
|
15
44
|
group: dist.ProcessGroup | None = None,
|
|
16
45
|
config_name: str = "",
|
|
17
46
|
inv_scale: bool = False,
|
|
18
|
-
|
|
47
|
+
verbose: Literal["only-zero-rank", "world"] | None = None,
|
|
48
|
+
) -> dict[str, float]:
|
|
19
49
|
"""
|
|
20
50
|
Scale learning-rate configuration values to match the active distributed world size.
|
|
21
51
|
|
|
52
|
+
Note:
|
|
53
|
+
The value in the `lrs` will be modified in place.
|
|
54
|
+
|
|
22
55
|
Args:
|
|
23
|
-
|
|
56
|
+
lrs (dict[str, float]): A dictionary of learning rate names and their corresponding values to be scaled.
|
|
24
57
|
group (dist.ProcessGroup | None): Optional process group used to determine
|
|
25
58
|
the target world size. Defaults to the global process group.
|
|
26
59
|
config_name (str): Human-readable identifier included in log messages.
|
|
27
60
|
inv_scale (bool): If True, use the inverse square-root scale factor.
|
|
61
|
+
verbose (Literal["only-zero-rank", "world"] | None): Verbosity level for logging scaled values.
|
|
62
|
+
- "only-zero-rank": Log only from the main process (rank 0).
|
|
63
|
+
- "world": Log from all processes in the distributed environment.
|
|
64
|
+
- None: No logging.
|
|
28
65
|
|
|
29
66
|
Returns:
|
|
30
|
-
|
|
67
|
+
dict[str, float]: The learning-rate configuration with scaled values.
|
|
31
68
|
|
|
32
69
|
"""
|
|
33
70
|
world_size = dist.get_world_size(group=group)
|
|
@@ -37,26 +74,16 @@ def scale_lrs_by_world_size[Tlr: Lr](
|
|
|
37
74
|
else:
|
|
38
75
|
scale = math.sqrt(world_size)
|
|
39
76
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
f"New {config_name} lr FINAL: {lr_config.final_value}; OLD: {old_final_value}"
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
if lr_config.warmup_value is not None:
|
|
54
|
-
old_warmup_value = lr_config.warmup_value
|
|
55
|
-
lr_config.warmup_value *= scale
|
|
56
|
-
logger.info(
|
|
57
|
-
f"New {config_name} lr WARMUP: {lr_config.warmup_value}; OLD: {old_warmup_value}"
|
|
58
|
-
)
|
|
59
|
-
return lr_config
|
|
77
|
+
for name, value in lrs.items():
|
|
78
|
+
old_value = value
|
|
79
|
+
new_value = value * scale
|
|
80
|
+
if verbose is not None:
|
|
81
|
+
log_dist(
|
|
82
|
+
f"New {config_name} lr {name.upper()}: {new_value}; OLD: {old_value}",
|
|
83
|
+
verbose,
|
|
84
|
+
)
|
|
85
|
+
lrs[name] = new_value
|
|
86
|
+
return lrs
|
|
60
87
|
|
|
61
88
|
|
|
62
89
|
def _get_rank() -> int:
|
kostyl/ml_core/params_groups.py
CHANGED
|
@@ -7,23 +7,66 @@ def create_params_groups(
|
|
|
7
7
|
model: nn.Module,
|
|
8
8
|
weight_decay: float,
|
|
9
9
|
lr: float,
|
|
10
|
+
no_lr_keywords: set[str] | None = None,
|
|
11
|
+
no_decay_keywords: set[str] | None = None,
|
|
10
12
|
) -> list[dict]:
|
|
11
|
-
"""
|
|
13
|
+
"""
|
|
14
|
+
Create optimizer parameter groups for a PyTorch model with fine-grained weight decay control.
|
|
15
|
+
|
|
16
|
+
This function iterates through the model's named parameters and assigns them to specific
|
|
17
|
+
parameter groups based on whether they should be subject to weight decay. Certain parameter
|
|
18
|
+
types (like normalization layers, biases, embeddings) are typically excluded from weight decay
|
|
19
|
+
to improve training stability.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
model (nn.Module): The PyTorch model containing the parameters to optimize.
|
|
23
|
+
weight_decay (float): The default weight decay value to apply to parameters that are
|
|
24
|
+
not excluded.
|
|
25
|
+
lr (float): The learning rate to assign to all parameter groups.
|
|
26
|
+
no_lr_keywords (set[str] | None, optional): A set of string keywords. If a parameter's
|
|
27
|
+
name contains any of these keywords, its learning rate is set to 0.0.
|
|
28
|
+
Defaults to None, which uses an empty set.
|
|
29
|
+
no_decay_keywords (set[str] | None, optional): A set of string keywords. If a parameter's
|
|
30
|
+
name contains any of these keywords, its weight decay is set to 0.0.
|
|
31
|
+
If additional keywords are provided, they will be added to the default set.
|
|
32
|
+
Defaults to None, which uses a standard set of exclusion keywords:
|
|
33
|
+
{"norm", "bias", "embedding", "tokenizer", "ln", "scale"}.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
list[dict]: A list of dictionaries, where each dictionary represents a parameter group
|
|
37
|
+
compatible with PyTorch optimizers (e.g., `torch.optim.AdamW`). Each group contains:
|
|
38
|
+
- "params": The parameter tensor.
|
|
39
|
+
- "lr": The learning rate.
|
|
40
|
+
- "weight_decay": The specific weight decay value (0.0 or the provided default).
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
no_decay_keywords_ = {
|
|
44
|
+
"norm",
|
|
45
|
+
"bias",
|
|
46
|
+
"embedding",
|
|
47
|
+
"tokenizer",
|
|
48
|
+
"ln",
|
|
49
|
+
"scale",
|
|
50
|
+
}
|
|
51
|
+
if no_decay_keywords is not None:
|
|
52
|
+
no_decay_keywords_ = no_decay_keywords_.union(no_decay_keywords)
|
|
53
|
+
|
|
54
|
+
no_lr_keywords_ = set()
|
|
55
|
+
if no_lr_keywords is not None:
|
|
56
|
+
no_lr_keywords_ = no_lr_keywords_.union(no_lr_keywords)
|
|
57
|
+
|
|
12
58
|
param_groups = []
|
|
13
59
|
for name, param in model.named_parameters():
|
|
14
60
|
if param.requires_grad is False:
|
|
15
61
|
continue
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
or ("ln" in name)
|
|
25
|
-
or ("scale" in name)
|
|
26
|
-
):
|
|
62
|
+
|
|
63
|
+
if any(keyword in name for keyword in no_lr_keywords_):
|
|
64
|
+
lr_ = 0.0
|
|
65
|
+
else:
|
|
66
|
+
lr_ = lr
|
|
67
|
+
param_group = {"params": param, "lr": lr_}
|
|
68
|
+
|
|
69
|
+
if any(keyword in name for keyword in no_decay_keywords_):
|
|
27
70
|
param_group["weight_decay"] = 0.0
|
|
28
71
|
else:
|
|
29
72
|
param_group["weight_decay"] = weight_decay
|
kostyl/utils/logging.py
CHANGED
|
@@ -5,9 +5,12 @@ import os
|
|
|
5
5
|
import sys
|
|
6
6
|
import uuid
|
|
7
7
|
from copy import deepcopy
|
|
8
|
+
from functools import partialmethod
|
|
8
9
|
from pathlib import Path
|
|
10
|
+
from threading import Lock
|
|
9
11
|
from typing import TYPE_CHECKING
|
|
10
12
|
from typing import Literal
|
|
13
|
+
from typing import cast
|
|
11
14
|
|
|
12
15
|
from loguru import logger as _base_logger
|
|
13
16
|
from torch.nn.modules.module import _IncompatibleKeys
|
|
@@ -16,6 +19,12 @@ from torch.nn.modules.module import _IncompatibleKeys
|
|
|
16
19
|
if TYPE_CHECKING:
|
|
17
20
|
from loguru import Logger
|
|
18
21
|
|
|
22
|
+
class CustomLogger(Logger): # noqa: D101
|
|
23
|
+
def log_once(self, level: str, message: str, *args, **kwargs) -> None: ... # noqa: ANN003, D102
|
|
24
|
+
def warning_once(self, message: str, *args, **kwargs) -> None: ... # noqa: ANN003, D102
|
|
25
|
+
else:
|
|
26
|
+
CustomLogger = type(_base_logger)
|
|
27
|
+
|
|
19
28
|
try:
|
|
20
29
|
import torch.distributed as dist
|
|
21
30
|
except Exception:
|
|
@@ -31,10 +40,25 @@ except Exception:
|
|
|
31
40
|
|
|
32
41
|
dist = _Dummy()
|
|
33
42
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
43
|
+
_once_lock = Lock()
|
|
44
|
+
_once_keys: set[tuple[str, str]] = set()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _log_once(self: CustomLogger, level: str, message: str, *args, **kwargs) -> None: # noqa: ANN003
|
|
48
|
+
key = (message, level)
|
|
49
|
+
|
|
50
|
+
with _once_lock:
|
|
51
|
+
if key in _once_keys:
|
|
52
|
+
return
|
|
53
|
+
_once_keys.add(key)
|
|
54
|
+
|
|
55
|
+
self.log(level, message, *args, **kwargs)
|
|
56
|
+
return
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
_base_logger = cast(CustomLogger, _base_logger)
|
|
60
|
+
_base_logger.log_once = _log_once # pyright: ignore[reportAttributeAccessIssue]
|
|
61
|
+
_base_logger.warning_once = partialmethod(_log_once, "WARNING") # pyright: ignore[reportAttributeAccessIssue]
|
|
38
62
|
|
|
39
63
|
|
|
40
64
|
def _caller_filename() -> str:
|
|
@@ -43,6 +67,12 @@ def _caller_filename() -> str:
|
|
|
43
67
|
return name
|
|
44
68
|
|
|
45
69
|
|
|
70
|
+
_DEFAULT_SINK_REMOVED = False
|
|
71
|
+
_DEFAULT_FMT = "<level>{level: <8}</level> {time:HH:mm:ss.SSS} [{extra[channel]}] <level>{message}</level>"
|
|
72
|
+
_ONLY_MESSAGE_FMT = "<level>{message}</level>"
|
|
73
|
+
_PRESETS = {"default": _DEFAULT_FMT, "only_message": _ONLY_MESSAGE_FMT}
|
|
74
|
+
|
|
75
|
+
|
|
46
76
|
def setup_logger(
|
|
47
77
|
name: str | None = None,
|
|
48
78
|
fmt: Literal["default", "only_message"] | str = "default",
|
|
@@ -51,7 +81,7 @@ def setup_logger(
|
|
|
51
81
|
sink=sys.stdout,
|
|
52
82
|
colorize: bool = True,
|
|
53
83
|
serialize: bool = False,
|
|
54
|
-
) ->
|
|
84
|
+
) -> CustomLogger:
|
|
55
85
|
"""
|
|
56
86
|
Returns a bound logger with its own sink and formatting.
|
|
57
87
|
|
|
@@ -96,8 +126,8 @@ def setup_logger(
|
|
|
96
126
|
serialize=serialize,
|
|
97
127
|
filter=lambda r: r["extra"].get("logger_id") == logger_id,
|
|
98
128
|
)
|
|
99
|
-
|
|
100
|
-
return
|
|
129
|
+
logger = _base_logger.bind(logger_id=logger_id, channel=channel)
|
|
130
|
+
return cast(CustomLogger, logger)
|
|
101
131
|
|
|
102
132
|
|
|
103
133
|
def log_incompatible_keys(
|
|
@@ -4,10 +4,10 @@ kostyl/ml_core/clearml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3
|
|
|
4
4
|
kostyl/ml_core/clearml/logging_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWaJTyq8,1210
|
|
5
5
|
kostyl/ml_core/clearml/pulling_utils.py,sha256=Yf70ux8dS0_ENdvfbNQkXOrDxwd4ed2GnRCmOR2ppEk,3252
|
|
6
6
|
kostyl/ml_core/configs/__init__.py,sha256=RKSHp5J8eksqMxFu5xkpSxyswSpgKhrHLjltLS3yZXc,896
|
|
7
|
-
kostyl/ml_core/configs/config_base.py,sha256=
|
|
7
|
+
kostyl/ml_core/configs/config_base.py,sha256=ctjedEKZbwByUr5HA-Ic0dVCPWPAIPL9kK8T0S-BOvk,5276
|
|
8
8
|
kostyl/ml_core/configs/hyperparams.py,sha256=iKzuFOAL3xSVGjXlvRX_mbSBt0pqh6RQAxyHPmN-Bik,2974
|
|
9
|
-
kostyl/ml_core/configs/training_params.py,sha256=
|
|
10
|
-
kostyl/ml_core/dist_utils.py,sha256=
|
|
9
|
+
kostyl/ml_core/configs/training_params.py,sha256=a8ewftu_xDatlbJ6qk_87WkuRpdThBGYQA2fHbjb9RU,2598
|
|
10
|
+
kostyl/ml_core/dist_utils.py,sha256=G8atjzkRbXZZiZh9rdEYBmeXqX26rJdDDovft2n6xiU,3201
|
|
11
11
|
kostyl/ml_core/lightning/__init__.py,sha256=-F3JAyq8KU1d-nACWryGu8d1CbvWbQ1rXFdeRwfE2X8,175
|
|
12
12
|
kostyl/ml_core/lightning/callbacks/__init__.py,sha256=Vd-rozY4T9Prr3IMqbliXxj6sC6y9XsovHQqRwzc2HI,297
|
|
13
13
|
kostyl/ml_core/lightning/callbacks/checkpoint.py,sha256=RgkNNmsbAz9fdMYGlEgn9Qs_DF8LiuY7Bp1Hu4ZW98s,1946
|
|
@@ -20,14 +20,14 @@ kostyl/ml_core/lightning/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRk
|
|
|
20
20
|
kostyl/ml_core/lightning/loggers/tb_logger.py,sha256=Zh9n-lLu-bXMld-FIUO3lJfCyDf0IQFhS3JVShDJmvg,937
|
|
21
21
|
kostyl/ml_core/lightning/steps_estimation.py,sha256=fTZ0IrUEZV3H6VYlx4GYn56oco56mMiB7FO9F0Z7qc4,1511
|
|
22
22
|
kostyl/ml_core/metrics_formatting.py,sha256=w0rTz61z0Um_d2pomYLvcQFcZX_C-KolZcIPRsa1efE,1421
|
|
23
|
-
kostyl/ml_core/params_groups.py,sha256=
|
|
23
|
+
kostyl/ml_core/params_groups.py,sha256=nUyw5d06Pvy9QPiYtZzLYR87xwXqJLxbHthgQH8oSCM,3583
|
|
24
24
|
kostyl/ml_core/schedulers/__init__.py,sha256=bxXbsU_WYnVbhvNNnuI7cOAh2Axz7D25TaleBTZhYfc,197
|
|
25
25
|
kostyl/ml_core/schedulers/base.py,sha256=9M2iOoOVSRojR_liPX1qo3Nn4iMXSM5ZJuAFWZTulUk,1327
|
|
26
26
|
kostyl/ml_core/schedulers/composite.py,sha256=ee4xlMDMMtjKPkbTF2ue9GTr9DuGCGjZWf11mHbi6aE,2387
|
|
27
27
|
kostyl/ml_core/schedulers/cosine.py,sha256=jufULVHn_L_ZZEc3ZTG3QCY_pc0jlAMH5Aw496T31jo,8203
|
|
28
28
|
kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
|
|
29
29
|
kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
|
|
30
|
-
kostyl/utils/logging.py,sha256=
|
|
31
|
-
kostyl_toolkit-0.1.
|
|
32
|
-
kostyl_toolkit-0.1.
|
|
33
|
-
kostyl_toolkit-0.1.
|
|
30
|
+
kostyl/utils/logging.py,sha256=3MvfDPArZhwakHu5nMlp_LpOsWg0E0SP26y41clsBtA,5232
|
|
31
|
+
kostyl_toolkit-0.1.2.dist-info/WHEEL,sha256=YUH1mBqsx8Dh2cQG2rlcuRYUhJddG9iClegy4IgnHik,79
|
|
32
|
+
kostyl_toolkit-0.1.2.dist-info/METADATA,sha256=4aZUWVa-k5qqIZJFlOqyCLSwT3S-V_znIRMR1d3_tJ0,4053
|
|
33
|
+
kostyl_toolkit-0.1.2.dist-info/RECORD,,
|
|
File without changes
|