kostyl-toolkit 0.1.36__py3-none-any.whl → 0.1.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. kostyl/ml/base_uploader.py +17 -0
  2. kostyl/ml/configs/__init__.py +2 -2
  3. kostyl/ml/configs/mixins.py +50 -0
  4. kostyl/ml/{data_processing_utils.py → data_collator.py} +6 -3
  5. kostyl/ml/dist_utils.py +53 -33
  6. kostyl/ml/integrations/clearml/__init__.py +7 -0
  7. kostyl/ml/{registry_uploader.py → integrations/clearml/checkpoint_uploader.py} +3 -13
  8. kostyl/ml/{configs/base_model.py → integrations/clearml/config_mixin.py} +7 -63
  9. kostyl/ml/{clearml/pulling_utils.py → integrations/clearml/loading_utils.py} +32 -5
  10. kostyl/ml/integrations/lightning/__init__.py +14 -0
  11. kostyl/ml/{lightning → integrations/lightning}/callbacks/checkpoint.py +27 -42
  12. kostyl/ml/{lightning → integrations/lightning}/loggers/tb_logger.py +2 -2
  13. kostyl/ml/{lightning/extensions/pretrained_model.py → integrations/lightning/mixins.py} +6 -4
  14. kostyl/ml/{lightning/extensions/custom_module.py → integrations/lightning/module.py} +2 -38
  15. kostyl/ml/{lightning → integrations/lightning}/utils.py +1 -1
  16. kostyl/ml/schedulers/__init__.py +4 -4
  17. kostyl/ml/schedulers/{cosine_with_plateu.py → plateau.py} +59 -36
  18. kostyl/utils/logging.py +67 -52
  19. {kostyl_toolkit-0.1.36.dist-info → kostyl_toolkit-0.1.38.dist-info}/METADATA +1 -1
  20. kostyl_toolkit-0.1.38.dist-info/RECORD +40 -0
  21. {kostyl_toolkit-0.1.36.dist-info → kostyl_toolkit-0.1.38.dist-info}/WHEEL +2 -2
  22. kostyl/ml/lightning/__init__.py +0 -5
  23. kostyl/ml/lightning/extensions/__init__.py +0 -5
  24. kostyl_toolkit-0.1.36.dist-info/RECORD +0 -38
  25. /kostyl/ml/{clearml → integrations}/__init__.py +0 -0
  26. /kostyl/ml/{clearml → integrations/clearml}/dataset_utils.py +0 -0
  27. /kostyl/ml/{clearml/logging_utils.py → integrations/clearml/version_utils.py} +0 -0
  28. /kostyl/ml/{lightning → integrations/lightning}/callbacks/__init__.py +0 -0
  29. /kostyl/ml/{lightning → integrations/lightning}/callbacks/early_stopping.py +0 -0
  30. /kostyl/ml/{lightning → integrations/lightning}/loggers/__init__.py +0 -0
  31. /kostyl/ml/{metrics_formatting.py → integrations/lightning/metrics_formatting.py} +0 -0
@@ -12,12 +12,12 @@ from kostyl.utils.logging import setup_logger
12
12
  logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
13
13
 
14
14
 
