kostyl-toolkit 0.1.35__py3-none-any.whl → 0.1.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kostyl/utils/logging.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
+ import os
4
5
  import sys
5
6
  import uuid
6
7
  from collections import namedtuple
@@ -18,32 +19,18 @@ from loguru import logger as _base_logger
18
19
  if TYPE_CHECKING:
19
20
  from loguru import Logger
20
21
 
21
- class CustomLogger(Logger): # noqa: D101
22
+ class KostylLogger(Logger): # noqa: D101
22
23
  def log_once(self, level: str, message: str, *args, **kwargs) -> None: ... # noqa: ANN003, D102
23
24
  def warning_once(self, message: str, *args, **kwargs) -> None: ... # noqa: ANN003, D102
24
25
  else:
25
- CustomLogger = type(_base_logger)
26
+ KostylLogger = type(_base_logger)
26
27
 
27
28
  try:
28
- import torch.distributed as dist
29
29
  from torch.nn.modules.module import (
30
30
  _IncompatibleKeys, # pyright: ignore[reportAssignmentType]
31
31
  )
32
32
  except Exception:
33
33
 
34
- class _Dummy:
35
- @staticmethod
36
- def is_available() -> bool:
37
- return False
38
-
39
- @staticmethod
40
- def is_initialized() -> bool:
41
- return False
42
-
43
- @staticmethod
44
- def get_rank() -> int:
45
- return 0
46
-
47
34
  class _IncompatibleKeys(
48
35
  namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
49
36
  ):
@@ -56,14 +43,13 @@ except Exception:
56
43
 
57
44
  __str__ = __repr__
58
45
 
59
- dist = _Dummy()
60
46
  _IncompatibleKeys = _IncompatibleKeys
61
47
 
62
48
  _once_lock = Lock()
63
49
  _once_keys: set[tuple[str, str]] = set()
64
50
 
65
51
 
66
- def _log_once(self: CustomLogger, level: str, message: str, *args, **kwargs) -> None: # noqa: ANN003
52
+ def _log_once(self: KostylLogger, level: str, message: str, *args, **kwargs) -> None: # noqa: ANN003
67
53
  key = (message, level)
68
54
 
69
55
  with _once_lock:
@@ -75,7 +61,7 @@ def _log_once(self: CustomLogger, level: str, message: str, *args, **kwargs) ->
75
61
  return
76
62
 
77
63
 
78
- _base_logger = cast(CustomLogger, _base_logger)
64
+ _base_logger = cast(KostylLogger, _base_logger)
79
65
  _base_logger.log_once = _log_once # pyright: ignore[reportAttributeAccessIssue]
80
66
  _base_logger.warning_once = partialmethod(_log_once, "WARNING") # pyright: ignore[reportAttributeAccessIssue]
81
67
 
@@ -91,44 +77,83 @@ _DEFAULT_FMT = "<level>{level: <8}</level> {time:HH:mm:ss.SSS} [{extra[channel]}
91
77
  _ONLY_MESSAGE_FMT = "<level>{message}</level>"
92
78
  _PRESETS = {"default": _DEFAULT_FMT, "only_message": _ONLY_MESSAGE_FMT}
93
79
 
80
+ KOSTYL_LOG_LEVEL = os.getenv("KOSTYL_LOG_LEVEL", "INFO")
81
+
94
82
 
