kostyl-toolkit 0.1.35__py3-none-any.whl → 0.1.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kostyl/ml/configs/hyperparams.py +21 -5
- kostyl/ml/configs/training_settings.py +17 -6
- kostyl/ml/dist_utils.py +52 -30
- kostyl/ml/lightning/callbacks/checkpoint.py +10 -10
- kostyl/ml/lightning/extensions/custom_module.py +0 -5
- kostyl/ml/lightning/extensions/pretrained_model.py +6 -4
- kostyl/ml/lightning/loggers/tb_logger.py +2 -2
- kostyl/ml/lightning/utils.py +58 -0
- kostyl/ml/registry_uploader.py +56 -29
- kostyl/ml/schedulers/__init__.py +13 -1
- kostyl/ml/schedulers/base.py +9 -7
- kostyl/ml/schedulers/cosine.py +53 -24
- kostyl/ml/schedulers/cosine_with_plateu.py +277 -0
- kostyl/ml/schedulers/linear.py +36 -11
- kostyl/utils/logging.py +68 -53
- {kostyl_toolkit-0.1.35.dist-info → kostyl_toolkit-0.1.37.dist-info}/METADATA +1 -1
- {kostyl_toolkit-0.1.35.dist-info → kostyl_toolkit-0.1.37.dist-info}/RECORD +18 -17
- {kostyl_toolkit-0.1.35.dist-info → kostyl_toolkit-0.1.37.dist-info}/WHEEL +1 -1
- kostyl/ml/lightning/training_utils.py +0 -241
kostyl/utils/logging.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
+
import os
|
|
4
5
|
import sys
|
|
5
6
|
import uuid
|
|
6
7
|
from collections import namedtuple
|
|
@@ -18,32 +19,18 @@ from loguru import logger as _base_logger
|
|
|
18
19
|
if TYPE_CHECKING:
|
|
19
20
|
from loguru import Logger
|
|
20
21
|
|
|
21
|
-
class
|
|
22
|
+
class KostylLogger(Logger): # noqa: D101
|
|
22
23
|
def log_once(self, level: str, message: str, *args, **kwargs) -> None: ... # noqa: ANN003, D102
|
|
23
24
|
def warning_once(self, message: str, *args, **kwargs) -> None: ... # noqa: ANN003, D102
|
|
24
25
|
else:
|
|
25
|
-
|
|
26
|
+
KostylLogger = type(_base_logger)
|
|
26
27
|
|
|
27
28
|
try:
|
|
28
|
-
import torch.distributed as dist
|
|
29
29
|
from torch.nn.modules.module import (
|
|
30
30
|
_IncompatibleKeys, # pyright: ignore[reportAssignmentType]
|
|
31
31
|
)
|
|
32
32
|
except Exception:
|
|
33
33
|
|
|
34
|
-
class _Dummy:
|
|
35
|
-
@staticmethod
|
|
36
|
-
def is_available() -> bool:
|
|
37
|
-
return False
|
|
38
|
-
|
|
39
|
-
@staticmethod
|
|
40
|
-
def is_initialized() -> bool:
|
|
41
|
-
return False
|
|
42
|
-
|
|
43
|
-
@staticmethod
|
|
44
|
-
def get_rank() -> int:
|
|
45
|
-
return 0
|
|
46
|
-
|
|
47
34
|
class _IncompatibleKeys(
|
|
48
35
|
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
|
|
49
36
|
):
|
|
@@ -56,14 +43,13 @@ except Exception:
|
|
|
56
43
|
|
|
57
44
|
__str__ = __repr__
|
|
58
45
|
|
|
59
|
-
dist = _Dummy()
|
|
60
46
|
_IncompatibleKeys = _IncompatibleKeys
|
|
61
47
|
|
|
62
48
|
_once_lock = Lock()
|
|
63
49
|
_once_keys: set[tuple[str, str]] = set()
|
|
64
50
|
|
|
65
51
|
|
|
66
|
-
def _log_once(self:
|
|
52
|
+
def _log_once(self: KostylLogger, level: str, message: str, *args, **kwargs) -> None: # noqa: ANN003
|
|
67
53
|
key = (message, level)
|
|
68
54
|
|
|
69
55
|
with _once_lock:
|
|
@@ -75,7 +61,7 @@ def _log_once(self: CustomLogger, level: str, message: str, *args, **kwargs) ->
|
|
|
75
61
|
return
|
|
76
62
|
|
|
77
63
|
|
|
78
|
-
_base_logger = cast(
|
|
64
|
+
_base_logger = cast(KostylLogger, _base_logger)
|
|
79
65
|
_base_logger.log_once = _log_once # pyright: ignore[reportAttributeAccessIssue]
|
|
80
66
|
_base_logger.warning_once = partialmethod(_log_once, "WARNING") # pyright: ignore[reportAttributeAccessIssue]
|
|
81
67
|
|
|
@@ -91,44 +77,83 @@ _DEFAULT_FMT = "<level>{level: <8}</level> {time:HH:mm:ss.SSS} [{extra[channel]}
|
|
|
91
77
|
_ONLY_MESSAGE_FMT = "<level>{message}</level>"
|
|
92
78
|
_PRESETS = {"default": _DEFAULT_FMT, "only_message": _ONLY_MESSAGE_FMT}
|
|
93
79
|
|
|
80
|
+
KOSTYL_LOG_LEVEL = os.getenv("KOSTYL_LOG_LEVEL", "INFO")
|
|
81
|
+
|
|
94
82
|
|
|
95
83
|
def setup_logger(
|
|
96
84
|
name: str | None = None,
|
|
97
|
-
fmt: Literal["default", "only_message"] | str = "
|
|
98
|
-
level: str =
|
|
99
|
-
add_rank: bool | None = None,
|
|
85
|
+
fmt: Literal["default", "only_message"] | str = "only_message",
|
|
86
|
+
level: str | None = None,
|
|
100
87
|
sink=sys.stdout,
|
|
101
88
|
colorize: bool = True,
|
|
102
89
|
serialize: bool = False,
|
|
103
|
-
) ->
|
|
90
|
+
) -> KostylLogger:
|
|
104
91
|
"""
|
|
105
|
-
|
|
92
|
+
Creates and configures a logger with custom formatting and output.
|
|
93
|
+
|
|
94
|
+
The function automatically removes the default sink on first call and creates
|
|
95
|
+
an isolated logger with a unique identifier for message filtering.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
name (str | None, optional): Logger channel name. If None, automatically
|
|
99
|
+
uses the calling function's filename. Defaults to None.
|
|
100
|
+
fmt (Literal["default", "only_message"] | str, optional): Log message format.
|
|
101
|
+
Available presets:
|
|
102
|
+
- "default": includes level, time, and channel
|
|
103
|
+
- "only_message": outputs only the message itself
|
|
104
|
+
Custom format strings are also supported. Defaults to "only_message".
|
|
105
|
+
level (str | None, optional): Logging level (TRACE, DEBUG, INFO, SUCCESS,
|
|
106
|
+
WARNING, ERROR, CRITICAL). If None, uses the KOSTYL_LOG_LEVEL environment
|
|
107
|
+
variable or "INFO" by default. Defaults to None.
|
|
108
|
+
sink: Output object for logs (file, sys.stdout, sys.stderr, etc.).
|
|
109
|
+
Defaults to sys.stdout.
|
|
110
|
+
colorize (bool, optional): Enable colored output formatting.
|
|
111
|
+
Defaults to True.
|
|
112
|
+
serialize (bool, optional): Serialize logs to JSON format.
|
|
113
|
+
Defaults to False.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
CustomLogger: Configured logger instance with additional methods
|
|
117
|
+
log_once() and warning_once().
|
|
118
|
+
|
|
119
|
+
Example:
|
|
120
|
+
>>> # Basic usage with automatic name detection
|
|
121
|
+
>>> logger = setup_logger()
|
|
122
|
+
>>> logger.info("Hello World")
|
|
106
123
|
|
|
107
|
-
|
|
124
|
+
>>> # With custom name and level
|
|
125
|
+
>>> logger = setup_logger(name="MyApp", level="DEBUG")
|
|
126
|
+
|
|
127
|
+
>>> # With custom format
|
|
128
|
+
>>> logger = setup_logger(
|
|
129
|
+
... name="API",
|
|
130
|
+
... fmt="{level} | {time:YYYY-MM-DD HH:mm:ss} | {message}"
|
|
131
|
+
... )
|
|
108
132
|
|
|
109
|
-
Format example: "{level} {time:MM-DD HH:mm:ss} [{extra[channel]}] {message}"
|
|
110
133
|
"""
|
|
111
134
|
global _DEFAULT_SINK_REMOVED
|
|
112
135
|
if not _DEFAULT_SINK_REMOVED:
|
|
113
136
|
_base_logger.remove()
|
|
114
137
|
_DEFAULT_SINK_REMOVED = True
|
|
115
138
|
|
|
116
|
-
if
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
139
|
+
if level is None:
|
|
140
|
+
if KOSTYL_LOG_LEVEL not in {
|
|
141
|
+
"TRACE",
|
|
142
|
+
"DEBUG",
|
|
143
|
+
"INFO",
|
|
144
|
+
"SUCCESS",
|
|
145
|
+
"WARNING",
|
|
146
|
+
"ERROR",
|
|
147
|
+
"CRITICAL",
|
|
148
|
+
}:
|
|
149
|
+
level = "INFO"
|
|
150
|
+
else:
|
|
151
|
+
level = KOSTYL_LOG_LEVEL
|
|
120
152
|
|
|
121
|
-
if
|
|
122
|
-
|
|
123
|
-
add_rank = dist.is_available() and dist.is_initialized()
|
|
124
|
-
except Exception:
|
|
125
|
-
add_rank = False
|
|
126
|
-
|
|
127
|
-
if add_rank:
|
|
128
|
-
rank = dist.get_rank()
|
|
129
|
-
channel = f"rank:{rank} - {base}"
|
|
153
|
+
if name is None:
|
|
154
|
+
channel = _caller_filename()
|
|
130
155
|
else:
|
|
131
|
-
channel =
|
|
156
|
+
channel = name
|
|
132
157
|
|
|
133
158
|
if fmt in _PRESETS:
|
|
134
159
|
fmt = _PRESETS[fmt]
|
|
@@ -146,7 +171,7 @@ def setup_logger(
|
|
|
146
171
|
filter=lambda r: r["extra"].get("logger_id") == logger_id,
|
|
147
172
|
)
|
|
148
173
|
logger = _base_logger.bind(logger_id=logger_id, channel=channel)
|
|
149
|
-
return cast(
|
|
174
|
+
return cast(KostylLogger, logger)
|
|
150
175
|
|
|
151
176
|
|
|
152
177
|
def log_incompatible_keys(
|
|
@@ -154,22 +179,12 @@ def log_incompatible_keys(
|
|
|
154
179
|
incompatible_keys: _IncompatibleKeys
|
|
155
180
|
| tuple[list[str], list[str]]
|
|
156
181
|
| dict[str, list[str]],
|
|
157
|
-
|
|
182
|
+
postfix_msg: str = "",
|
|
158
183
|
) -> None:
|
|
159
184
|
"""
|
|
160
185
|
Logs warnings for incompatible keys encountered during model loading or state dict operations.
|
|
161
186
|
|
|
162
187
|
Note: If incompatible_keys is of an unsupported type, an error message is logged and the function returns early.
|
|
163
|
-
|
|
164
|
-
Args:
|
|
165
|
-
logger (Logger): The logger instance used to output warning messages.
|
|
166
|
-
incompatible_keys (_IncompatibleKeys | tuple[list[str], list[str]] | dict[str, list[str]]): An object containing lists of missing and unexpected keys.
|
|
167
|
-
model_specific_msg (str, optional): A custom message to append to the log output, typically
|
|
168
|
-
indicating the model or context. Defaults to an empty string.
|
|
169
|
-
|
|
170
|
-
Returns:
|
|
171
|
-
None
|
|
172
|
-
|
|
173
188
|
"""
|
|
174
189
|
incompatible_keys_: dict[str, list[str]] = {}
|
|
175
190
|
match incompatible_keys:
|
|
@@ -192,5 +207,5 @@ def log_incompatible_keys(
|
|
|
192
207
|
return
|
|
193
208
|
|
|
194
209
|
for name, keys in incompatible_keys_.items():
|
|
195
|
-
logger.warning(f"{name} {
|
|
210
|
+
logger.warning(f"{name} {postfix_msg}: {', '.join(keys)}")
|
|
196
211
|
return
|
|
@@ -6,32 +6,33 @@ kostyl/ml/clearml/logging_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWa
|
|
|
6
6
|
kostyl/ml/clearml/pulling_utils.py,sha256=jMlVXcYRumwWnPlELRlgEdfq5L6Wir_EcfTmOoWBLTA,4077
|
|
7
7
|
kostyl/ml/configs/__init__.py,sha256=IetcivbqYGutowLqxdKp7QR4tkXKBr4m8t4Zkk9jHZU,911
|
|
8
8
|
kostyl/ml/configs/base_model.py,sha256=Eofn14J9RsjpVx_J4rp6C19pDDCANU4hr3JtX-d0FpQ,4820
|
|
9
|
-
kostyl/ml/configs/hyperparams.py,sha256=
|
|
10
|
-
kostyl/ml/configs/training_settings.py,sha256=
|
|
9
|
+
kostyl/ml/configs/hyperparams.py,sha256=lvtbvOFEoTBAJug7FR35xMQdPLgDQjRoP2fyDP-jD7E,3305
|
|
10
|
+
kostyl/ml/configs/training_settings.py,sha256=wT9CHuLaKrLwonsc87Ee421EyFis_c9fqOgn9bSClm8,2747
|
|
11
11
|
kostyl/ml/data_processing_utils.py,sha256=jjEjV0S0wREgZkzg27ip0LpI8cQqkwe2QwATmAqm9-g,3832
|
|
12
|
-
kostyl/ml/dist_utils.py,sha256=
|
|
12
|
+
kostyl/ml/dist_utils.py,sha256=lK9_aAh9L1SvvXWzcWiBoFjczfDiKzEpcno5csImAYQ,4635
|
|
13
13
|
kostyl/ml/lightning/__init__.py,sha256=R36PImjVvzBF9t_z9u6RYVnUFJJ-sNDUOdboWUojHmM,173
|
|
14
14
|
kostyl/ml/lightning/callbacks/__init__.py,sha256=EnKkNwwNDZnEqKRlpY4FVrqP88ECPF6nlT2bSLUIKRk,194
|
|
15
|
-
kostyl/ml/lightning/callbacks/checkpoint.py,sha256=
|
|
15
|
+
kostyl/ml/lightning/callbacks/checkpoint.py,sha256=HI17gu-GxnfXUchflWBTwly7cCYnlpKcshuR-TgD6s4,19066
|
|
16
16
|
kostyl/ml/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
|
|
17
17
|
kostyl/ml/lightning/extensions/__init__.py,sha256=OY6QGv1agYgqqKf1xJBrxgp_i8FunVfPzYezfaRrGXU,182
|
|
18
|
-
kostyl/ml/lightning/extensions/custom_module.py,sha256=
|
|
19
|
-
kostyl/ml/lightning/extensions/pretrained_model.py,sha256=
|
|
18
|
+
kostyl/ml/lightning/extensions/custom_module.py,sha256=qYffgPwIB_ePwK_MIaRruuDxPKJZb42kg2yy996eGwY,6415
|
|
19
|
+
kostyl/ml/lightning/extensions/pretrained_model.py,sha256=hVIsIUu6Iryrz6S7GQTqog9vNq8LQyjJd2aoJ5Ws6KU,5253
|
|
20
20
|
kostyl/ml/lightning/loggers/__init__.py,sha256=e51dszaoJbuzwBkbdugmuDsPldoSO4yaRgmZUg1Bdy0,71
|
|
21
|
-
kostyl/ml/lightning/loggers/tb_logger.py,sha256=
|
|
22
|
-
kostyl/ml/lightning/
|
|
21
|
+
kostyl/ml/lightning/loggers/tb_logger.py,sha256=CpjlcEIT187cJXJgRYafqfzvcnwPgPaVZ0vLUflIr7k,899
|
|
22
|
+
kostyl/ml/lightning/utils.py,sha256=DhLy_3JA5VyMQkB1v6xxRxDNHfisjXFYVjuIKPpO81M,1967
|
|
23
23
|
kostyl/ml/metrics_formatting.py,sha256=U6vdNENZLvp2dT1L3HqFKtXrHwGKoDXN93hvamPGHjM,1341
|
|
24
24
|
kostyl/ml/params_groups.py,sha256=nUyw5d06Pvy9QPiYtZzLYR87xwXqJLxbHthgQH8oSCM,3583
|
|
25
|
-
kostyl/ml/registry_uploader.py,sha256=
|
|
26
|
-
kostyl/ml/schedulers/__init__.py,sha256=
|
|
27
|
-
kostyl/ml/schedulers/base.py,sha256=
|
|
25
|
+
kostyl/ml/registry_uploader.py,sha256=BbyLXvF8AL145k7g6MRkJ7gf_3Um53p3Pn5280vVD9U,4384
|
|
26
|
+
kostyl/ml/schedulers/__init__.py,sha256=_EtZu8DwTCSv4-eR84kRstEZblHylVqda7WQUOXIKfw,534
|
|
27
|
+
kostyl/ml/schedulers/base.py,sha256=bjmwgdZpnSqpCnHPnKC6MEiRO79cwxMJpZq-eQVNs2M,1353
|
|
28
28
|
kostyl/ml/schedulers/composite.py,sha256=ee4xlMDMMtjKPkbTF2ue9GTr9DuGCGjZWf11mHbi6aE,2387
|
|
29
|
-
kostyl/ml/schedulers/cosine.py,sha256=
|
|
30
|
-
kostyl/ml/schedulers/
|
|
29
|
+
kostyl/ml/schedulers/cosine.py,sha256=y8ylrgVOkVcr2-ExoqqNW--tdDX88TBYPQCOppIf2_M,8685
|
|
30
|
+
kostyl/ml/schedulers/cosine_with_plateu.py,sha256=0-X6wl3HgsTiLIbISb9lOxIVWXHDEND7rILitMWtIiM,10195
|
|
31
|
+
kostyl/ml/schedulers/linear.py,sha256=RnnnblRuRXP3LT03QVIHUaK2kNsiMP1AedrMoeyh3qk,5843
|
|
31
32
|
kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
|
|
32
33
|
kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
|
|
33
34
|
kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
|
|
34
|
-
kostyl/utils/logging.py,sha256=
|
|
35
|
-
kostyl_toolkit-0.1.
|
|
36
|
-
kostyl_toolkit-0.1.
|
|
37
|
-
kostyl_toolkit-0.1.
|
|
35
|
+
kostyl/utils/logging.py,sha256=CgNFNogcK0hoZmygvBWlTcq5A3m2Pfv9eOAP_gwx0pM,6633
|
|
36
|
+
kostyl_toolkit-0.1.37.dist-info/WHEEL,sha256=eycQt0QpYmJMLKpE3X9iDk8R04v2ZF0x82ogq-zP6bQ,79
|
|
37
|
+
kostyl_toolkit-0.1.37.dist-info/METADATA,sha256=yHPgSAhPnm5tDQjvDIfs213-bsVX6vMfVsUbX9GboGU,4269
|
|
38
|
+
kostyl_toolkit-0.1.37.dist-info/RECORD,,
|
|
@@ -1,241 +0,0 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
2
|
-
from dataclasses import fields
|
|
3
|
-
from pathlib import Path
|
|
4
|
-
from typing import Literal
|
|
5
|
-
from typing import cast
|
|
6
|
-
|
|
7
|
-
import lightning as L
|
|
8
|
-
import torch
|
|
9
|
-
import torch.distributed as dist
|
|
10
|
-
from clearml import OutputModel
|
|
11
|
-
from clearml import Task
|
|
12
|
-
from lightning.pytorch.callbacks import Callback
|
|
13
|
-
from lightning.pytorch.callbacks import EarlyStopping
|
|
14
|
-
from lightning.pytorch.callbacks import LearningRateMonitor
|
|
15
|
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
16
|
-
from lightning.pytorch.loggers import TensorBoardLogger
|
|
17
|
-
from lightning.pytorch.strategies import DDPStrategy
|
|
18
|
-
from lightning.pytorch.strategies import FSDPStrategy
|
|
19
|
-
from torch.distributed import ProcessGroup
|
|
20
|
-
from torch.distributed.fsdp import MixedPrecision
|
|
21
|
-
from torch.nn import Module
|
|
22
|
-
|
|
23
|
-
from kostyl.ml.configs import CheckpointConfig
|
|
24
|
-
from kostyl.ml.configs import DDPStrategyConfig
|
|
25
|
-
from kostyl.ml.configs import EarlyStoppingConfig
|
|
26
|
-
from kostyl.ml.configs import FSDP1StrategyConfig
|
|
27
|
-
from kostyl.ml.configs import SingleDeviceStrategyConfig
|
|
28
|
-
from kostyl.ml.lightning.callbacks import setup_checkpoint_callback
|
|
29
|
-
from kostyl.ml.lightning.callbacks import setup_early_stopping_callback
|
|
30
|
-
from kostyl.ml.lightning.loggers import setup_tb_logger
|
|
31
|
-
from kostyl.ml.registry_uploader import ClearMLRegistryUploaderCallback
|
|
32
|
-
from kostyl.utils.logging import setup_logger
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
TRAINING_STRATEGIES = (
|
|
36
|
-
FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
logger = setup_logger(add_rank=True)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def estimate_total_steps(
|
|
43
|
-
trainer: L.Trainer, process_group: ProcessGroup | None = None
|
|
44
|
-
) -> int:
|
|
45
|
-
"""
|
|
46
|
-
Estimates the total number of training steps based on the
|
|
47
|
-
dataloader length, accumulation steps, and distributed world size.
|
|
48
|
-
""" # noqa: D205
|
|
49
|
-
if dist.is_initialized():
|
|
50
|
-
world_size = dist.get_world_size(process_group)
|
|
51
|
-
else:
|
|
52
|
-
world_size = 1
|
|
53
|
-
|
|
54
|
-
datamodule = trainer.datamodule # type: ignore
|
|
55
|
-
if datamodule is None:
|
|
56
|
-
raise ValueError("Trainer must have a datamodule to estimate total steps.")
|
|
57
|
-
datamodule = cast(L.LightningDataModule, datamodule)
|
|
58
|
-
|
|
59
|
-
logger.info("Loading `train_dataloader` to estimate number of stepping batches.")
|
|
60
|
-
datamodule.setup("fit")
|
|
61
|
-
|
|
62
|
-
dataloader_len = len(datamodule.train_dataloader())
|
|
63
|
-
steps_per_epoch = dataloader_len // trainer.accumulate_grad_batches // world_size
|
|
64
|
-
|
|
65
|
-
if trainer.max_epochs is None:
|
|
66
|
-
raise ValueError("Trainer must have `max_epochs` set to estimate total steps.")
|
|
67
|
-
total_steps = steps_per_epoch * trainer.max_epochs
|
|
68
|
-
|
|
69
|
-
logger.info(
|
|
70
|
-
f"Total steps: {total_steps} (per-epoch: {steps_per_epoch})\n"
|
|
71
|
-
f"-> Dataloader len: {dataloader_len}\n"
|
|
72
|
-
f"-> Accumulate grad batches: {trainer.accumulate_grad_batches}\n"
|
|
73
|
-
f"-> Epochs: {trainer.max_epochs}\n "
|
|
74
|
-
f"-> World size: {world_size}"
|
|
75
|
-
)
|
|
76
|
-
return total_steps
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
@dataclass
|
|
80
|
-
class Callbacks:
|
|
81
|
-
"""Dataclass to hold PyTorch Lightning callbacks."""
|
|
82
|
-
|
|
83
|
-
checkpoint: ModelCheckpoint
|
|
84
|
-
lr_monitor: LearningRateMonitor
|
|
85
|
-
early_stopping: EarlyStopping | None = None
|
|
86
|
-
|
|
87
|
-
def to_list(self) -> list[Callback]:
|
|
88
|
-
"""Convert dataclass fields to a list of Callbacks. None values are omitted."""
|
|
89
|
-
callbacks: list[Callback] = [
|
|
90
|
-
getattr(self, field.name)
|
|
91
|
-
for field in fields(self)
|
|
92
|
-
if getattr(self, field.name) is not None
|
|
93
|
-
]
|
|
94
|
-
return callbacks
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def setup_callbacks(
|
|
98
|
-
task: Task,
|
|
99
|
-
root_path: Path,
|
|
100
|
-
checkpoint_cfg: CheckpointConfig,
|
|
101
|
-
early_stopping_cfg: EarlyStoppingConfig | None,
|
|
102
|
-
output_model: OutputModel,
|
|
103
|
-
checkpoint_upload_strategy: Literal["only-best", "every-checkpoint"],
|
|
104
|
-
config_dict: dict[str, str] | None = None,
|
|
105
|
-
enable_tag_versioning: bool = False,
|
|
106
|
-
) -> Callbacks:
|
|
107
|
-
"""
|
|
108
|
-
Set up PyTorch Lightning callbacks for training.
|
|
109
|
-
|
|
110
|
-
Creates and configures a set of callbacks including checkpoint saving,
|
|
111
|
-
learning rate monitoring, model registry uploading, and optional early stopping.
|
|
112
|
-
|
|
113
|
-
Args:
|
|
114
|
-
task: ClearML task for organizing checkpoints by task name and ID.
|
|
115
|
-
root_path: Root directory for saving checkpoints.
|
|
116
|
-
checkpoint_cfg: Configuration for checkpoint saving behavior.
|
|
117
|
-
checkpoint_upload_strategy: Model upload strategy:
|
|
118
|
-
- `"only-best"`: Upload only the best checkpoint based on monitored metric.
|
|
119
|
-
- `"every-checkpoint"`: Upload every saved checkpoint.
|
|
120
|
-
output_model: ClearML OutputModel instance for model registry integration.
|
|
121
|
-
early_stopping_cfg: Configuration for early stopping. If None, early stopping
|
|
122
|
-
is disabled.
|
|
123
|
-
config_dict: Optional configuration dictionary to store with the model
|
|
124
|
-
in the registry.
|
|
125
|
-
enable_tag_versioning: Whether to auto-increment version tags (e.g., "v1.0")
|
|
126
|
-
on the uploaded model.
|
|
127
|
-
|
|
128
|
-
Returns:
|
|
129
|
-
Callbacks dataclass containing configured checkpoint, lr_monitor,
|
|
130
|
-
and optionally early_stopping callbacks.
|
|
131
|
-
|
|
132
|
-
"""
|
|
133
|
-
lr_monitor = LearningRateMonitor(
|
|
134
|
-
logging_interval="step", log_weight_decay=True, log_momentum=False
|
|
135
|
-
)
|
|
136
|
-
model_uploader = ClearMLRegistryUploaderCallback(
|
|
137
|
-
output_model=output_model,
|
|
138
|
-
config_dict=config_dict,
|
|
139
|
-
verbose=True,
|
|
140
|
-
enable_tag_versioning=enable_tag_versioning,
|
|
141
|
-
)
|
|
142
|
-
checkpoint_callback = setup_checkpoint_callback(
|
|
143
|
-
root_path / "checkpoints" / task.name / task.id,
|
|
144
|
-
checkpoint_cfg,
|
|
145
|
-
registry_uploader_callback=model_uploader,
|
|
146
|
-
uploading_strategy=checkpoint_upload_strategy,
|
|
147
|
-
)
|
|
148
|
-
if early_stopping_cfg is not None:
|
|
149
|
-
early_stopping_callback = setup_early_stopping_callback(early_stopping_cfg)
|
|
150
|
-
else:
|
|
151
|
-
early_stopping_callback = None
|
|
152
|
-
|
|
153
|
-
callbacks = Callbacks(
|
|
154
|
-
checkpoint=checkpoint_callback,
|
|
155
|
-
lr_monitor=lr_monitor,
|
|
156
|
-
early_stopping=early_stopping_callback,
|
|
157
|
-
)
|
|
158
|
-
return callbacks
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
def setup_loggers(task: Task, root_path: Path) -> list[TensorBoardLogger]:
|
|
162
|
-
"""
|
|
163
|
-
Set up PyTorch Lightning loggers for training.
|
|
164
|
-
|
|
165
|
-
Args:
|
|
166
|
-
task: ClearML task used to organize log directories by task name and ID.
|
|
167
|
-
root_path: Root directory for storing TensorBoard logs.
|
|
168
|
-
|
|
169
|
-
Returns:
|
|
170
|
-
List of configured TensorBoard loggers.
|
|
171
|
-
|
|
172
|
-
"""
|
|
173
|
-
loggers = [
|
|
174
|
-
setup_tb_logger(root_path / "runs" / task.name / task.id),
|
|
175
|
-
]
|
|
176
|
-
return loggers
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
def setup_strategy(
|
|
180
|
-
strategy_settings: TRAINING_STRATEGIES,
|
|
181
|
-
devices: list[int] | int,
|
|
182
|
-
auto_wrap_policy: set[type[Module]] | None = None,
|
|
183
|
-
) -> Literal["auto"] | FSDPStrategy | DDPStrategy:
|
|
184
|
-
"""
|
|
185
|
-
Configure and return a PyTorch Lightning training strategy.
|
|
186
|
-
|
|
187
|
-
Args:
|
|
188
|
-
strategy_settings: Strategy configuration object. Must be one of:
|
|
189
|
-
- `FSDP1StrategyConfig`: Fully Sharded Data Parallel strategy (requires 2+ devices).
|
|
190
|
-
- `DDPStrategyConfig`: Distributed Data Parallel strategy (requires 2+ devices).
|
|
191
|
-
- `SingleDeviceStrategyConfig`: Single device training (requires exactly 1 device).
|
|
192
|
-
devices: Device(s) to use for training. Either a list of device IDs or
|
|
193
|
-
a single integer representing the number of devices.
|
|
194
|
-
auto_wrap_policy: Set of module types that should be wrapped for FSDP.
|
|
195
|
-
Required when using `FSDP1StrategyConfig`, ignored otherwise.
|
|
196
|
-
|
|
197
|
-
Returns:
|
|
198
|
-
Configured strategy: `FSDPStrategy`, `DDPStrategy`, or `"auto"` for single device.
|
|
199
|
-
|
|
200
|
-
Raises:
|
|
201
|
-
ValueError: If device count doesn't match strategy requirements or
|
|
202
|
-
if `auto_wrap_policy` is missing for FSDP.
|
|
203
|
-
|
|
204
|
-
"""
|
|
205
|
-
if isinstance(devices, list):
|
|
206
|
-
num_devices = len(devices)
|
|
207
|
-
else:
|
|
208
|
-
num_devices = devices
|
|
209
|
-
|
|
210
|
-
match strategy_settings:
|
|
211
|
-
case FSDP1StrategyConfig():
|
|
212
|
-
if num_devices < 2:
|
|
213
|
-
raise ValueError("FSDP strategy requires multiple devices.")
|
|
214
|
-
|
|
215
|
-
if auto_wrap_policy is None:
|
|
216
|
-
raise ValueError("auto_wrap_policy must be provided for FSDP strategy.")
|
|
217
|
-
|
|
218
|
-
mixed_precision_config = MixedPrecision(
|
|
219
|
-
param_dtype=getattr(torch, strategy_settings.param_dtype),
|
|
220
|
-
reduce_dtype=getattr(torch, strategy_settings.reduce_dtype),
|
|
221
|
-
buffer_dtype=getattr(torch, strategy_settings.buffer_dtype),
|
|
222
|
-
)
|
|
223
|
-
strategy = FSDPStrategy(
|
|
224
|
-
auto_wrap_policy=auto_wrap_policy,
|
|
225
|
-
mixed_precision=mixed_precision_config,
|
|
226
|
-
)
|
|
227
|
-
case DDPStrategyConfig():
|
|
228
|
-
if num_devices < 2:
|
|
229
|
-
raise ValueError("DDP strategy requires at least two devices.")
|
|
230
|
-
strategy = DDPStrategy(
|
|
231
|
-
find_unused_parameters=strategy_settings.find_unused_parameters
|
|
232
|
-
)
|
|
233
|
-
case SingleDeviceStrategyConfig():
|
|
234
|
-
if num_devices != 1:
|
|
235
|
-
raise ValueError("SingleDevice strategy requires exactly one device.")
|
|
236
|
-
strategy = "auto"
|
|
237
|
-
case _:
|
|
238
|
-
raise ValueError(
|
|
239
|
-
f"Unsupported strategy type: {type(strategy_settings.trainer.strategy)}"
|
|
240
|
-
)
|
|
241
|
-
return strategy
|