15
- class LightningCheckpointLoaderMixin(PreTrainedModel):
15
+ class LightningCheckpointLoaderMixin:
16
16
  """A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
17
17
 
18
18
  @classmethod
19
- def from_lightning_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
20
- cls: type[TModelInstance],
19
+ def from_lightning_checkpoint[TModelInstance: PreTrainedModel]( # noqa: C901
20
+ cls: type[TModelInstance], # pyright: ignore[reportGeneralTypeIssues]
21
21
  checkpoint_path: str | Path,
22
22
  config_key: str = "config",
23
23
  weights_prefix: str | None = "model.",
@@ -78,7 +78,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
78
78
  mmap=True,
79
79
  )
80
80
 
81
- # 1. Восстанавливаем конфиг
81
+ # Load config
82
82
  config_cls = cast(type[PretrainedConfig], cls.config_class)
83
83
  config_dict = checkpoint_dict[config_key]
84
84
  config_dict.update(kwargs)
@@ -91,6 +91,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
91
91
 
92
92
  raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
93
93
 
94
+ # Handle weights prefix
94
95
  if weights_prefix:
95
96
  if not weights_prefix.endswith("."):
96
97
  weights_prefix = weights_prefix + "."
@@ -117,6 +118,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
117
118
  else:
118
119
  state_dict = raw_state_dict
119
120
 
121
+ # Instantiate model and load state dict
120
122
  model = cls.from_pretrained(
121
123
  pretrained_model_name_or_path=None,
122
124
  config=config,
@@ -5,17 +5,15 @@ from typing import override
5
5
 
6
6
  import lightning as L
7
7
  import torch
8
- import torch.distributed as dist
9
8
  from lightning.pytorch.strategies import FSDPStrategy
10
9
  from torch import nn
11
- from torch.distributed import ProcessGroup
12
10
  from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
13
11
  from torchmetrics import Metric
14
12
  from torchmetrics import MetricCollection
15
13
  from transformers import PretrainedConfig
16
14
  from transformers import PreTrainedModel
17
15
 
18
- from kostyl.ml.metrics_formatting import apply_suffix
16
+ from kostyl.ml.integrations.lightning.metrics_formatting import apply_suffix
19
17
  from kostyl.ml.schedulers.base import BaseScheduler
20
18
  from kostyl.utils import setup_logger
21
19
 
@@ -26,37 +24,6 @@ module_logger = setup_logger(fmt="only_message")
26
24
  class KostylLightningModule(L.LightningModule):
27
25
  """Custom PyTorch Lightning Module with logging, checkpointing, and distributed training utilities."""
28
26
 
29
- @property
30
- def process_group(self) -> ProcessGroup | None:
31
- """Returns the data parallel process group for distributed training."""
32
- return self.get_process_group()
33
-
34
- def get_process_group(self) -> ProcessGroup | None:
35
- """
36
- Retrieves the data parallel process group for distributed training.
37
-
38
- This method checks if distributed processing is initialized. If a device mesh is provided,
39
- it extracts the data parallel mesh and returns its process group, unless the mesh size is 1,
40
- in which case it logs a warning and returns None. If no device mesh is provided, it returns
41
- the world process group.
42
-
43
- Returns:
44
- ProcessGroup | None: The data parallel process group if available and valid, otherwise None.
45
-
46
- """
47
- if not dist.is_initialized():
48
- return None
49
-
50
- if self.device_mesh is not None:
51
- dp_mesh = self.device_mesh["data_parallel"]
52
- if dp_mesh.size() == 1:
53
- module_logger.warning("Data parallel mesh size is 1, returning None")
54
- return None
55
- dp_pg = dp_mesh.get_group()
56
- else:
57
- dp_pg = dist.group.WORLD
58
- return dp_pg
59
-
60
27
  @property
61
28
  def model_instance(self) -> PreTrainedModel | nn.Module:
62
29
  """Returns the underlying model."""
@@ -65,10 +32,7 @@ class KostylLightningModule(L.LightningModule):
65
32
  @property
66
33
  def model_config(self) -> PretrainedConfig | None:
67
34
  """Returns the model configuration if available."""
68
- model = self.model_instance
69
- if hasattr(model, "config"):
70
- return model.config # type: ignore
71
- return None
35
+ raise NotImplementedError
72
36
 
73
37
  @property
74
38
  def grad_clip_val(self) -> float | None:
@@ -14,7 +14,7 @@ TRAINING_STRATEGIES = (
14
14
  FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
15
15
  )
16
16
 
17
- logger = setup_logger(add_rank=True)
17
+ logger = setup_logger()
18
18
 
19
19
 
20
20
  def estimate_total_steps(
@@ -1,18 +1,18 @@
1
1
  from .composite import CompositeScheduler
2
2
  from .cosine import CosineParamScheduler
3
3
  from .cosine import CosineScheduler
4
- from .cosine_with_plateu import CosineWithPlateauParamScheduler
5
- from .cosine_with_plateu import CosineWithPlateuScheduler
6
4
  from .linear import LinearParamScheduler
7
5
  from .linear import LinearScheduler
6
+ from .plateau import PlateauWithAnnealingParamScheduler
7
+ from .plateau import PlateauWithAnnealingScheduler
8
8
 
9
9
 
10
10
  __all__ = [
11
11
  "CompositeScheduler",
12
12
  "CosineParamScheduler",
13
13
  "CosineScheduler",
14
- "CosineWithPlateauParamScheduler",
15
- "CosineWithPlateuScheduler",
16
14
  "LinearParamScheduler",
17
15
  "LinearScheduler",
16
+ "PlateauWithAnnealingParamScheduler",
17
+ "PlateauWithAnnealingScheduler",
18
18
  ]
@@ -1,4 +1,5 @@
1
1
  from typing import Any
2
+ from typing import Literal
2
3
  from typing import override
3
4
 
4
5
  import numpy as np
@@ -7,20 +8,25 @@ import torch
7
8
  from .base import BaseScheduler
8
9
 
9
10
 
10
- class _CosineWithPlateauSchedulerCore(BaseScheduler):
11
- """Core cosine with plateau scheduler logic."""
11
+ class _PlateauWithAnnealingCore(BaseScheduler):
12
+ """Core annealing with plateau scheduler logic."""
12
13
 
13
14
  def __init__(
14
15
  self,
15
16
  param_name: str,
16
17
  num_iters: int,
17
- base_value: float,
18
+ plateau_value: float,
18
19
  final_value: float,
19
20
  plateau_ratio: float,
20
21
  warmup_value: float | None = None,
21
22
  warmup_ratio: float | None = None,
22
23
  freeze_ratio: float | None = None,
24
+ annealing_type: Literal["cosine", "linear"] = "cosine",
23
25
  ) -> None:
26
+ if annealing_type not in ("cosine", "linear"):
27
+ raise ValueError(
28
+ f"Annealing type must be 'cosine' or 'linear', got {annealing_type}."
29
+ )
24
30
  if warmup_ratio is not None:
25
31
  if not (0 < warmup_ratio < 1):
26
32
  raise ValueError(f"Warmup ratio must be in (0, 1), got {warmup_ratio}.")
@@ -47,16 +53,17 @@ class _CosineWithPlateauSchedulerCore(BaseScheduler):
47
53
 
48
54
  self.param_name = param_name
49
55
  self.num_iters = num_iters
50
- self.base_value = base_value
56
+ self.plateau_value = plateau_value
51
57
  self.final_value = final_value
52
- self.cosine_annealing_ratio = 1 - pre_annealing_ratio
58
+ self.annealing_ratio = 1 - pre_annealing_ratio
53
59
  self.plateau_ratio = plateau_ratio
54
60
  self.warmup_ratio = warmup_ratio
55
61
  self.warmup_value = warmup_value
56
62
  self.freeze_ratio = freeze_ratio
63
+ self.annealing_type = annealing_type
57
64
 
58
65
  self.scheduled_values: np.ndarray = np.array([], dtype=np.float64)
59
- self.current_value_ = self.base_value
66
+ self.current_value_ = self.plateau_value
60
67
  return
61
68
 
62
69
  def _create_scheduler(self) -> None:
@@ -72,28 +79,41 @@ class _CosineWithPlateauSchedulerCore(BaseScheduler):
72
79
  if self.warmup_ratio is not None and self.warmup_value is not None:
73
80
  warmup_iters = int(self.num_iters * self.warmup_ratio)
74
81
  warmup_schedule = np.linspace(
75
- self.warmup_value, self.base_value, warmup_iters, dtype=np.float64
82
+ self.warmup_value, self.plateau_value, warmup_iters, dtype=np.float64
76
83
  )
77
84
  else:
78
85
  warmup_iters = 0
79
86
  warmup_schedule = np.array([], dtype=np.float64)
80
87
 
81
- # Create cosine annealing schedule
82
- if self.cosine_annealing_ratio > 0:
83
- cosine_annealing_iters = int(self.num_iters * self.cosine_annealing_ratio)
84
- iters = np.arange(cosine_annealing_iters)
85
- cosine_annealing_schedule = self.final_value + 0.5 * (
86
- self.base_value - self.final_value
87
- ) * (1 + np.cos(np.pi * iters / len(iters)))
88
+ # Create annealing schedule
89
+ if self.annealing_ratio > 0:
90
+ annealing_iters = int(self.num_iters * self.annealing_ratio)
91
+ match self.annealing_type:
92
+ case "cosine":
93
+ iters = np.arange(annealing_iters)
94
+ annealing_schedule = self.final_value + 0.5 * (
95
+ self.plateau_value - self.final_value
96
+ ) * (1 + np.cos(np.pi * iters / len(iters)))
97
+ case "linear":
98
+ annealing_schedule = np.linspace(
99
+ self.plateau_value,
100
+ self.final_value,
101
+ annealing_iters,
102
+ dtype=np.float64,
103
+ )
104
+ case _:
105
+ raise ValueError(
106
+ f"Unsupported annealing type: {self.annealing_type}"
107
+ )
88
108
  else:
89
- cosine_annealing_iters = 0
90
- cosine_annealing_schedule = np.array([], dtype=np.float64)
109
+ annealing_iters = 0
110
+ annealing_schedule = np.array([], dtype=np.float64)
91
111
 
92
- plateau_iters = (
93
- self.num_iters - warmup_iters - freeze_iters - cosine_annealing_iters
94
- )
112
+ plateau_iters = self.num_iters - warmup_iters - freeze_iters - annealing_iters
95
113
  if plateau_iters > 0:
96
- plateau_schedule = np.full(plateau_iters, self.base_value, dtype=np.float64)
114
+ plateau_schedule = np.full(
115
+ plateau_iters, self.plateau_value, dtype=np.float64
116
+ )
97
117
  else:
98
118
  plateau_schedule = np.array([], dtype=np.float64)
99
119
 
@@ -103,7 +123,7 @@ class _CosineWithPlateauSchedulerCore(BaseScheduler):
103
123
  freeze_schedule,
104
124
  warmup_schedule,
105
125
  plateau_schedule,
106
- cosine_annealing_schedule,
126
+ annealing_schedule,
107
127
  )
108
128
  )
109
129
  self._verify()
@@ -137,12 +157,12 @@ class _CosineWithPlateauSchedulerCore(BaseScheduler):
137
157
  return {self.param_name: self.current_value_}
138
158
 
139
159
 
140
- class CosineWithPlateuScheduler(_CosineWithPlateauSchedulerCore):
160
+ class PlateauWithAnnealingScheduler(_PlateauWithAnnealingCore):
141
161
  """
