kostyl-toolkit 0.1.36__py3-none-any.whl → 0.1.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kostyl/ml/base_uploader.py +17 -0
- kostyl/ml/configs/__init__.py +2 -2
- kostyl/ml/configs/mixins.py +50 -0
- kostyl/ml/{data_processing_utils.py → data_collator.py} +6 -3
- kostyl/ml/dist_utils.py +53 -33
- kostyl/ml/integrations/clearml/__init__.py +7 -0
- kostyl/ml/{registry_uploader.py → integrations/clearml/checkpoint_uploader.py} +3 -13
- kostyl/ml/{configs/base_model.py → integrations/clearml/config_mixin.py} +7 -63
- kostyl/ml/{clearml/pulling_utils.py → integrations/clearml/loading_utils.py} +32 -5
- kostyl/ml/integrations/lightning/__init__.py +14 -0
- kostyl/ml/{lightning → integrations/lightning}/callbacks/checkpoint.py +27 -42
- kostyl/ml/{lightning → integrations/lightning}/loggers/tb_logger.py +2 -2
- kostyl/ml/{lightning/extensions/pretrained_model.py → integrations/lightning/mixins.py} +6 -4
- kostyl/ml/{lightning/extensions/custom_module.py → integrations/lightning/module.py} +2 -38
- kostyl/ml/{lightning → integrations/lightning}/utils.py +1 -1
- kostyl/ml/schedulers/__init__.py +4 -4
- kostyl/ml/schedulers/{cosine_with_plateu.py → plateau.py} +59 -36
- kostyl/utils/logging.py +67 -52
- {kostyl_toolkit-0.1.36.dist-info → kostyl_toolkit-0.1.38.dist-info}/METADATA +1 -1
- kostyl_toolkit-0.1.38.dist-info/RECORD +40 -0
- {kostyl_toolkit-0.1.36.dist-info → kostyl_toolkit-0.1.38.dist-info}/WHEEL +2 -2
- kostyl/ml/lightning/__init__.py +0 -5
- kostyl/ml/lightning/extensions/__init__.py +0 -5
- kostyl_toolkit-0.1.36.dist-info/RECORD +0 -38
- /kostyl/ml/{clearml → integrations}/__init__.py +0 -0
- /kostyl/ml/{clearml → integrations/clearml}/dataset_utils.py +0 -0
- /kostyl/ml/{clearml/logging_utils.py → integrations/clearml/version_utils.py} +0 -0
- /kostyl/ml/{lightning → integrations/lightning}/callbacks/__init__.py +0 -0
- /kostyl/ml/{lightning → integrations/lightning}/callbacks/early_stopping.py +0 -0
- /kostyl/ml/{lightning → integrations/lightning}/loggers/__init__.py +0 -0
- /kostyl/ml/{metrics_formatting.py → integrations/lightning/metrics_formatting.py} +0 -0
|
@@ -12,12 +12,12 @@ from kostyl.utils.logging import setup_logger
|
|
|
12
12
|
logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class LightningCheckpointLoaderMixin
|
|
15
|
+
class LightningCheckpointLoaderMixin:
|
|
16
16
|
"""A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
|
|
17
17
|
|
|
18
18
|
@classmethod
|
|
19
|
-
def from_lightning_checkpoint[TModelInstance:
|
|
20
|
-
cls: type[TModelInstance],
|
|
19
|
+
def from_lightning_checkpoint[TModelInstance: PreTrainedModel]( # noqa: C901
|
|
20
|
+
cls: type[TModelInstance], # pyright: ignore[reportGeneralTypeIssues]
|
|
21
21
|
checkpoint_path: str | Path,
|
|
22
22
|
config_key: str = "config",
|
|
23
23
|
weights_prefix: str | None = "model.",
|
|
@@ -78,7 +78,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
78
78
|
mmap=True,
|
|
79
79
|
)
|
|
80
80
|
|
|
81
|
-
#
|
|
81
|
+
# Load config
|
|
82
82
|
config_cls = cast(type[PretrainedConfig], cls.config_class)
|
|
83
83
|
config_dict = checkpoint_dict[config_key]
|
|
84
84
|
config_dict.update(kwargs)
|
|
@@ -91,6 +91,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
91
91
|
|
|
92
92
|
raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
|
|
93
93
|
|
|
94
|
+
# Handle weights prefix
|
|
94
95
|
if weights_prefix:
|
|
95
96
|
if not weights_prefix.endswith("."):
|
|
96
97
|
weights_prefix = weights_prefix + "."
|
|
@@ -117,6 +118,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
117
118
|
else:
|
|
118
119
|
state_dict = raw_state_dict
|
|
119
120
|
|
|
121
|
+
# Instantiate model and load state dict
|
|
120
122
|
model = cls.from_pretrained(
|
|
121
123
|
pretrained_model_name_or_path=None,
|
|
122
124
|
config=config,
|
|
@@ -5,17 +5,15 @@ from typing import override
|
|
|
5
5
|
|
|
6
6
|
import lightning as L
|
|
7
7
|
import torch
|
|
8
|
-
import torch.distributed as dist
|
|
9
8
|
from lightning.pytorch.strategies import FSDPStrategy
|
|
10
9
|
from torch import nn
|
|
11
|
-
from torch.distributed import ProcessGroup
|
|
12
10
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
13
11
|
from torchmetrics import Metric
|
|
14
12
|
from torchmetrics import MetricCollection
|
|
15
13
|
from transformers import PretrainedConfig
|
|
16
14
|
from transformers import PreTrainedModel
|
|
17
15
|
|
|
18
|
-
from kostyl.ml.metrics_formatting import apply_suffix
|
|
16
|
+
from kostyl.ml.integrations.lightning.metrics_formatting import apply_suffix
|
|
19
17
|
from kostyl.ml.schedulers.base import BaseScheduler
|
|
20
18
|
from kostyl.utils import setup_logger
|
|
21
19
|
|
|
@@ -26,37 +24,6 @@ module_logger = setup_logger(fmt="only_message")
|
|
|
26
24
|
class KostylLightningModule(L.LightningModule):
|
|
27
25
|
"""Custom PyTorch Lightning Module with logging, checkpointing, and distributed training utilities."""
|
|
28
26
|
|
|
29
|
-
@property
|
|
30
|
-
def process_group(self) -> ProcessGroup | None:
|
|
31
|
-
"""Returns the data parallel process group for distributed training."""
|
|
32
|
-
return self.get_process_group()
|
|
33
|
-
|
|
34
|
-
def get_process_group(self) -> ProcessGroup | None:
|
|
35
|
-
"""
|
|
36
|
-
Retrieves the data parallel process group for distributed training.
|
|
37
|
-
|
|
38
|
-
This method checks if distributed processing is initialized. If a device mesh is provided,
|
|
39
|
-
it extracts the data parallel mesh and returns its process group, unless the mesh size is 1,
|
|
40
|
-
in which case it logs a warning and returns None. If no device mesh is provided, it returns
|
|
41
|
-
the world process group.
|
|
42
|
-
|
|
43
|
-
Returns:
|
|
44
|
-
ProcessGroup | None: The data parallel process group if available and valid, otherwise None.
|
|
45
|
-
|
|
46
|
-
"""
|
|
47
|
-
if not dist.is_initialized():
|
|
48
|
-
return None
|
|
49
|
-
|
|
50
|
-
if self.device_mesh is not None:
|
|
51
|
-
dp_mesh = self.device_mesh["data_parallel"]
|
|
52
|
-
if dp_mesh.size() == 1:
|
|
53
|
-
module_logger.warning("Data parallel mesh size is 1, returning None")
|
|
54
|
-
return None
|
|
55
|
-
dp_pg = dp_mesh.get_group()
|
|
56
|
-
else:
|
|
57
|
-
dp_pg = dist.group.WORLD
|
|
58
|
-
return dp_pg
|
|
59
|
-
|
|
60
27
|
@property
|
|
61
28
|
def model_instance(self) -> PreTrainedModel | nn.Module:
|
|
62
29
|
"""Returns the underlying model."""
|
|
@@ -65,10 +32,7 @@ class KostylLightningModule(L.LightningModule):
|
|
|
65
32
|
@property
|
|
66
33
|
def model_config(self) -> PretrainedConfig | None:
|
|
67
34
|
"""Returns the model configuration if available."""
|
|
68
|
-
|
|
69
|
-
if hasattr(model, "config"):
|
|
70
|
-
return model.config # type: ignore
|
|
71
|
-
return None
|
|
35
|
+
raise NotImplementedError
|
|
72
36
|
|
|
73
37
|
@property
|
|
74
38
|
def grad_clip_val(self) -> float | None:
|
kostyl/ml/schedulers/__init__.py
CHANGED
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
from .composite import CompositeScheduler
|
|
2
2
|
from .cosine import CosineParamScheduler
|
|
3
3
|
from .cosine import CosineScheduler
|
|
4
|
-
from .cosine_with_plateu import CosineWithPlateauParamScheduler
|
|
5
|
-
from .cosine_with_plateu import CosineWithPlateuScheduler
|
|
6
4
|
from .linear import LinearParamScheduler
|
|
7
5
|
from .linear import LinearScheduler
|
|
6
|
+
from .plateau import PlateauWithAnnealingParamScheduler
|
|
7
|
+
from .plateau import PlateauWithAnnealingScheduler
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
__all__ = [
|
|
11
11
|
"CompositeScheduler",
|
|
12
12
|
"CosineParamScheduler",
|
|
13
13
|
"CosineScheduler",
|
|
14
|
-
"CosineWithPlateauParamScheduler",
|
|
15
|
-
"CosineWithPlateuScheduler",
|
|
16
14
|
"LinearParamScheduler",
|
|
17
15
|
"LinearScheduler",
|
|
16
|
+
"PlateauWithAnnealingParamScheduler",
|
|
17
|
+
"PlateauWithAnnealingScheduler",
|
|
18
18
|
]
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from typing import Any
|
|
2
|
+
from typing import Literal
|
|
2
3
|
from typing import override
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
@@ -7,20 +8,25 @@ import torch
|
|
|
7
8
|
from .base import BaseScheduler
|
|
8
9
|
|
|
9
10
|
|
|
10
|
-
class
|
|
11
|
-
"""Core
|
|
11
|
+
class _PlateauWithAnnealingCore(BaseScheduler):
|
|
12
|
+
"""Core annealing with plateau scheduler logic."""
|
|
12
13
|
|
|
13
14
|
def __init__(
|
|
14
15
|
self,
|
|
15
16
|
param_name: str,
|
|
16
17
|
num_iters: int,
|
|
17
|
-
|
|
18
|
+
plateau_value: float,
|
|
18
19
|
final_value: float,
|
|
19
20
|
plateau_ratio: float,
|
|
20
21
|
warmup_value: float | None = None,
|
|
21
22
|
warmup_ratio: float | None = None,
|
|
22
23
|
freeze_ratio: float | None = None,
|
|
24
|
+
annealing_type: Literal["cosine", "linear"] = "cosine",
|
|
23
25
|
) -> None:
|
|
26
|
+
if annealing_type not in ("cosine", "linear"):
|
|
27
|
+
raise ValueError(
|
|
28
|
+
f"Annealing type must be 'cosine' or 'linear', got {annealing_type}."
|
|
29
|
+
)
|
|
24
30
|
if warmup_ratio is not None:
|
|
25
31
|
if not (0 < warmup_ratio < 1):
|
|
26
32
|
raise ValueError(f"Warmup ratio must be in (0, 1), got {warmup_ratio}.")
|
|
@@ -47,16 +53,17 @@ class _CosineWithPlateauSchedulerCore(BaseScheduler):
|
|
|
47
53
|
|
|
48
54
|
self.param_name = param_name
|
|
49
55
|
self.num_iters = num_iters
|
|
50
|
-
self.
|
|
56
|
+
self.plateau_value = plateau_value
|
|
51
57
|
self.final_value = final_value
|
|
52
|
-
self.
|
|
58
|
+
self.annealing_ratio = 1 - pre_annealing_ratio
|
|
53
59
|
self.plateau_ratio = plateau_ratio
|
|
54
60
|
self.warmup_ratio = warmup_ratio
|
|
55
61
|
self.warmup_value = warmup_value
|
|
56
62
|
self.freeze_ratio = freeze_ratio
|
|
63
|
+
self.annealing_type = annealing_type
|
|
57
64
|
|
|
58
65
|
self.scheduled_values: np.ndarray = np.array([], dtype=np.float64)
|
|
59
|
-
self.current_value_ = self.
|
|
66
|
+
self.current_value_ = self.plateau_value
|
|
60
67
|
return
|
|
61
68
|
|
|
62
69
|
def _create_scheduler(self) -> None:
|
|
@@ -72,28 +79,41 @@ class _CosineWithPlateauSchedulerCore(BaseScheduler):
|
|
|
72
79
|
if self.warmup_ratio is not None and self.warmup_value is not None:
|
|
73
80
|
warmup_iters = int(self.num_iters * self.warmup_ratio)
|
|
74
81
|
warmup_schedule = np.linspace(
|
|
75
|
-
self.warmup_value, self.
|
|
82
|
+
self.warmup_value, self.plateau_value, warmup_iters, dtype=np.float64
|
|
76
83
|
)
|
|
77
84
|
else:
|
|
78
85
|
warmup_iters = 0
|
|
79
86
|
warmup_schedule = np.array([], dtype=np.float64)
|
|
80
87
|
|
|
81
|
-
# Create
|
|
82
|
-
if self.
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
+
# Create annealing schedule
|
|
89
|
+
if self.annealing_ratio > 0:
|
|
90
|
+
annealing_iters = int(self.num_iters * self.annealing_ratio)
|
|
91
|
+
match self.annealing_type:
|
|
92
|
+
case "cosine":
|
|
93
|
+
iters = np.arange(annealing_iters)
|
|
94
|
+
annealing_schedule = self.final_value + 0.5 * (
|
|
95
|
+
self.plateau_value - self.final_value
|
|
96
|
+
) * (1 + np.cos(np.pi * iters / len(iters)))
|
|
97
|
+
case "linear":
|
|
98
|
+
annealing_schedule = np.linspace(
|
|
99
|
+
self.plateau_value,
|
|
100
|
+
self.final_value,
|
|
101
|
+
annealing_iters,
|
|
102
|
+
dtype=np.float64,
|
|
103
|
+
)
|
|
104
|
+
case _:
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"Unsupported annealing type: {self.annealing_type}"
|
|
107
|
+
)
|
|
88
108
|
else:
|
|
89
|
-
|
|
90
|
-
|
|
109
|
+
annealing_iters = 0
|
|
110
|
+
annealing_schedule = np.array([], dtype=np.float64)
|
|
91
111
|
|
|
92
|
-
plateau_iters =
|
|
93
|
-
self.num_iters - warmup_iters - freeze_iters - cosine_annealing_iters
|
|
94
|
-
)
|
|
112
|
+
plateau_iters = self.num_iters - warmup_iters - freeze_iters - annealing_iters
|
|
95
113
|
if plateau_iters > 0:
|
|
96
|
-
plateau_schedule = np.full(
|
|
114
|
+
plateau_schedule = np.full(
|
|
115
|
+
plateau_iters, self.plateau_value, dtype=np.float64
|
|
116
|
+
)
|
|
97
117
|
else:
|
|
98
118
|
plateau_schedule = np.array([], dtype=np.float64)
|
|
99
119
|
|
|
@@ -103,7 +123,7 @@ class _CosineWithPlateauSchedulerCore(BaseScheduler):
|
|
|
103
123
|
freeze_schedule,
|
|
104
124
|
warmup_schedule,
|
|
105
125
|
plateau_schedule,
|
|
106
|
-
|
|
126
|
+
annealing_schedule,
|
|
107
127
|
)
|
|
108
128
|
)
|
|
109
129
|
self._verify()
|
|
@@ -137,12 +157,12 @@ class _CosineWithPlateauSchedulerCore(BaseScheduler):
|
|
|
137
157
|
return {self.param_name: self.current_value_}
|
|
138
158
|
|
|
139
159
|
|
|
140
|
-
class
|
|
160
|
+
class PlateauWithAnnealingScheduler(_PlateauWithAnnealingCore):
|
|
141
161
|
"""
|
|
142
|
-
Applies
|
|
162
|
+
Applies an annealing schedule with plateau to an optimizer param-group field.
|
|
143
163
|
|
|
144
|
-
Schedule phases: freeze (0) → warmup → plateau (
|
|
145
|
-
The plateau phase maintains the
|
|
164
|
+
Schedule phases: freeze (0) → warmup → plateau (plateau_value) → annealing (cosine/linear) to final_value.
|
|
165
|
+
The plateau phase maintains the plateau_value before annealing begins.
|
|
146
166
|
"""
|
|
147
167
|
|
|
148
168
|
def __init__(
|
|
@@ -150,30 +170,32 @@ class CosineWithPlateuScheduler(_CosineWithPlateauSchedulerCore):
|
|
|
150
170
|
optimizer: torch.optim.Optimizer,
|
|
151
171
|
param_group_field: str,
|
|
152
172
|
num_iters: int,
|
|
153
|
-
|
|
173
|
+
plateau_value: float,
|
|
154
174
|
final_value: float,
|
|
155
175
|
plateau_ratio: float,
|
|
156
176
|
warmup_value: float | None = None,
|
|
157
177
|
warmup_ratio: float | None = None,
|
|
158
178
|
freeze_ratio: float | None = None,
|
|
179
|
+
annealing_type: Literal["cosine", "linear"] = "cosine",
|
|
159
180
|
multiplier_field: str | None = None,
|
|
160
181
|
skip_if_zero: bool = False,
|
|
161
182
|
apply_if_field: str | None = None,
|
|
162
183
|
ignore_if_field: str | None = None,
|
|
163
184
|
) -> None:
|
|
164
185
|
"""
|
|
165
|
-
Configure
|
|
186
|
+
Configure annealing scheduling for matching optimizer groups.
|
|
166
187
|
|
|
167
188
|
Args:
|
|
168
189
|
optimizer: Optimizer whose param groups are updated in-place.
|
|
169
190
|
param_group_field: Name of the field that receives the scheduled value.
|
|
170
191
|
num_iters: Number of scheduler iterations before clamping at ``final_value``.
|
|
171
|
-
|
|
172
|
-
final_value: Value approached as iterations progress during
|
|
173
|
-
plateau_ratio: Fraction of iterations to maintain ``
|
|
174
|
-
warmup_ratio: Optional fraction of iterations to linearly ramp from ``warmup_value`` to ``
|
|
192
|
+
plateau_value: Value maintained during plateau phase and used as annealing start.
|
|
193
|
+
final_value: Value approached as iterations progress during annealing.
|
|
194
|
+
plateau_ratio: Fraction of iterations to maintain ``plateau_value`` before annealing.
|
|
195
|
+
warmup_ratio: Optional fraction of iterations to linearly ramp from ``warmup_value`` to ``plateau_value``.
|
|
175
196
|
warmup_value: Starting value for the warmup ramp.
|
|
176
197
|
freeze_ratio: Optional fraction of iterations to keep the value frozen at zero at the beginning.
|
|
198
|
+
annealing_type: Type of annealing from plateau to final value ("cosine" or "linear").
|
|
177
199
|
multiplier_field: Optional per-group multiplier applied to the scheduled value.
|
|
178
200
|
skip_if_zero: Leave groups untouched when their target field equals zero.
|
|
179
201
|
apply_if_field: Require this flag to be present in a param group before updating.
|
|
@@ -188,12 +210,13 @@ class CosineWithPlateuScheduler(_CosineWithPlateauSchedulerCore):
|
|
|
188
210
|
super().__init__(
|
|
189
211
|
param_name=param_group_field,
|
|
190
212
|
num_iters=num_iters,
|
|
191
|
-
|
|
213
|
+
plateau_value=plateau_value,
|
|
192
214
|
final_value=final_value,
|
|
193
215
|
plateau_ratio=plateau_ratio,
|
|
194
216
|
warmup_ratio=warmup_ratio,
|
|
195
217
|
warmup_value=warmup_value,
|
|
196
218
|
freeze_ratio=freeze_ratio,
|
|
219
|
+
annealing_type=annealing_type,
|
|
197
220
|
)
|
|
198
221
|
self.param_group_field = param_group_field
|
|
199
222
|
return
|
|
@@ -242,12 +265,12 @@ class CosineWithPlateuScheduler(_CosineWithPlateauSchedulerCore):
|
|
|
242
265
|
return
|
|
243
266
|
|
|
244
267
|
|
|
245
|
-
class
|
|
268
|
+
class PlateauWithAnnealingParamScheduler(_PlateauWithAnnealingCore):
|
|
246
269
|
"""
|
|
247
|
-
Standalone
|
|
270
|
+
Standalone annealing scheduler with plateau for non-optimizer parameters.
|
|
248
271
|
|
|
249
|
-
Schedule phases: freeze (0) → warmup → plateau (
|
|
250
|
-
The plateau phase maintains the
|
|
272
|
+
Schedule phases: freeze (0) → warmup → plateau (plateau_value) → annealing (cosine/linear) to final_value.
|
|
273
|
+
The plateau phase maintains the plateau_value before annealing begins.
|
|
251
274
|
"""
|
|
252
275
|
|
|
253
276
|
@override
|
kostyl/utils/logging.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
+
import os
|
|
4
5
|
import sys
|
|
5
6
|
import uuid
|
|
6
7
|
from collections import namedtuple
|
|
@@ -18,32 +19,18 @@ from loguru import logger as _base_logger
|
|
|
18
19
|
if TYPE_CHECKING:
|
|
19
20
|
from loguru import Logger
|
|
20
21
|
|
|
21
|
-
class
|
|
22
|
+
class KostylLogger(Logger): # noqa: D101
|
|
22
23
|
def log_once(self, level: str, message: str, *args, **kwargs) -> None: ... # noqa: ANN003, D102
|
|
23
24
|
def warning_once(self, message: str, *args, **kwargs) -> None: ... # noqa: ANN003, D102
|
|
24
25
|
else:
|
|
25
|
-
|
|
26
|
+
KostylLogger = type(_base_logger)
|
|
26
27
|
|
|
27
28
|
try:
|
|
28
|
-
import torch.distributed as dist
|
|
29
29
|
from torch.nn.modules.module import (
|
|
30
30
|
_IncompatibleKeys, # pyright: ignore[reportAssignmentType]
|
|
31
31
|
)
|
|
32
32
|
except Exception:
|
|
33
33
|
|
|
34
|
-
class _Dummy:
|
|
35
|
-
@staticmethod
|
|
36
|
-
def is_available() -> bool:
|
|
37
|
-
return False
|
|
38
|
-
|
|
39
|
-
@staticmethod
|
|
40
|
-
def is_initialized() -> bool:
|
|
41
|
-
return False
|
|
42
|
-
|
|
43
|
-
@staticmethod
|
|
44
|
-
def get_rank() -> int:
|
|
45
|
-
return 0
|
|
46
|
-
|
|
47
34
|
class _IncompatibleKeys(
|
|
48
35
|
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
|
|
49
36
|
):
|
|
@@ -56,14 +43,13 @@ except Exception:
|
|
|
56
43
|
|
|
57
44
|
__str__ = __repr__
|
|
58
45
|
|
|
59
|
-
dist = _Dummy()
|
|
60
46
|
_IncompatibleKeys = _IncompatibleKeys
|
|
61
47
|
|
|
62
48
|
_once_lock = Lock()
|
|
63
49
|
_once_keys: set[tuple[str, str]] = set()
|
|
64
50
|
|
|
65
51
|
|
|
66
|
-
def _log_once(self:
|
|
52
|
+
def _log_once(self: KostylLogger, level: str, message: str, *args, **kwargs) -> None: # noqa: ANN003
|
|
67
53
|
key = (message, level)
|
|
68
54
|
|
|
69
55
|
with _once_lock:
|
|
@@ -75,7 +61,7 @@ def _log_once(self: CustomLogger, level: str, message: str, *args, **kwargs) ->
|
|
|
75
61
|
return
|
|
76
62
|
|
|
77
63
|
|
|
78
|
-
_base_logger = cast(
|
|
64
|
+
_base_logger = cast(KostylLogger, _base_logger)
|
|
79
65
|
_base_logger.log_once = _log_once # pyright: ignore[reportAttributeAccessIssue]
|
|
80
66
|
_base_logger.warning_once = partialmethod(_log_once, "WARNING") # pyright: ignore[reportAttributeAccessIssue]
|
|
81
67
|
|
|
@@ -91,44 +77,83 @@ _DEFAULT_FMT = "<level>{level: <8}</level> {time:HH:mm:ss.SSS} [{extra[channel]}
|
|
|
91
77
|
_ONLY_MESSAGE_FMT = "<level>{message}</level>"
|
|
92
78
|
_PRESETS = {"default": _DEFAULT_FMT, "only_message": _ONLY_MESSAGE_FMT}
|
|
93
79
|
|
|
80
|
+
KOSTYL_LOG_LEVEL = os.getenv("KOSTYL_LOG_LEVEL", "INFO")
|
|
81
|
+
|
|
94
82
|
|
|
95
83
|
def setup_logger(
|
|
96
84
|
name: str | None = None,
|
|
97
85
|
fmt: Literal["default", "only_message"] | str = "only_message",
|
|
98
|
-
level: str =
|
|
99
|
-
add_rank: bool | None = None,
|
|
86
|
+
level: str | None = None,
|
|
100
87
|
sink=sys.stdout,
|
|
101
88
|
colorize: bool = True,
|
|
102
89
|
serialize: bool = False,
|
|
103
|
-
) ->
|
|
90
|
+
) -> KostylLogger:
|
|
104
91
|
"""
|
|
105
|
-
|
|
92
|
+
Creates and configures a logger with custom formatting and output.
|
|
93
|
+
|
|
94
|
+
The function automatically removes the default sink on first call and creates
|
|
95
|
+
an isolated logger with a unique identifier for message filtering.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
name (str | None, optional): Logger channel name. If None, automatically
|
|
99
|
+
uses the calling function's filename. Defaults to None.
|
|
100
|
+
fmt (Literal["default", "only_message"] | str, optional): Log message format.
|
|
101
|
+
Available presets:
|
|
102
|
+
- "default": includes level, time, and channel
|
|
103
|
+
- "only_message": outputs only the message itself
|
|
104
|
+
Custom format strings are also supported. Defaults to "only_message".
|
|
105
|
+
level (str | None, optional): Logging level (TRACE, DEBUG, INFO, SUCCESS,
|
|
106
|
+
WARNING, ERROR, CRITICAL). If None, uses the KOSTYL_LOG_LEVEL environment
|
|
107
|
+
variable or "INFO" by default. Defaults to None.
|
|
108
|
+
sink: Output object for logs (file, sys.stdout, sys.stderr, etc.).
|
|
109
|
+
Defaults to sys.stdout.
|
|
110
|
+
colorize (bool, optional): Enable colored output formatting.
|
|
111
|
+
Defaults to True.
|
|
112
|
+
serialize (bool, optional): Serialize logs to JSON format.
|
|
113
|
+
Defaults to False.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
CustomLogger: Configured logger instance with additional methods
|
|
117
|
+
log_once() and warning_once().
|
|
118
|
+
|
|
119
|
+
Example:
|
|
120
|
+
>>> # Basic usage with automatic name detection
|
|
121
|
+
>>> logger = setup_logger()
|
|
122
|
+
>>> logger.info("Hello World")
|
|
106
123
|
|
|
107
|
-
|
|
124
|
+
>>> # With custom name and level
|
|
125
|
+
>>> logger = setup_logger(name="MyApp", level="DEBUG")
|
|
126
|
+
|
|
127
|
+
>>> # With custom format
|
|
128
|
+
>>> logger = setup_logger(
|
|
129
|
+
... name="API",
|
|
130
|
+
... fmt="{level} | {time:YYYY-MM-DD HH:mm:ss} | {message}"
|
|
131
|
+
... )
|
|
108
132
|
|
|
109
|
-
Format example: "{level} {time:MM-DD HH:mm:ss} [{extra[channel]}] {message}"
|
|
110
133
|
"""
|
|
111
134
|
global _DEFAULT_SINK_REMOVED
|
|
112
135
|
if not _DEFAULT_SINK_REMOVED:
|
|
113
136
|
_base_logger.remove()
|
|
114
137
|
_DEFAULT_SINK_REMOVED = True
|
|
115
138
|
|
|
116
|
-
if
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
139
|
+
if level is None:
|
|
140
|
+
if KOSTYL_LOG_LEVEL not in {
|
|
141
|
+
"TRACE",
|
|
142
|
+
"DEBUG",
|
|
143
|
+
"INFO",
|
|
144
|
+
"SUCCESS",
|
|
145
|
+
"WARNING",
|
|
146
|
+
"ERROR",
|
|
147
|
+
"CRITICAL",
|
|
148
|
+
}:
|
|
149
|
+
level = "INFO"
|
|
150
|
+
else:
|
|
151
|
+
level = KOSTYL_LOG_LEVEL
|
|
120
152
|
|
|
121
|
-
if
|
|
122
|
-
|
|
123
|
-
add_rank = dist.is_available() and dist.is_initialized()
|
|
124
|
-
except Exception:
|
|
125
|
-
add_rank = False
|
|
126
|
-
|
|
127
|
-
if add_rank:
|
|
128
|
-
rank = dist.get_rank()
|
|
129
|
-
channel = f"rank:{rank} - {base}"
|
|
153
|
+
if name is None:
|
|
154
|
+
channel = _caller_filename()
|
|
130
155
|
else:
|
|
131
|
-
channel =
|
|
156
|
+
channel = name
|
|
132
157
|
|
|
133
158
|
if fmt in _PRESETS:
|
|
134
159
|
fmt = _PRESETS[fmt]
|
|
@@ -146,7 +171,7 @@ def setup_logger(
|
|
|
146
171
|
filter=lambda r: r["extra"].get("logger_id") == logger_id,
|
|
147
172
|
)
|
|
148
173
|
logger = _base_logger.bind(logger_id=logger_id, channel=channel)
|
|
149
|
-
return cast(
|
|
174
|
+
return cast(KostylLogger, logger)
|
|
150
175
|
|
|
151
176
|
|
|
152
177
|
def log_incompatible_keys(
|
|
@@ -154,22 +179,12 @@ def log_incompatible_keys(
|
|
|
154
179
|
incompatible_keys: _IncompatibleKeys
|
|
155
180
|
| tuple[list[str], list[str]]
|
|
156
181
|
| dict[str, list[str]],
|
|
157
|
-
|
|
182
|
+
postfix_msg: str = "",
|
|
158
183
|
) -> None:
|
|
159
184
|
"""
|
|
160
185
|
Logs warnings for incompatible keys encountered during model loading or state dict operations.
|
|
161
186
|
|
|
162
187
|
Note: If incompatible_keys is of an unsupported type, an error message is logged and the function returns early.
|
|
163
|
-
|
|
164
|
-
Args:
|
|
165
|
-
logger (Logger): The logger instance used to output warning messages.
|
|
166
|
-
incompatible_keys (_IncompatibleKeys | tuple[list[str], list[str]] | dict[str, list[str]]): An object containing lists of missing and unexpected keys.
|
|
167
|
-
model_specific_msg (str, optional): A custom message to append to the log output, typically
|
|
168
|
-
indicating the model or context. Defaults to an empty string.
|
|
169
|
-
|
|
170
|
-
Returns:
|
|
171
|
-
None
|
|
172
|
-
|
|
173
188
|
"""
|
|
174
189
|
incompatible_keys_: dict[str, list[str]] = {}
|
|
175
190
|
match incompatible_keys:
|
|
@@ -192,5 +207,5 @@ def log_incompatible_keys(
|
|
|
192
207
|
return
|
|
193
208
|
|
|
194
209
|
for name, keys in incompatible_keys_.items():
|
|
195
|
-
logger.warning(f"{name} {
|
|
210
|
+
logger.warning(f"{name} {postfix_msg}: {', '.join(keys)}")
|
|
196
211
|
return
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
kostyl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
kostyl/ml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
kostyl/ml/base_uploader.py,sha256=KxHuohCcNK18kTVFBBqDu_IOQefluhSXOzwC56O66wc,484
|
|
4
|
+
kostyl/ml/configs/__init__.py,sha256=djYjLxA7riFcSibAKfWHns-BCESEPrqSz_ZY2rJO-cc,913
|
|
5
|
+
kostyl/ml/configs/hyperparams.py,sha256=lvtbvOFEoTBAJug7FR35xMQdPLgDQjRoP2fyDP-jD7E,3305
|
|
6
|
+
kostyl/ml/configs/mixins.py,sha256=xHHAoRoPbzP9ECFP9duzg6SzegHcoLI8Pr9NrLoWNHs,1411
|
|
7
|
+
kostyl/ml/configs/training_settings.py,sha256=wT9CHuLaKrLwonsc87Ee421EyFis_c9fqOgn9bSClm8,2747
|
|
8
|
+
kostyl/ml/data_collator.py,sha256=kxiaMDKwSKXGBtrF8yXxHcypf7t_6syU-NwO1LcX50k,4062
|
|
9
|
+
kostyl/ml/dist_utils.py,sha256=UFNMLEHc0A5F6KvTRG8GQPpRDwG4m5dvM__UvXNc2aQ,4526
|
|
10
|
+
kostyl/ml/integrations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
kostyl/ml/integrations/clearml/__init__.py,sha256=3TBVI-3fE9ZzuvOLEohW9TOK0BZTLD5JiYalAVDkocc,217
|
|
12
|
+
kostyl/ml/integrations/clearml/checkpoint_uploader.py,sha256=PupFi7jKROsIddOz7X5DhV7nUNdDZg5kKaaLvzdCHlY,4012
|
|
13
|
+
kostyl/ml/integrations/clearml/config_mixin.py,sha256=70QRicU7etiDzLX-MplqVX8uFm5siuPrM8KbTOriZnQ,3308
|
|
14
|
+
kostyl/ml/integrations/clearml/dataset_utils.py,sha256=eij_sr2KDhm8GxEbVbK8aBjPsuVvLl9-PIGGaKVgXLA,1729
|
|
15
|
+
kostyl/ml/integrations/clearml/loading_utils.py,sha256=NAMmB9NTGCXCHh-bR_nrQZyqImUVZqicNjExDyPM2mU,5224
|
|
16
|
+
kostyl/ml/integrations/clearml/version_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWaJTyq8,1210
|
|
17
|
+
kostyl/ml/integrations/lightning/__init__.py,sha256=r96os8kTuKIAymx3k9Td1JBrO2PH7nQAWUC54NsY5yY,392
|
|
18
|
+
kostyl/ml/integrations/lightning/callbacks/__init__.py,sha256=EnKkNwwNDZnEqKRlpY4FVrqP88ECPF6nlT2bSLUIKRk,194
|
|
19
|
+
kostyl/ml/integrations/lightning/callbacks/checkpoint.py,sha256=SfcaQRkXviMUej0UgrfXcqMDlRKYaAN3rgYCMKI97Os,18433
|
|
20
|
+
kostyl/ml/integrations/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
|
|
21
|
+
kostyl/ml/integrations/lightning/loggers/__init__.py,sha256=e51dszaoJbuzwBkbdugmuDsPldoSO4yaRgmZUg1Bdy0,71
|
|
22
|
+
kostyl/ml/integrations/lightning/loggers/tb_logger.py,sha256=CpjlcEIT187cJXJgRYafqfzvcnwPgPaVZ0vLUflIr7k,899
|
|
23
|
+
kostyl/ml/integrations/lightning/metrics_formatting.py,sha256=U6vdNENZLvp2dT1L3HqFKtXrHwGKoDXN93hvamPGHjM,1341
|
|
24
|
+
kostyl/ml/integrations/lightning/mixins.py,sha256=hVIsIUu6Iryrz6S7GQTqog9vNq8LQyjJd2aoJ5Ws6KU,5253
|
|
25
|
+
kostyl/ml/integrations/lightning/module.py,sha256=39hcVNZSGyj5tLpXyX8IoqMGWt5vf6-Bx5JnNJ2-Wag,5218
|
|
26
|
+
kostyl/ml/integrations/lightning/utils.py,sha256=DhLy_3JA5VyMQkB1v6xxRxDNHfisjXFYVjuIKPpO81M,1967
|
|
27
|
+
kostyl/ml/params_groups.py,sha256=nUyw5d06Pvy9QPiYtZzLYR87xwXqJLxbHthgQH8oSCM,3583
|
|
28
|
+
kostyl/ml/schedulers/__init__.py,sha256=VIo8MOP4w5Ll24XqFb3QGi2rKvys6c0dEFYPIdDoPlw,526
|
|
29
|
+
kostyl/ml/schedulers/base.py,sha256=bjmwgdZpnSqpCnHPnKC6MEiRO79cwxMJpZq-eQVNs2M,1353
|
|
30
|
+
kostyl/ml/schedulers/composite.py,sha256=ee4xlMDMMtjKPkbTF2ue9GTr9DuGCGjZWf11mHbi6aE,2387
|
|
31
|
+
kostyl/ml/schedulers/cosine.py,sha256=y8ylrgVOkVcr2-ExoqqNW--tdDX88TBYPQCOppIf2_M,8685
|
|
32
|
+
kostyl/ml/schedulers/linear.py,sha256=RnnnblRuRXP3LT03QVIHUaK2kNsiMP1AedrMoeyh3qk,5843
|
|
33
|
+
kostyl/ml/schedulers/plateau.py,sha256=N-hiostPtTR0W4xnEJYB_1dv0DRx39iufLkGUrSIoWE,11235
|
|
34
|
+
kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
|
|
35
|
+
kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
|
|
36
|
+
kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
|
|
37
|
+
kostyl/utils/logging.py,sha256=CgNFNogcK0hoZmygvBWlTcq5A3m2Pfv9eOAP_gwx0pM,6633
|
|
38
|
+
kostyl_toolkit-0.1.38.dist-info/WHEEL,sha256=e_m4S054HL0hyR3CpOk-b7Q7fDX6BuFkgL5OjAExXas,80
|
|
39
|
+
kostyl_toolkit-0.1.38.dist-info/METADATA,sha256=nz5AzlWjKBqh7OZCklk-efWZ1jVDihw3YrrpLyoII3k,4269
|
|
40
|
+
kostyl_toolkit-0.1.38.dist-info/RECORD,,
|
kostyl/ml/lightning/__init__.py
DELETED