95
83
  def setup_logger(
96
84
  name: str | None = None,
97
- fmt: Literal["default", "only_message"] | str = "default",
98
- level: str = "INFO",
99
- add_rank: bool | None = None,
85
+ fmt: Literal["default", "only_message"] | str = "only_message",
86
+ level: str | None = None,
100
87
  sink=sys.stdout,
101
88
  colorize: bool = True,
102
89
  serialize: bool = False,
103
- ) -> CustomLogger:
90
+ ) -> KostylLogger:
104
91
  """
105
- Returns a bound logger with its own sink and formatting.
92
+ Creates and configures a logger with custom formatting and output.
93
+
94
+ The function automatically removes the default sink on first call and creates
95
+ an isolated logger with a unique identifier for message filtering.
96
+
97
+ Args:
98
+ name (str | None, optional): Logger channel name. If None, automatically
99
+ uses the calling function's filename. Defaults to None.
100
+ fmt (Literal["default", "only_message"] | str, optional): Log message format.
101
+ Available presets:
102
+ - "default": includes level, time, and channel
103
+ - "only_message": outputs only the message itself
104
+ Custom format strings are also supported. Defaults to "only_message".
105
+ level (str | None, optional): Logging level (TRACE, DEBUG, INFO, SUCCESS,
106
+ WARNING, ERROR, CRITICAL). If None, uses the KOSTYL_LOG_LEVEL environment
107
+ variable or "INFO" by default. Defaults to None.
108
+ sink: Output object for logs (file, sys.stdout, sys.stderr, etc.).
109
+ Defaults to sys.stdout.
110
+ colorize (bool, optional): Enable colored output formatting.
111
+ Defaults to True.
112
+ serialize (bool, optional): Serialize logs to JSON format.
113
+ Defaults to False.
114
+
115
+ Returns:
116
+ CustomLogger: Configured logger instance with additional methods
117
+ log_once() and warning_once().
118
+
119
+ Example:
120
+ >>> # Basic usage with automatic name detection
121
+ >>> logger = setup_logger()
122
+ >>> logger.info("Hello World")
106
123
 
107
- Note: If name=None, the caller's filename (similar to __file__) is used automatically.
124
+ >>> # With custom name and level
125
+ >>> logger = setup_logger(name="MyApp", level="DEBUG")
126
+
127
+ >>> # With custom format
128
+ >>> logger = setup_logger(
129
+ ... name="API",
130
+ ... fmt="{level} | {time:YYYY-MM-DD HH:mm:ss} | {message}"
131
+ ... )
108
132
 
109
- Format example: "{level} {time:MM-DD HH:mm:ss} [{extra[channel]}] {message}"
110
133
  """
111
134
  global _DEFAULT_SINK_REMOVED
112
135
  if not _DEFAULT_SINK_REMOVED:
113
136
  _base_logger.remove()
114
137
  _DEFAULT_SINK_REMOVED = True
115
138
 
116
- if name is None:
117
- base = _caller_filename()
118
- else:
119
- base = name
139
+ if level is None:
140
+ if KOSTYL_LOG_LEVEL not in {
141
+ "TRACE",
142
+ "DEBUG",
143
+ "INFO",
144
+ "SUCCESS",
145
+ "WARNING",
146
+ "ERROR",
147
+ "CRITICAL",
148
+ }:
149
+ level = "INFO"
150
+ else:
151
+ level = KOSTYL_LOG_LEVEL
120
152
 
121
- if (add_rank is None) or add_rank:
122
- try:
123
- add_rank = dist.is_available() and dist.is_initialized()
124
- except Exception:
125
- add_rank = False
126
-
127
- if add_rank:
128
- rank = dist.get_rank()
129
- channel = f"rank:{rank} - {base}"
153
+ if name is None:
154
+ channel = _caller_filename()
130
155
  else:
131
- channel = base
156
+ channel = name
132
157
 
133
158
  if fmt in _PRESETS:
134
159
  fmt = _PRESETS[fmt]
@@ -146,7 +171,7 @@ def setup_logger(
146
171
  filter=lambda r: r["extra"].get("logger_id") == logger_id,
147
172
  )
148
173
  logger = _base_logger.bind(logger_id=logger_id, channel=channel)
149
- return cast(CustomLogger, logger)
174
+ return cast(KostylLogger, logger)
150
175
 
151
176
 