142
- Applies a cosine schedule with plateau to an optimizer param-group field.
162
+ Applies an annealing schedule with plateau to an optimizer param-group field.
143
163
 
144
- Schedule phases: freeze (0) → warmup → plateau (base_value) → cosine annealing to final_value.
145
- The plateau phase maintains the base_value before cosine annealing begins.
164
+ Schedule phases: freeze (0) → warmup → plateau (plateau_value) → annealing (cosine/linear) to final_value.
165
+ The plateau phase maintains the plateau_value before annealing begins.
146
166
  """
147
167
 
148
168
  def __init__(
@@ -150,30 +170,32 @@ class CosineWithPlateuScheduler(_CosineWithPlateauSchedulerCore):
150
170
  optimizer: torch.optim.Optimizer,
151
171
  param_group_field: str,
152
172
  num_iters: int,
153
- base_value: float,
173
+ plateau_value: float,
154
174
  final_value: float,
155
175
  plateau_ratio: float,
156
176
  warmup_value: float | None = None,
157
177
  warmup_ratio: float | None = None,
158
178
  freeze_ratio: float | None = None,
179
+ annealing_type: Literal["cosine", "linear"] = "cosine",
159
180
  multiplier_field: str | None = None,
160
181
  skip_if_zero: bool = False,
161
182
  apply_if_field: str | None = None,
162
183
  ignore_if_field: str | None = None,
163
184
  ) -> None:
164
185
  """
165
- Configure cosine scheduling for matching optimizer groups.
186
+ Configure annealing scheduling for matching optimizer groups.
166
187
 
167
188
  Args:
168
189
  optimizer: Optimizer whose param groups are updated in-place.
169
190
  param_group_field: Name of the field that receives the scheduled value.
170
191
  num_iters: Number of scheduler iterations before clamping at ``final_value``.
171
- base_value: Value maintained during plateau phase and used as cosine start.
172
- final_value: Value approached as iterations progress during cosine annealing.
173
- plateau_ratio: Fraction of iterations to maintain ``base_value`` before cosine annealing.
174
- warmup_ratio: Optional fraction of iterations to linearly ramp from ``warmup_value`` to ``base_value``.
192
+ plateau_value: Value maintained during plateau phase and used as annealing start.
193
+ final_value: Value approached as iterations progress during annealing.
194
+ plateau_ratio: Fraction of iterations to maintain ``plateau_value`` before annealing.
195
+ warmup_ratio: Optional fraction of iterations to linearly ramp from ``warmup_value`` to ``plateau_value``.
175
196
  warmup_value: Starting value for the warmup ramp.
176
197
  freeze_ratio: Optional fraction of iterations to keep the value frozen at zero at the beginning.
198
+ annealing_type: Type of annealing from plateau to final value ("cosine" or "linear").
177
199
  multiplier_field: Optional per-group multiplier applied to the scheduled value.