152
177
  def log_incompatible_keys(
@@ -154,22 +179,12 @@ def log_incompatible_keys(
154
179
  incompatible_keys: _IncompatibleKeys
155
180
  | tuple[list[str], list[str]]
156
181
  | dict[str, list[str]],
157
- model_specific_msg: str = "",
182
+ postfix_msg: str = "",
158
183
  ) -> None:
159
184
  """
160
185
  Logs warnings for incompatible keys encountered during model loading or state dict operations.
161
186
 
162
187
  Note: If incompatible_keys is of an unsupported type, an error message is logged and the function returns early.
163
-
164
- Args:
165
- logger (Logger): The logger instance used to output warning messages.
166
- incompatible_keys (_IncompatibleKeys | tuple[list[str], list[str]] | dict[str, list[str]]): An object containing lists of missing and unexpected keys.
167
- model_specific_msg (str, optional): A custom message to append to the log output, typically
168
- indicating the model or context. Defaults to an empty string.
169
-
170
- Returns:
171
- None
172
-
173
188
  """
174
189
  incompatible_keys_: dict[str, list[str]] = {}
175
190
  match incompatible_keys:
@@ -192,5 +207,5 @@ def log_incompatible_keys(
192
207
  return
193
208
 
194
209
  for name, keys in incompatible_keys_.items():
195
- logger.warning(f"{name} {model_specific_msg}: {', '.join(keys)}")
210
+ logger.warning(f"{name} {postfix_msg}: {', '.join(keys)}")
196
211
  return
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.35
3
+ Version: 0.1.37
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -6,32 +6,33 @@ kostyl/ml/clearml/logging_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWa
6
6
  kostyl/ml/clearml/pulling_utils.py,sha256=jMlVXcYRumwWnPlELRlgEdfq5L6Wir_EcfTmOoWBLTA,4077
7
7
  kostyl/ml/configs/__init__.py,sha256=IetcivbqYGutowLqxdKp7QR4tkXKBr4m8t4Zkk9jHZU,911
8
8
  kostyl/ml/configs/base_model.py,sha256=Eofn14J9RsjpVx_J4rp6C19pDDCANU4hr3JtX-d0FpQ,4820
9
- kostyl/ml/configs/hyperparams.py,sha256=2S_VEZ07RWquNFSWjHBb3OUpBlTznbUpFSchzMpSBOc,2879
10
- kostyl/ml/configs/training_settings.py,sha256=Sq2tiRuwkbmi9zKDG2JghZLXo5DDt_eQqN_KYJSdcTY,2509
9
+ kostyl/ml/configs/hyperparams.py,sha256=lvtbvOFEoTBAJug7FR35xMQdPLgDQjRoP2fyDP-jD7E,3305
10
+ kostyl/ml/configs/training_settings.py,sha256=wT9CHuLaKrLwonsc87Ee421EyFis_c9fqOgn9bSClm8,2747
11
11
  kostyl/ml/data_processing_utils.py,sha256=jjEjV0S0wREgZkzg27ip0LpI8cQqkwe2QwATmAqm9-g,3832
12
- kostyl/ml/dist_utils.py,sha256=Onf0KHVLA8oeUgZTcTdmR9qiM22f2uYLoNwgLbMGJWk,3495
12
+ kostyl/ml/dist_utils.py,sha256=lK9_aAh9L1SvvXWzcWiBoFjczfDiKzEpcno5csImAYQ,4635
13
13
  kostyl/ml/lightning/__init__.py,sha256=R36PImjVvzBF9t_z9u6RYVnUFJJ-sNDUOdboWUojHmM,173
14
14
  kostyl/ml/lightning/callbacks/__init__.py,sha256=EnKkNwwNDZnEqKRlpY4FVrqP88ECPF6nlT2bSLUIKRk,194
15
- kostyl/ml/lightning/callbacks/checkpoint.py,sha256=sZ9OqudO-gXp7FqtWaOH46TXVpeCJxV-EowyRPN836k,18983
15
+ kostyl/ml/lightning/callbacks/checkpoint.py,sha256=HI17gu-GxnfXUchflWBTwly7cCYnlpKcshuR-TgD6s4,19066
16
16
  kostyl/ml/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
17
17
  kostyl/ml/lightning/extensions/__init__.py,sha256=OY6QGv1agYgqqKf1xJBrxgp_i8FunVfPzYezfaRrGXU,182
18
- kostyl/ml/lightning/extensions/custom_module.py,sha256=iQrnPz-WTmRfvLo94C5fQc2Qwa1IpHtUh1sCpVwTSFM,6602
19
- kostyl/ml/lightning/extensions/pretrained_model.py,sha256=eRfQBzAjVernHl9A4PP5uTLvjjmcNKPdTu7ABFLq7HI,5196
18
+ kostyl/ml/lightning/extensions/custom_module.py,sha256=qYffgPwIB_ePwK_MIaRruuDxPKJZb42kg2yy996eGwY,6415
19
+ kostyl/ml/lightning/extensions/pretrained_model.py,sha256=hVIsIUu6Iryrz6S7GQTqog9vNq8LQyjJd2aoJ5Ws6KU,5253
20
20
  kostyl/ml/lightning/loggers/__init__.py,sha256=e51dszaoJbuzwBkbdugmuDsPldoSO4yaRgmZUg1Bdy0,71
21
- kostyl/ml/lightning/loggers/tb_logger.py,sha256=j02HK5ue8yzXXV8FWKmmXyHkFlIxgHx-ahHWk_rFCZs,893
22
- kostyl/ml/lightning/training_utils.py,sha256=u7X9ysF9Gqy8CdwacdcDlNQNsbagYAhslbv-1WLJ45k,9052
21
+ kostyl/ml/lightning/loggers/tb_logger.py,sha256=CpjlcEIT187cJXJgRYafqfzvcnwPgPaVZ0vLUflIr7k,899
22
+ kostyl/ml/lightning/utils.py,sha256=DhLy_3JA5VyMQkB1v6xxRxDNHfisjXFYVjuIKPpO81M,1967
23
23
  kostyl/ml/metrics_formatting.py,sha256=U6vdNENZLvp2dT1L3HqFKtXrHwGKoDXN93hvamPGHjM,1341
24
24
  kostyl/ml/params_groups.py,sha256=nUyw5d06Pvy9QPiYtZzLYR87xwXqJLxbHthgQH8oSCM,3583
25
- kostyl/ml/registry_uploader.py,sha256=W90TYo_WKv2oBE6nqEJl4hecYmJyyuKwQJ9_uUPGnJQ,3346
26
- kostyl/ml/schedulers/__init__.py,sha256=bxXbsU_WYnVbhvNNnuI7cOAh2Axz7D25TaleBTZhYfc,197
27
- kostyl/ml/schedulers/base.py,sha256=9M2iOoOVSRojR_liPX1qo3Nn4iMXSM5ZJuAFWZTulUk,1327
25
+ kostyl/ml/registry_uploader.py,sha256=BbyLXvF8AL145k7g6MRkJ7gf_3Um53p3Pn5280vVD9U,4384
26
+ kostyl/ml/schedulers/__init__.py,sha256=_EtZu8DwTCSv4-eR84kRstEZblHylVqda7WQUOXIKfw,534
27
+ kostyl/ml/schedulers/base.py,sha256=bjmwgdZpnSqpCnHPnKC6MEiRO79cwxMJpZq-eQVNs2M,1353
28
28
  kostyl/ml/schedulers/composite.py,sha256=ee4xlMDMMtjKPkbTF2ue9GTr9DuGCGjZWf11mHbi6aE,2387
29
- kostyl/ml/schedulers/cosine.py,sha256=t74_ByT22L5NQKpnBVU9UGzBVx1ZM2GTylb9ct3_PVg,7627
30
- kostyl/ml/schedulers/linear.py,sha256=7HPkVWcPa0lbaZywutXSDdVLLSihAyWk5XIE2Dzj_5Q,5168
29
+ kostyl/ml/schedulers/cosine.py,sha256=y8ylrgVOkVcr2-ExoqqNW--tdDX88TBYPQCOppIf2_M,8685
30
+ kostyl/ml/schedulers/cosine_with_plateu.py,sha256=0-X6wl3HgsTiLIbISb9lOxIVWXHDEND7rILitMWtIiM,10195
31
+ kostyl/ml/schedulers/linear.py,sha256=RnnnblRuRXP3LT03QVIHUaK2kNsiMP1AedrMoeyh3qk,5843
31
32
  kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
32
33
  kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
33
34
  kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
34
- kostyl/utils/logging.py,sha256=Vye0u4-yeOSUc-f03gpQbxSktTbFiilTWLEVr00ZHvc,5796
35
- kostyl_toolkit-0.1.35.dist-info/WHEEL,sha256=ZyFSCYkV2BrxH6-HRVRg3R9Fo7MALzer9KiPYqNxSbo,79
36
- kostyl_toolkit-0.1.35.dist-info/METADATA,sha256=KL4-Z421DpchI6KUZ6tVATy99urk1OP2OY4Uf5r9R3U,4269
37
- kostyl_toolkit-0.1.35.dist-info/RECORD,,
35
+ kostyl/utils/logging.py,sha256=CgNFNogcK0hoZmygvBWlTcq5A3m2Pfv9eOAP_gwx0pM,6633
36
+ kostyl_toolkit-0.1.37.dist-info/WHEEL,sha256=eycQt0QpYmJMLKpE3X9iDk8R04v2ZF0x82ogq-zP6bQ,79
37
+ kostyl_toolkit-0.1.37.dist-info/METADATA,sha256=yHPgSAhPnm5tDQjvDIfs213-bsVX6vMfVsUbX9GboGU,4269
38
+ kostyl_toolkit-0.1.37.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: uv 0.9.18
2
+ Generator: uv 0.9.24
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,241 +0,0 @@
1
- from dataclasses import dataclass
2
- from dataclasses import fields
3
- from pathlib import Path
4
- from typing import Literal
5
- from typing import cast
6
-
7
- import lightning as L
8
- import torch
9
- import torch.distributed as dist
10
- from clearml import OutputModel
11
- from clearml import Task
12
- from lightning.pytorch.callbacks import Callback
13
- from lightning.pytorch.callbacks import EarlyStopping
14
- from lightning.pytorch.callbacks import LearningRateMonitor
15
- from lightning.pytorch.callbacks import ModelCheckpoint
16
- from lightning.pytorch.loggers import TensorBoardLogger
17
- from lightning.pytorch.strategies import DDPStrategy
18
- from lightning.pytorch.strategies import FSDPStrategy
19
- from torch.distributed import ProcessGroup
20
- from torch.distributed.fsdp import MixedPrecision
21
- from torch.nn import Module
22
-
23
- from kostyl.ml.configs import CheckpointConfig
24
- from kostyl.ml.configs import DDPStrategyConfig
25
- from kostyl.ml.configs import EarlyStoppingConfig
26
- from kostyl.ml.configs import FSDP1StrategyConfig
27
- from kostyl.ml.configs import SingleDeviceStrategyConfig
28
- from kostyl.ml.lightning.callbacks import setup_checkpoint_callback
29
- from kostyl.ml.lightning.callbacks import setup_early_stopping_callback
30
- from kostyl.ml.lightning.loggers import setup_tb_logger
31
- from kostyl.ml.registry_uploader import ClearMLRegistryUploaderCallback
32
- from kostyl.utils.logging import setup_logger
33
-
34
-
35
- TRAINING_STRATEGIES = (
36
- FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
37
- )
38
-
39
- logger = setup_logger(add_rank=True)
40
-
41
-
42
- def estimate_total_steps(
43
- trainer: L.Trainer, process_group: ProcessGroup | None = None
44
- ) -> int:
45
- """
46
- Estimates the total number of training steps based on the
47
- dataloader length, accumulation steps, and distributed world size.
48
- """ # noqa: D205
49
- if dist.is_initialized():
50
- world_size = dist.get_world_size(process_group)
51
- else:
52
- world_size = 1
53
-
54
- datamodule = trainer.datamodule # type: ignore
55
- if datamodule is None:
56
- raise ValueError("Trainer must have a datamodule to estimate total steps.")
57
- datamodule = cast(L.LightningDataModule, datamodule)
58
-
59
- logger.info("Loading `train_dataloader` to estimate number of stepping batches.")
60
- datamodule.setup("fit")
61
-
62
- dataloader_len = len(datamodule.train_dataloader())
63
- steps_per_epoch = dataloader_len // trainer.accumulate_grad_batches // world_size
64
-
65
- if trainer.max_epochs is None:
66
- raise ValueError("Trainer must have `max_epochs` set to estimate total steps.")
67
- total_steps = steps_per_epoch * trainer.max_epochs
68
-
69
- logger.info(
70
- f"Total steps: {total_steps} (per-epoch: {steps_per_epoch})\n"
71
- f"-> Dataloader len: {dataloader_len}\n"
72
- f"-> Accumulate grad batches: {trainer.accumulate_grad_batches}\n"
73
- f"-> Epochs: {trainer.max_epochs}\n "
74
- f"-> World size: {world_size}"
75
- )
76
- return total_steps
77
-
78
-
79
- @dataclass
80
- class Callbacks:
81
- """Dataclass to hold PyTorch Lightning callbacks."""
82
-
83
- checkpoint: ModelCheckpoint
84
- lr_monitor: LearningRateMonitor
85
- early_stopping: EarlyStopping | None = None
86
-
87
- def to_list(self) -> list[Callback]:
88
- """Convert dataclass fields to a list of Callbacks. None values are omitted."""
89
- callbacks: list[Callback] = [
90
- getattr(self, field.name)
91
- for field in fields(self)
92
- if getattr(self, field.name) is not None
93
- ]
94
- return callbacks
95
-
96
-
97
- def setup_callbacks(
98
- task: Task,
99
- root_path: Path,
100
- checkpoint_cfg: CheckpointConfig,
101
- early_stopping_cfg: EarlyStoppingConfig | None,
102
- output_model: OutputModel,
103
- checkpoint_upload_strategy: Literal["only-best", "every-checkpoint"],
104
- config_dict: dict[str, str] | None = None,
105
- enable_tag_versioning: bool = False,
106
- ) -> Callbacks:
107
- """
108
- Set up PyTorch Lightning callbacks for training.
109
-
110
- Creates and configures a set of callbacks including checkpoint saving,
111
- learning rate monitoring, model registry uploading, and optional early stopping.
112
-
113
- Args:
114
- task: ClearML task for organizing checkpoints by task name and ID.
115
- root_path: Root directory for saving checkpoints.
116
- checkpoint_cfg: Configuration for checkpoint saving behavior.
117
- checkpoint_upload_strategy: Model upload strategy:
118
- - `"only-best"`: Upload only the best checkpoint based on monitored metric.
119
- - `"every-checkpoint"`: Upload every saved checkpoint.
120
- output_model: ClearML OutputModel instance for model registry integration.
121
- early_stopping_cfg: Configuration for early stopping. If None, early stopping
122
- is disabled.
123
- config_dict: Optional configuration dictionary to store with the model
124
- in the registry.
125
- enable_tag_versioning: Whether to auto-increment version tags (e.g., "v1.0")
126
- on the uploaded model.
127
-
128
- Returns:
129
- Callbacks dataclass containing configured checkpoint, lr_monitor,
130
- and optionally early_stopping callbacks.
131
-
132
- """
133
- lr_monitor = LearningRateMonitor(
134
- logging_interval="step", log_weight_decay=True, log_momentum=False
135
- )
136
- model_uploader = ClearMLRegistryUploaderCallback(
137
- output_model=output_model,
138
- config_dict=config_dict,
139
- verbose=True,
140
- enable_tag_versioning=enable_tag_versioning,
141
- )
142
- checkpoint_callback = setup_checkpoint_callback(
143
- root_path / "checkpoints" / task.name / task.id,
144
- checkpoint_cfg,
145
- registry_uploader_callback=model_uploader,
146
- uploading_strategy=checkpoint_upload_strategy,
147
- )
148
- if early_stopping_cfg is not None:
149
- early_stopping_callback = setup_early_stopping_callback(early_stopping_cfg)
150
- else:
151
- early_stopping_callback = None
152
-
153
- callbacks = Callbacks(
154
- checkpoint=checkpoint_callback,
155
- lr_monitor=lr_monitor,
156
- early_stopping=early_stopping_callback,
157
- )
158
- return callbacks
159
-
160
-
161
- def setup_loggers(task: Task, root_path: Path) -> list[TensorBoardLogger]:
162
- """
163
- Set up PyTorch Lightning loggers for training.
164
-
165
- Args:
166
- task: ClearML task used to organize log directories by task name and ID.
167
- root_path: Root directory for storing TensorBoard logs.
168
-
169
- Returns:
170
- List of configured TensorBoard loggers.
171
-
172
- """
173
- loggers = [
174
- setup_tb_logger(root_path / "runs" / task.name / task.id),
175
- ]
176
- return loggers
177
-
178
-
179
- def setup_strategy(
180
- strategy_settings: TRAINING_STRATEGIES,
181
- devices: list[int] | int,
182
- auto_wrap_policy: set[type[Module]] | None = None,
183
- ) -> Literal["auto"] | FSDPStrategy | DDPStrategy:
184
- """
185
- Configure and return a PyTorch Lightning training strategy.
186
-
187
- Args:
188
- strategy_settings: Strategy configuration object. Must be one of:
189
- - `FSDP1StrategyConfig`: Fully Sharded Data Parallel strategy (requires 2+ devices).
190
- - `DDPStrategyConfig`: Distributed Data Parallel strategy (requires 2+ devices).
191
- - `SingleDeviceStrategyConfig`: Single device training (requires exactly 1 device).
192
- devices: Device(s) to use for training. Either a list of device IDs or
193
- a single integer representing the number of devices.
194
- auto_wrap_policy: Set of module types that should be wrapped for FSDP.
195
- Required when using `FSDP1StrategyConfig`, ignored otherwise.
196
-
197
- Returns:
198
- Configured strategy: `FSDPStrategy`, `DDPStrategy`, or `"auto"` for single device.
199
-
200
- Raises:
201
- ValueError: If device count doesn't match strategy requirements or
202
- if `auto_wrap_policy` is missing for FSDP.
203
-
204
- """
205
- if isinstance(devices, list):
206
- num_devices = len(devices)
207
- else:
208
- num_devices = devices
209
-
210
- match strategy_settings:
211
- case FSDP1StrategyConfig():
212
- if num_devices < 2:
213
- raise ValueError("FSDP strategy requires multiple devices.")
214
-
215
- if auto_wrap_policy is None:
216
- raise ValueError("auto_wrap_policy must be provided for FSDP strategy.")
217
-
218
- mixed_precision_config = MixedPrecision(
219
- param_dtype=getattr(torch, strategy_settings.param_dtype),
220
- reduce_dtype=getattr(torch, strategy_settings.reduce_dtype),
221
- buffer_dtype=getattr(torch, strategy_settings.buffer_dtype),
222
- )
223
- strategy = FSDPStrategy(
224
- auto_wrap_policy=auto_wrap_policy,
225
- mixed_precision=mixed_precision_config,
226
- )
227
- case DDPStrategyConfig():
228
- if num_devices < 2:
229
- raise ValueError("DDP strategy requires at least two devices.")
230
- strategy = DDPStrategy(
231
- find_unused_parameters=strategy_settings.find_unused_parameters
232
- )
233
- case SingleDeviceStrategyConfig():
234
- if num_devices != 1:
235
- raise ValueError("SingleDevice strategy requires exactly one device.")
236
- strategy = "auto"
237
- case _:
238
- raise ValueError(
239
- f"Unsupported strategy type: {type(strategy_settings.trainer.strategy)}"
240
- )
241
- return strategy