178
200
  skip_if_zero: Leave groups untouched when their target field equals zero.
179
201
  apply_if_field: Require this flag to be present in a param group before updating.
@@ -188,12 +210,13 @@ class CosineWithPlateuScheduler(_CosineWithPlateauSchedulerCore):
188
210
  super().__init__(
189
211
  param_name=param_group_field,
190
212
  num_iters=num_iters,
191
- base_value=base_value,
213
+ plateau_value=plateau_value,
192
214
  final_value=final_value,
193
215
  plateau_ratio=plateau_ratio,
194
216
  warmup_ratio=warmup_ratio,
195
217
  warmup_value=warmup_value,
196
218
  freeze_ratio=freeze_ratio,
219
+ annealing_type=annealing_type,
197
220
  )
198
221
  self.param_group_field = param_group_field
199
222
  return
@@ -242,12 +265,12 @@ class CosineWithPlateuScheduler(_CosineWithPlateauSchedulerCore):
242
265
  return
243
266
 
244
267
 
245
- class CosineWithPlateauParamScheduler(_CosineWithPlateauSchedulerCore):
268
+ class PlateauWithAnnealingParamScheduler(_PlateauWithAnnealingCore):
246
269
  """
247
- Standalone cosine scheduler with plateau for non-optimizer parameters.
270
+ Standalone annealing scheduler with plateau for non-optimizer parameters.
248
271
 
249
- Schedule phases: freeze (0) → warmup → plateau (base_value) → cosine annealing to final_value.
250
- The plateau phase maintains the base_value before cosine annealing begins.
272
+ Schedule phases: freeze (0) → warmup → plateau (plateau_value) → annealing (cosine/linear) to final_value.
273
+ The plateau phase maintains the plateau_value before annealing begins.
251
274
  """
252
275
 
253
276
  @override
kostyl/utils/logging.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
+ import os
4
5
  import sys
5
6
  import uuid
6
7
  from collections import namedtuple
@@ -18,32 +19,18 @@ from loguru import logger as _base_logger
18
19
  if TYPE_CHECKING:
19
20
  from loguru import Logger
20
21
 
21
- class CustomLogger(Logger): # noqa: D101
22
+ class KostylLogger(Logger): # noqa: D101
22
23
  def log_once(self, level: str, message: str, *args, **kwargs) -> None: ... # noqa: ANN003, D102
23
24
  def warning_once(self, message: str, *args, **kwargs) -> None: ... # noqa: ANN003, D102
24
25
  else:
25
- CustomLogger = type(_base_logger)
26
+ KostylLogger = type(_base_logger)
26
27
 
27
28
  try:
28
- import torch.distributed as dist
29
29
  from torch.nn.modules.module import (
30
30
  _IncompatibleKeys, # pyright: ignore[reportAssignmentType]
31
31
  )
32
32
  except Exception:
33
33
 
34
- class _Dummy:
35
- @staticmethod
36
- def is_available() -> bool:
37
- return False
38
-
39
- @staticmethod
40
- def is_initialized() -> bool:
41
- return False
42
-
43
- @staticmethod
44
- def get_rank() -> int:
45
- return 0
46
-
47
34
  class _IncompatibleKeys(
48
35
  namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
49
36
  ):
@@ -56,14 +43,13 @@ except Exception:
56
43
 
57
44
  __str__ = __repr__
58
45
 
59
- dist = _Dummy()
60
46
  _IncompatibleKeys = _IncompatibleKeys
61
47
 
62
48
  _once_lock = Lock()
63
49
  _once_keys: set[tuple[str, str]] = set()
64
50
 
65
51
 
66
- def _log_once(self: CustomLogger, level: str, message: str, *args, **kwargs) -> None: # noqa: ANN003
52
+ def _log_once(self: KostylLogger, level: str, message: str, *args, **kwargs) -> None: # noqa: ANN003
67
53
  key = (message, level)
68
54
 
69
55
  with _once_lock:
@@ -75,7 +61,7 @@ def _log_once(self: CustomLogger, level: str, message: str, *args, **kwargs) ->
75
61
  return
76
62
 
77
63
 
78
- _base_logger = cast(CustomLogger, _base_logger)
64
+ _base_logger = cast(KostylLogger, _base_logger)
79
65
  _base_logger.log_once = _log_once # pyright: ignore[reportAttributeAccessIssue]
80
66
  _base_logger.warning_once = partialmethod(_log_once, "WARNING") # pyright: ignore[reportAttributeAccessIssue]
81
67
 
@@ -91,44 +77,83 @@ _DEFAULT_FMT = "<level>{level: <8}</level> {time:HH:mm:ss.SSS} [{extra[channel]}
91
77
  _ONLY_MESSAGE_FMT = "<level>{message}</level>"
92
78
  _PRESETS = {"default": _DEFAULT_FMT, "only_message": _ONLY_MESSAGE_FMT}
93
79
 
80
+ KOSTYL_LOG_LEVEL = os.getenv("KOSTYL_LOG_LEVEL", "INFO")
81
+
94
82
 
95
83
  def setup_logger(
96
84
  name: str | None = None,
97
85
  fmt: Literal["default", "only_message"] | str = "only_message",
98
- level: str = "INFO",
99
- add_rank: bool | None = None,
86
+ level: str | None = None,
100
87
  sink=sys.stdout,
101
88
  colorize: bool = True,
102
89
  serialize: bool = False,
103
- ) -> CustomLogger:
90
+ ) -> KostylLogger:
104
91
  """
105
- Returns a bound logger with its own sink and formatting.
92
+ Creates and configures a logger with custom formatting and output.
93
+
94
+ The function automatically removes the default sink on first call and creates
95
+ an isolated logger with a unique identifier for message filtering.
96
+
97
+ Args:
98
+ name (str | None, optional): Logger channel name. If None, automatically
99
+ uses the calling function's filename. Defaults to None.
100
+ fmt (Literal["default", "only_message"] | str, optional): Log message format.
101
+ Available presets:
102
+ - "default": includes level, time, and channel
103
+ - "only_message": outputs only the message itself
104
+ Custom format strings are also supported. Defaults to "only_message".
105
+ level (str | None, optional): Logging level (TRACE, DEBUG, INFO, SUCCESS,
106
+ WARNING, ERROR, CRITICAL). If None, uses the KOSTYL_LOG_LEVEL environment
107
+ variable or "INFO" by default. Defaults to None.
108
+ sink: Output object for logs (file, sys.stdout, sys.stderr, etc.).
109
+ Defaults to sys.stdout.
110
+ colorize (bool, optional): Enable colored output formatting.
111
+ Defaults to True.
112
+ serialize (bool, optional): Serialize logs to JSON format.
113
+ Defaults to False.
114
+
115
+ Returns:
116
+ CustomLogger: Configured logger instance with additional methods
117
+ log_once() and warning_once().
118
+
119
+ Example:
120
+ >>> # Basic usage with automatic name detection
121
+ >>> logger = setup_logger()
122
+ >>> logger.info("Hello World")
106
123
 
107
- Note: If name=None, the caller's filename (similar to __file__) is used automatically.
124
+ >>> # With custom name and level
125
+ >>> logger = setup_logger(name="MyApp", level="DEBUG")
126
+
127
+ >>> # With custom format
128
+ >>> logger = setup_logger(
129
+ ... name="API",
130
+ ... fmt="{level} | {time:YYYY-MM-DD HH:mm:ss} | {message}"
131
+ ... )
108
132
 
109
- Format example: "{level} {time:MM-DD HH:mm:ss} [{extra[channel]}] {message}"
110
133
  """
111
134
  global _DEFAULT_SINK_REMOVED
112
135
  if not _DEFAULT_SINK_REMOVED:
113
136
  _base_logger.remove()
114
137
  _DEFAULT_SINK_REMOVED = True
115
138
 
116
- if name is None:
117
- base = _caller_filename()
118
- else:
119
- base = name
139
+ if level is None:
140
+ if KOSTYL_LOG_LEVEL not in {
141
+ "TRACE",
142
+ "DEBUG",
143
+ "INFO",
144
+ "SUCCESS",
145
+ "WARNING",
146
+ "ERROR",
147
+ "CRITICAL",
148
+ }:
149
+ level = "INFO"
150
+ else:
151
+ level = KOSTYL_LOG_LEVEL
120
152
 
121
- if (add_rank is None) or add_rank:
122
- try:
123
- add_rank = dist.is_available() and dist.is_initialized()
124
- except Exception:
125
- add_rank = False
126
-
127
- if add_rank:
128
- rank = dist.get_rank()
129
- channel = f"rank:{rank} - {base}"
153
+ if name is None:
154
+ channel = _caller_filename()
130
155
  else:
131
- channel = base
156
+ channel = name
132
157
 
133
158
  if fmt in _PRESETS:
134
159
  fmt = _PRESETS[fmt]
@@ -146,7 +171,7 @@ def setup_logger(
146
171
  filter=lambda r: r["extra"].get("logger_id") == logger_id,
147
172
  )
148
173
  logger = _base_logger.bind(logger_id=logger_id, channel=channel)
149
- return cast(CustomLogger, logger)
174
+ return cast(KostylLogger, logger)
150
175
 
151
176
 
152
177
  def log_incompatible_keys(
@@ -154,22 +179,12 @@ def log_incompatible_keys(
154
179
  incompatible_keys: _IncompatibleKeys
155
180
  | tuple[list[str], list[str]]
156
181
  | dict[str, list[str]],
157
- model_specific_msg: str = "",
182
+ postfix_msg: str = "",
158
183
  ) -> None:
159
184
  """
160
185
  Logs warnings for incompatible keys encountered during model loading or state dict operations.
161
186
 
162
187
  Note: If incompatible_keys is of an unsupported type, an error message is logged and the function returns early.
163
-
164
- Args:
165
- logger (Logger): The logger instance used to output warning messages.
166
- incompatible_keys (_IncompatibleKeys | tuple[list[str], list[str]] | dict[str, list[str]]): An object containing lists of missing and unexpected keys.
167
- model_specific_msg (str, optional): A custom message to append to the log output, typically
168
- indicating the model or context. Defaults to an empty string.
169
-
170
- Returns:
171
- None
172
-
173
188
  """
174
189
  incompatible_keys_: dict[str, list[str]] = {}
175
190
  match incompatible_keys:
@@ -192,5 +207,5 @@ def log_incompatible_keys(
192
207
  return
193
208
 
194
209
  for name, keys in incompatible_keys_.items():
195
- logger.warning(f"{name} {model_specific_msg}: {', '.join(keys)}")
210
+ logger.warning(f"{name} {postfix_msg}: {', '.join(keys)}")
196
211
  return
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.36
3
+ Version: 0.1.38
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -0,0 +1,40 @@
1
+ kostyl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ kostyl/ml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ kostyl/ml/base_uploader.py,sha256=KxHuohCcNK18kTVFBBqDu_IOQefluhSXOzwC56O66wc,484
4
+ kostyl/ml/configs/__init__.py,sha256=djYjLxA7riFcSibAKfWHns-BCESEPrqSz_ZY2rJO-cc,913
5
+ kostyl/ml/configs/hyperparams.py,sha256=lvtbvOFEoTBAJug7FR35xMQdPLgDQjRoP2fyDP-jD7E,3305
6
+ kostyl/ml/configs/mixins.py,sha256=xHHAoRoPbzP9ECFP9duzg6SzegHcoLI8Pr9NrLoWNHs,1411
7
+ kostyl/ml/configs/training_settings.py,sha256=wT9CHuLaKrLwonsc87Ee421EyFis_c9fqOgn9bSClm8,2747
8
+ kostyl/ml/data_collator.py,sha256=kxiaMDKwSKXGBtrF8yXxHcypf7t_6syU-NwO1LcX50k,4062
9
+ kostyl/ml/dist_utils.py,sha256=UFNMLEHc0A5F6KvTRG8GQPpRDwG4m5dvM__UvXNc2aQ,4526
10
+ kostyl/ml/integrations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ kostyl/ml/integrations/clearml/__init__.py,sha256=3TBVI-3fE9ZzuvOLEohW9TOK0BZTLD5JiYalAVDkocc,217
12
+ kostyl/ml/integrations/clearml/checkpoint_uploader.py,sha256=PupFi7jKROsIddOz7X5DhV7nUNdDZg5kKaaLvzdCHlY,4012
13
+ kostyl/ml/integrations/clearml/config_mixin.py,sha256=70QRicU7etiDzLX-MplqVX8uFm5siuPrM8KbTOriZnQ,3308
14
+ kostyl/ml/integrations/clearml/dataset_utils.py,sha256=eij_sr2KDhm8GxEbVbK8aBjPsuVvLl9-PIGGaKVgXLA,1729
15
+ kostyl/ml/integrations/clearml/loading_utils.py,sha256=NAMmB9NTGCXCHh-bR_nrQZyqImUVZqicNjExDyPM2mU,5224
16
+ kostyl/ml/integrations/clearml/version_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWaJTyq8,1210
17
+ kostyl/ml/integrations/lightning/__init__.py,sha256=r96os8kTuKIAymx3k9Td1JBrO2PH7nQAWUC54NsY5yY,392
18
+ kostyl/ml/integrations/lightning/callbacks/__init__.py,sha256=EnKkNwwNDZnEqKRlpY4FVrqP88ECPF6nlT2bSLUIKRk,194
19
+ kostyl/ml/integrations/lightning/callbacks/checkpoint.py,sha256=SfcaQRkXviMUej0UgrfXcqMDlRKYaAN3rgYCMKI97Os,18433
20
+ kostyl/ml/integrations/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
21
+ kostyl/ml/integrations/lightning/loggers/__init__.py,sha256=e51dszaoJbuzwBkbdugmuDsPldoSO4yaRgmZUg1Bdy0,71
22
+ kostyl/ml/integrations/lightning/loggers/tb_logger.py,sha256=CpjlcEIT187cJXJgRYafqfzvcnwPgPaVZ0vLUflIr7k,899
23
+ kostyl/ml/integrations/lightning/metrics_formatting.py,sha256=U6vdNENZLvp2dT1L3HqFKtXrHwGKoDXN93hvamPGHjM,1341
24
+ kostyl/ml/integrations/lightning/mixins.py,sha256=hVIsIUu6Iryrz6S7GQTqog9vNq8LQyjJd2aoJ5Ws6KU,5253
25
+ kostyl/ml/integrations/lightning/module.py,sha256=39hcVNZSGyj5tLpXyX8IoqMGWt5vf6-Bx5JnNJ2-Wag,5218
26
+ kostyl/ml/integrations/lightning/utils.py,sha256=DhLy_3JA5VyMQkB1v6xxRxDNHfisjXFYVjuIKPpO81M,1967
27
+ kostyl/ml/params_groups.py,sha256=nUyw5d06Pvy9QPiYtZzLYR87xwXqJLxbHthgQH8oSCM,3583
28
+ kostyl/ml/schedulers/__init__.py,sha256=VIo8MOP4w5Ll24XqFb3QGi2rKvys6c0dEFYPIdDoPlw,526
29
+ kostyl/ml/schedulers/base.py,sha256=bjmwgdZpnSqpCnHPnKC6MEiRO79cwxMJpZq-eQVNs2M,1353
30
+ kostyl/ml/schedulers/composite.py,sha256=ee4xlMDMMtjKPkbTF2ue9GTr9DuGCGjZWf11mHbi6aE,2387
31
+ kostyl/ml/schedulers/cosine.py,sha256=y8ylrgVOkVcr2-ExoqqNW--tdDX88TBYPQCOppIf2_M,8685
32
+ kostyl/ml/schedulers/linear.py,sha256=RnnnblRuRXP3LT03QVIHUaK2kNsiMP1AedrMoeyh3qk,5843
33
+ kostyl/ml/schedulers/plateau.py,sha256=N-hiostPtTR0W4xnEJYB_1dv0DRx39iufLkGUrSIoWE,11235
34
+ kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
35
+ kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
36
+ kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
37
+ kostyl/utils/logging.py,sha256=CgNFNogcK0hoZmygvBWlTcq5A3m2Pfv9eOAP_gwx0pM,6633
38
+ kostyl_toolkit-0.1.38.dist-info/WHEEL,sha256=e_m4S054HL0hyR3CpOk-b7Q7fDX6BuFkgL5OjAExXas,80
39
+ kostyl_toolkit-0.1.38.dist-info/METADATA,sha256=nz5AzlWjKBqh7OZCklk-efWZ1jVDihw3YrrpLyoII3k,4269
40
+ kostyl_toolkit-0.1.38.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: uv 0.9.24
2
+ Generator: uv 0.9.27
3
3
  Root-Is-Purelib: true
4
- Tag: py3-none-any
4
+ Tag: py3-none-any
@@ -1,5 +0,0 @@
1
- from .extensions import KostylLightningModule
2
- from .extensions import LightningCheckpointLoaderMixin
3
-
4
-
5
- __all__ = ["KostylLightningModule", "LightningCheckpointLoaderMixin"]
@@ -1,5 +0,0 @@
1
- from .custom_module import KostylLightningModule
2
- from .pretrained_model import LightningCheckpointLoaderMixin
3
-
4
-
5
- __all__ = ["KostylLightningModule", "LightningCheckpointLoaderMixin"]