kostyl-toolkit 0.1.22__tar.gz → 0.1.24__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/PKG-INFO +1 -1
  2. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/clearml/pulling_utils.py +1 -1
  3. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/callbacks/__init__.py +1 -1
  4. kostyl_toolkit-0.1.24/kostyl/ml/lightning/callbacks/checkpoint.py +327 -0
  5. kostyl_toolkit-0.1.22/kostyl/ml/lightning/callbacks/registry_uploading.py → kostyl_toolkit-0.1.24/kostyl/ml/lightning/callbacks/registry_uploader.py +55 -18
  6. kostyl_toolkit-0.1.24/kostyl/ml/lightning/extenstions/pretrained_model.py +105 -0
  7. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/pyproject.toml +1 -1
  8. kostyl_toolkit-0.1.22/kostyl/ml/lightning/callbacks/checkpoint.py +0 -56
  9. kostyl_toolkit-0.1.22/kostyl/ml/lightning/extenstions/pretrained_model.py +0 -202
  10. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/README.md +0 -0
  11. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/__init__.py +0 -0
  12. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/__init__.py +0 -0
  13. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/clearml/__init__.py +0 -0
  14. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/clearml/dataset_utils.py +0 -0
  15. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/clearml/logging_utils.py +0 -0
  16. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/configs/__init__.py +0 -0
  17. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/configs/base_model.py +0 -0
  18. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/configs/hyperparams.py +0 -0
  19. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/configs/training_settings.py +0 -0
  20. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/dist_utils.py +0 -0
  21. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/__init__.py +0 -0
  22. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/callbacks/early_stopping.py +0 -0
  23. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/extenstions/__init__.py +0 -0
  24. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/extenstions/custom_module.py +0 -0
  25. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/loggers/__init__.py +0 -0
  26. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/loggers/tb_logger.py +0 -0
  27. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/steps_estimation.py +0 -0
  28. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/metrics_formatting.py +0 -0
  29. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/params_groups.py +0 -0
  30. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/schedulers/__init__.py +0 -0
  31. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/schedulers/base.py +0 -0
  32. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/schedulers/composite.py +0 -0
  33. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/ml/schedulers/cosine.py +0 -0
  34. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/utils/__init__.py +0 -0
  35. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/utils/dict_manipulations.py +0 -0
  36. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/utils/fs.py +0 -0
  37. {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.24}/kostyl/utils/logging.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.22
3
+ Version: 0.1.24
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -95,7 +95,7 @@ def get_model_from_clearml[
95
95
  raise ValueError(
96
96
  f"Model class {model.__name__} is not compatible with Lightning checkpoints."
97
97
  )
98
- model_instance = model.from_lighting_checkpoint(local_path, **kwargs)
98
+ model_instance = model.from_lightning_checkpoint(local_path, **kwargs)
99
99
  else:
100
100
  raise ValueError(
101
101
  f"Unsupported model format for path: {local_path}. "
@@ -1,6 +1,6 @@
1
1
  from .checkpoint import setup_checkpoint_callback
2
2
  from .early_stopping import setup_early_stopping_callback
3
- from .registry_uploading import ClearMLRegistryUploaderCallback
3
+ from .registry_uploader import ClearMLRegistryUploaderCallback
4
4
 
5
5
 
6
6
  __all__ = [
@@ -0,0 +1,327 @@
1
+ from datetime import timedelta
2
+ from pathlib import Path
3
+ from shutil import rmtree
4
+ from typing import Literal
5
+ from typing import cast
6
+
7
+ from lightning.fabric.utilities.types import _PATH
8
+ from lightning.pytorch.callbacks import ModelCheckpoint
9
+
10
+ from kostyl.ml.configs import CheckpointConfig
11
+ from kostyl.ml.dist_utils import is_main_process
12
+ from kostyl.utils import setup_logger
13
+
14
+ from .registry_uploader import RegistryUploaderCallback
15
+
16
+
17
+ logger = setup_logger("callbacks/checkpoint.py")
18
+
19
+
20
+ class CustomModelCheckpoint(ModelCheckpoint):
21
+ r"""
22
+ Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
23
+ :class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
24
+ checkpoint.
25
+
26
+ After training finishes, use :attr:`best_model_path` to retrieve the path to the
27
+ best checkpoint file and :attr:`best_model_score` to get its score.
28
+
29
+ .. note::
30
+ When using manual optimization with ``every_n_train_steps``, you should save the model state
31
+ in your ``training_step`` before the optimizer step if you want the checkpoint to reflect
32
+ the pre-optimization state. Example:
33
+
34
+ .. code-block:: python
35
+
36
+ def training_step(self, batch, batch_idx):
37
+ # ... forward pass, loss calculation, backward pass ...
38
+
39
+ # Save model state before optimization
40
+ if not hasattr(self, 'saved_models'):
41
+ self.saved_models = {}
42
+ self.saved_models[batch_idx] = {
43
+ k: v.detach().clone()
44
+ for k, v in self.layer.state_dict().items()
45
+ }
46
+
47
+ # Then perform optimization
48
+ optimizer.zero_grad()
49
+ self.manual_backward(loss)
50
+ optimizer.step()
51
+
52
+ # Optional: Clean up old states to save memory
53
+ if batch_idx > 10: # Keep last 10 states
54
+ del self.saved_models[batch_idx - 10]
55
+
56
+ Args:
57
+ dirpath: Directory to save the model file.
58
+ Example: ``dirpath='my/path/'``.
59
+
60
+ .. warning::
61
+ In a distributed environment like DDP, it's recommended to provide a `dirpath` to avoid race conditions.
62
+ When using manual optimization with ``every_n_train_steps``, make sure to save the model state
63
+ in your training loop as shown in the example above.
64
+
65
+ Can be remote file paths such as `s3://mybucket/path/` or 'hdfs://path/'
66
+ (default: ``None``). If dirpath is ``None``, we only keep the ``k`` best checkpoints
67
+ in memory, and do not save anything to disk.
68
+
69
+ filename: Checkpoint filename. Can contain named formatting options to be auto-filled.
70
+ If no name is provided, it will be ``None`` and the checkpoint will be saved to
71
+ ``{epoch}``.and if the Trainer uses a logger, the path will also contain logger name and version.
72
+
73
+ filename: checkpoint filename. Can contain named formatting options to be auto-filled.
74
+
75
+ Example::
76
+
77
+ # save any arbitrary metrics like `val_loss`, etc. in name
78
+ # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
79
+ >>> checkpoint_callback = ModelCheckpoint(
80
+ ... dirpath='my/path',
81
+ ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
82
+ ... )
83
+
84
+ By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``, where "epoch" and "step" match
85
+ the number of finished epoch and optimizer steps respectively.
86
+ monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch.
87
+ verbose: verbosity mode. Default: ``False``.
88
+ save_last: When ``True``, saves a `last.ckpt` copy whenever a checkpoint file gets saved. Can be set to
89
+ ``'link'`` on a local filesystem to create a symbolic link. This allows accessing the latest checkpoint
90
+ in a deterministic manner. Default: ``None``.
91
+ save_top_k: if ``save_top_k == k``,
92
+ the best k models according to the quantity monitored will be saved.
93
+ If ``save_top_k == 0``, no models are saved.
94
+ If ``save_top_k == -1``, all models are saved.
95
+ Please note that the monitors are checked every ``every_n_epochs`` epochs.
96
+ If ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, and the filename remains
97
+ unchanged, the name of the saved file will be appended with a version count starting with ``v1`` to avoid
98
+ collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k
99
+ ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid
100
+ collisions.
101
+ save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``False``.
102
+ mode: one of {min, max}.
103
+ If ``save_top_k != 0``, the decision to overwrite the current save file is made
104
+ based on either the maximization or the minimization of the monitored quantity.
105
+ For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc.
106
+ auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name.
107
+ For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve
108
+ to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/``
109
+ as this will result in extra folders.
110
+ For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False``
111
+ save_weights_only: if ``True``, then only the model's weights will be
112
+ saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
113
+ every_n_train_steps: How many training steps to wait before saving a checkpoint. This does not take into account
114
+ the steps of the current epoch. If ``every_n_train_steps == None or every_n_train_steps == 0``,
115
+ no checkpoints
116
+ will be saved during training. Mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
117
+
118
+ .. note::
119
+ When using with manual optimization, the checkpoint will be saved after the optimizer step by default.
120
+ To save the model state before the optimizer step, you need to save the model state in your
121
+ ``training_step`` before calling ``optimizer.step()``. See the class docstring for an example.
122
+ train_time_interval: Checkpoints are monitored at the specified time interval.
123
+ For all practical purposes, this cannot be smaller than the amount
124
+ of time it takes to process a single training batch. This is not
125
+ guaranteed to execute at the exact time specified, but should be close.
126
+ This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
127
+ every_n_epochs: Number of epochs between checkpoints.
128
+ This value must be ``None`` or non-negative.
129
+ To disable saving top-k checkpoints, set ``every_n_epochs = 0``.
130
+ This argument does not impact the saving of ``save_last=True`` checkpoints.
131
+ If all of ``every_n_epochs``, ``every_n_train_steps`` and
132
+ ``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch
133
+ (equivalent to ``every_n_epochs = 1``).
134
+ If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``,
135
+ saving at the end of each epoch is disabled
136
+ (equivalent to ``every_n_epochs = 0``).
137
+ This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``.
138
+ Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and
139
+ ``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
140
+ will only save checkpoints at epochs 0 < E <= N
141
+ where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
142
+ save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch.
143
+ If ``True``, checkpoints are saved at the end of every training epoch.
144
+ If ``False``, checkpoints are saved at the end of validation.
145
+ If ``None`` (default), checkpointing behavior is determined based on training configuration.
146
+ If ``val_check_interval`` is a str, dict, or `timedelta` (time-based), checkpointing is performed after
147
+ validation.
148
+ If ``check_val_every_n_epoch != 1``, checkpointing will not be performed at the end of
149
+ every training epoch. If there are no validation batches of data, checkpointing will occur at the
150
+ end of the training epoch. If there is a non-default number of validation runs per training epoch
151
+ (``val_check_interval != 1``), checkpointing is performed after validation.
152
+ enable_version_counter: Whether to append a version to the existing file name.
153
+ If ``False``, then the checkpoint files will be overwritten.
154
+
155
+ Note:
156
+ For extra customization, ModelCheckpoint includes the following attributes:
157
+
158
+ - ``CHECKPOINT_JOIN_CHAR = "-"``
159
+ - ``CHECKPOINT_EQUALS_CHAR = "="``
160
+ - ``CHECKPOINT_NAME_LAST = "last"``
161
+ - ``FILE_EXTENSION = ".ckpt"``
162
+ - ``STARTING_VERSION = 1``
163
+
164
+ For example, you can change the default last checkpoint name by doing
165
+ ``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``
166
+
167
+ If you want to checkpoint every N hours, every M train batches, and/or every K val epochs,
168
+ then you should create multiple ``ModelCheckpoint`` callbacks.
169
+
170
+ If the checkpoint's ``dirpath`` changed from what it was before while resuming the training,
171
+ only ``best_model_path`` will be reloaded and a warning will be issued.
172
+
173
+ If you provide a ``filename`` on a mounted device where changing permissions is not allowed (causing ``chmod``
174
+ to raise a ``PermissionError``), install `fsspec>=2025.5.0`. Then the error is caught, the file's permissions
175
+ remain unchanged, and the checkpoint is still saved. Otherwise, no checkpoint will be saved and training stops.
176
+
177
+ Raises:
178
+ MisconfigurationException:
179
+ If ``save_top_k`` is smaller than ``-1``,
180
+ if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or
181
+ if ``mode`` is none of ``"min"`` or ``"max"``.
182
+ ValueError:
183
+ If ``trainer.save_checkpoint`` is ``None``.
184
+
185
+ Example::
186
+
187
+ >>> from lightning.pytorch import Trainer
188
+ >>> from lightning.pytorch.callbacks import ModelCheckpoint
189
+
190
+ # saves checkpoints to 'my/path/' at every epoch
191
+ >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
192
+ >>> trainer = Trainer(callbacks=[checkpoint_callback])
193
+
194
+ # save epoch and val_loss in name
195
+ # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
196
+ >>> checkpoint_callback = ModelCheckpoint(
197
+ ... monitor='val_loss',
198
+ ... dirpath='my/path/',
199
+ ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
200
+ ... )
201
+
202
+ # save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
203
+ # or Neptune, due to the presence of characters like '=' or '/')
204
+ # saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
205
+ >>> checkpoint_callback = ModelCheckpoint(
206
+ ... monitor='val/loss',
207
+ ... dirpath='my/path/',
208
+ ... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
209
+ ... auto_insert_metric_name=False
210
+ ... )
211
+
212
+ # retrieve the best checkpoint after training
213
+ >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
214
+ >>> trainer = Trainer(callbacks=[checkpoint_callback])
215
+ >>> model = ... # doctest: +SKIP
216
+ >>> trainer.fit(model) # doctest: +SKIP
217
+ >>> print(checkpoint_callback.best_model_path) # doctest: +SKIP
218
+
219
+ .. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the
220
+ following arguments:
221
+
222
+ *monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval*
223
+
224
+ Read more: :ref:`Persisting Callback State <extensions/callbacks_state:save callback state>`
225
+
226
+ """ # noqa: D205
227
+
228
+ def __init__( # noqa: D107
229
+ self,
230
+ dirpath: _PATH | None = None,
231
+ filename: str | None = None,
232
+ monitor: str | None = None,
233
+ verbose: bool = False,
234
+ save_last: bool | Literal["link"] | None = None,
235
+ save_top_k: int = 1,
236
+ save_on_exception: bool = False,
237
+ save_weights_only: bool = False,
238
+ mode: str = "min",
239
+ auto_insert_metric_name: bool = True,
240
+ every_n_train_steps: int | None = None,
241
+ train_time_interval: timedelta | None = None,
242
+ every_n_epochs: int | None = None,
243
+ save_on_train_epoch_end: bool | None = None,
244
+ enable_version_counter: bool = True,
245
+ registry_uploader_callback: RegistryUploaderCallback | None = None,
246
+ ) -> None:
247
+ super().__init__(
248
+ dirpath=dirpath,
249
+ filename=filename,
250
+ monitor=monitor,
251
+ verbose=verbose,
252
+ save_last=save_last,
253
+ save_top_k=save_top_k,
254
+ save_on_exception=save_on_exception,
255
+ save_weights_only=save_weights_only,
256
+ mode=mode,
257
+ auto_insert_metric_name=auto_insert_metric_name,
258
+ every_n_train_steps=every_n_train_steps,
259
+ train_time_interval=train_time_interval,
260
+ every_n_epochs=every_n_epochs,
261
+ save_on_train_epoch_end=save_on_train_epoch_end,
262
+ enable_version_counter=enable_version_counter,
263
+ )
264
+ self.registry_uploader_callback = registry_uploader_callback
265
+ self._custom_best_model_path = cast(str, self.best_model_path)
266
+ return
267
+
268
+ @property
269
+ def best_model_path(self) -> str:
270
+ """Best model path."""
271
+ return self._custom_best_model_path
272
+
273
+ @best_model_path.setter
274
+ def best_model_path(self, value: str) -> None:
275
+ self._custom_best_model_path = value
276
+ if self.registry_uploader_callback is not None:
277
+ self.registry_uploader_callback.best_model_path = value
278
+ return
279
+
280
+
281
+ def setup_checkpoint_callback(
282
+ dirpath: Path,
283
+ ckpt_cfg: CheckpointConfig,
284
+ save_weights_only: bool = True,
285
+ registry_uploader_callback: RegistryUploaderCallback | None = None,
286
+ ) -> CustomModelCheckpoint:
287
+ """
288
+ Sets up a ModelCheckpoint callback for PyTorch Lightning.
289
+
290
+ This function prepares a checkpoint directory and configures a ModelCheckpoint
291
+ callback based on the provided configuration. If the directory already exists,
292
+ it is removed (only by the main process) to ensure a clean start. Otherwise,
293
+ the directory is created.
294
+
295
+ Args:
296
+ dirpath (Path): The path to the directory where checkpoints will be saved.
297
+ ckpt_cfg (CheckpointConfig): Configuration object containing checkpoint
298
+ settings such as filename, save_top_k, monitor, and mode.
299
+ save_weights_only (bool, optional): Whether to save only the model weights
300
+ or the full model. Defaults to True.
301
+ registry_uploader_callback (RegistryUploaderCallback | None, optional):
302
+ An optional callback for uploading checkpoints to a registry. Defaults to None.
303
+
304
+ Returns:
305
+ ModelCheckpoint: The configured ModelCheckpoint callback instance.
306
+
307
+ """
308
+ if dirpath.exists():
309
+ if is_main_process():
310
+ logger.warning(f"Checkpoint directory {dirpath} already exists.")
311
+ rmtree(dirpath)
312
+ logger.warning(f"Removed existing checkpoint directory {dirpath}.")
313
+ else:
314
+ logger.info(f"Creating checkpoint directory {dirpath}.")
315
+ dirpath.mkdir(parents=True, exist_ok=True)
316
+
317
+ checkpoint_callback = CustomModelCheckpoint(
318
+ dirpath=dirpath,
319
+ filename=ckpt_cfg.filename,
320
+ save_top_k=ckpt_cfg.save_top_k,
321
+ monitor=ckpt_cfg.monitor,
322
+ mode=ckpt_cfg.mode,
323
+ verbose=True,
324
+ save_weights_only=save_weights_only,
325
+ registry_uploader_callback=registry_uploader_callback,
326
+ )
327
+ return checkpoint_callback
@@ -1,3 +1,7 @@
1
+ from abc import ABC
2
+ from abc import abstractmethod
3
+ from collections.abc import Callable
4
+ from functools import partial
1
5
  from typing import Literal
2
6
  from typing import override
3
7
 
@@ -5,7 +9,6 @@ from clearml import OutputModel
5
9
  from clearml import Task
6
10
  from lightning import Trainer
7
11
  from lightning.pytorch.callbacks import Callback
8
- from lightning.pytorch.callbacks import ModelCheckpoint
9
12
 
10
13
  from kostyl.ml.clearml.logging_utils import find_version_in_tags
11
14
  from kostyl.ml.clearml.logging_utils import increment_version
@@ -16,13 +19,31 @@ from kostyl.utils.logging import setup_logger
16
19
  logger = setup_logger()
17
20
 
18
21
 
19
- class ClearMLRegistryUploaderCallback(Callback):
22
+ class RegistryUploaderCallback(Callback, ABC):
23
+ """Abstract Lightning callback responsible for tracking and uploading the best-performing model checkpoint."""
24
+
25
+ @property
26
+ @abstractmethod
27
+ def best_model_path(self) -> str:
28
+ """Return the file system path pointing to the best model artifact produced during training."""
29
+ raise NotImplementedError
30
+
31
+ @best_model_path.setter
32
+ @abstractmethod
33
+ def best_model_path(self, value: str) -> None:
34
+ raise NotImplementedError
35
+
36
+ @abstractmethod
37
+ def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
38
+ raise NotImplementedError
39
+
40
+
41
+ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
20
42
  """PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
21
43
 
22
44
  def __init__(
23
45
  self,
24
46
  task: Task,
25
- ckpt_callback: ModelCheckpoint,
26
47
  output_model_name: str,
27
48
  output_model_tags: list[str] | None = None,
28
49
  verbose: bool = True,
@@ -56,7 +77,6 @@ class ClearMLRegistryUploaderCallback(Callback):
56
77
  output_model_tags = []
57
78
 
58
79
  self.task = task
59
- self.ckpt_callback = ckpt_callback
60
80
  self.output_model_name = output_model_name
61
81
  self.output_model_tags = output_model_tags
62
82
  self.config_dict = config_dict
@@ -66,7 +86,22 @@ class ClearMLRegistryUploaderCallback(Callback):
66
86
  self.enable_tag_versioning = enable_tag_versioning
67
87
 
68
88
  self._output_model: OutputModel | None = None
69
- self._last_best_model_path: str = ""
89
+ self._last_uploaded_model_path: str = ""
90
+ self._best_model_path: str = ""
91
+ self._upload_callback: Callable | None = None
92
+ return
93
+
94
+ @property
95
+ @override
96
+ def best_model_path(self) -> str:
97
+ return self._best_model_path
98
+
99
+ @best_model_path.setter
100
+ @override
101
+ def best_model_path(self, value: str) -> None:
102
+ self._best_model_path = value
103
+ if self._upload_callback is not None:
104
+ self._upload_callback()
70
105
  return
71
106
 
72
107
  def _create_output_model(self, pl_module: "KostylLightningModule") -> OutputModel:
@@ -98,27 +133,29 @@ class ClearMLRegistryUploaderCallback(Callback):
98
133
  label_enumeration=self.label_enumeration,
99
134
  )
100
135
 
136
+ @override
101
137
  def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
102
- current_best = self.ckpt_callback.best_model_path
103
-
104
- if not current_best:
105
- if self.verbose:
106
- logger.info("No best model found yet to upload")
107
- return
108
-
109
- if current_best == self._last_best_model_path:
110
- if self.verbose:
111
- logger.info("Best model unchanged since last upload")
138
+ if not self._best_model_path or (
139
+ self._best_model_path == self._last_uploaded_model_path
140
+ ):
141
+ if not self._best_model_path:
142
+ if self.verbose:
143
+ logger.info("No best model found yet to upload")
144
+ elif self._best_model_path == self._last_uploaded_model_path:
145
+ if self.verbose:
146
+ logger.info("Best model unchanged since last upload")
147
+ self._upload_callback = partial(self._upload_best_checkpoint, pl_module)
112
148
  return
149
+ self._upload_callback = None
113
150
 
114
151
  if self._output_model is None:
115
152
  self._output_model = self._create_output_model(pl_module)
116
153
 
117
154
  if self.verbose:
118
- logger.info(f"Uploading best model from {current_best}")
155
+ logger.info(f"Uploading best model from {self._best_model_path}")
119
156
 
120
157
  self._output_model.update_weights(
121
- current_best,
158
+ self._best_model_path,
122
159
  auto_delete_file=False,
123
160
  async_enable=False,
124
161
  )
@@ -130,7 +167,7 @@ class ClearMLRegistryUploaderCallback(Callback):
130
167
  config = self.config_dict
131
168
  self._output_model.update_design(config_dict=config)
132
169
 
133
- self._last_best_model_path = current_best
170
+ self._last_uploaded_model_path = self._best_model_path
134
171
  return
135
172
 
136
173
  @override
@@ -0,0 +1,105 @@
1
+ from pathlib import Path
2
+ from typing import Any
3
+ from typing import cast
4
+
5
+ import torch
6
+ from transformers import PretrainedConfig
7
+ from transformers import PreTrainedModel
8
+
9
+ from kostyl.utils.logging import setup_logger
10
+
11
+
12
+ logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
13
+
14
+
15
+ class LightningCheckpointLoaderMixin(PreTrainedModel):
16
+ """A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
17
+
18
+ @classmethod
19
+ def from_lightning_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
20
+ cls: type[TModelInstance],
21
+ checkpoint_path: str | Path,
22
+ config_key: str = "config",
23
+ weights_prefix: str = "model.",
24
+ **kwargs: Any,
25
+ ) -> TModelInstance:
26
+ """
27
+ Load a model from a Lightning checkpoint file.
28
+
29
+ This class method loads a pretrained model from a PyTorch Lightning checkpoint file (.ckpt).
30
+ It extracts the model configuration from the checkpoint, instantiates the model, and loads
31
+ the state dictionary, handling any incompatible keys.
32
+
33
+ Note:
34
+ The method uses `torch.load` with `weights_only=False` and `mmap=True` for loading.
35
+ Incompatible keys (missing, unexpected, mismatched) are collected and optionally logged.
36
+
37
+ Args:
38
+ cls (type["LightningPretrainedModelMixin"]): The class of the model to instantiate.
39
+ checkpoint_path (str | Path): Path to the checkpoint file. Must be a .ckpt file.
40
+ config_key (str, optional): Key in the checkpoint dictionary where the config is stored.
41
+ Defaults to "config".
42
+ weights_prefix (str, optional): Prefix to strip from state dict keys. Defaults to "model.".
43
+ If not empty and doesn't end with ".", a "." is appended.
44
+ kwargs: Additional keyword arguments to pass to the model's `from_pretrained` method.
45
+
46
+ Returns:
47
+ TModelInstance: The loaded model instance.
48
+
49
+ Raises:
50
+ ValueError: If checkpoint_path is a directory, not a .ckpt file, or invalid.
51
+ FileNotFoundError: If the checkpoint file does not exist.
52
+
53
+ """
54
+ if isinstance(checkpoint_path, str):
55
+ checkpoint_path = Path(checkpoint_path)
56
+
57
+ if checkpoint_path.is_dir():
58
+ raise ValueError(f"{checkpoint_path} is a directory")
59
+ if not checkpoint_path.exists():
60
+ raise FileNotFoundError(f"{checkpoint_path} does not exist")
61
+ if checkpoint_path.suffix != ".ckpt":
62
+ raise ValueError(f"{checkpoint_path} is not a .ckpt file")
63
+
64
+ checkpoint_dict = torch.load(
65
+ checkpoint_path,
66
+ map_location="cpu",
67
+ weights_only=False,
68
+ mmap=True,
69
+ )
70
+
71
+ # 1. Восстанавливаем конфиг
72
+ config_cls = cast(type[PretrainedConfig], cls.config_class)
73
+ config_dict = checkpoint_dict[config_key]
74
+ config_dict.update(kwargs)
75
+ config = config_cls.from_dict(config_dict)
76
+
77
+ kwargs_for_model: dict[str, Any] = {}
78
+ for key, value in kwargs.items():
79
+ if not hasattr(config, key):
80
+ kwargs_for_model[key] = value
81
+
82
+ raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
83
+
84
+ if weights_prefix:
85
+ if not weights_prefix.endswith("."):
86
+ weights_prefix = weights_prefix + "."
87
+ state_dict: dict[str, torch.Tensor] = {}
88
+
89
+ for key, value in raw_state_dict.items():
90
+ if key.startswith(weights_prefix):
91
+ new_key = key[len(weights_prefix) :]
92
+ state_dict[new_key] = value
93
+ else:
94
+ state_dict[key] = value
95
+ else:
96
+ state_dict = raw_state_dict
97
+
98
+ model = cls.from_pretrained(
99
+ pretrained_model_name_or_path=None,
100
+ config=config,
101
+ state_dict=state_dict,
102
+ **kwargs_for_model,
103
+ )
104
+
105
+ return model
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "kostyl-toolkit"
3
- version = "0.1.22"
3
+ version = "0.1.24"
4
4
  description = "Kickass Orchestration System for Training, Yielding & Logging "
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -1,56 +0,0 @@
1
- from pathlib import Path
2
- from shutil import rmtree
3
-
4
- from lightning.pytorch.callbacks import ModelCheckpoint
5
-
6
- from kostyl.ml.configs import CheckpointConfig
7
- from kostyl.ml.dist_utils import is_main_process
8
- from kostyl.utils import setup_logger
9
-
10
-
11
- logger = setup_logger("callbacks/checkpoint.py")
12
-
13
-
14
- def setup_checkpoint_callback(
15
- dirpath: Path,
16
- ckpt_cfg: CheckpointConfig,
17
- save_weights_only: bool = True,
18
- ) -> ModelCheckpoint:
19
- """
20
- Sets up a ModelCheckpoint callback for PyTorch Lightning.
21
-
22
- This function prepares a checkpoint directory and configures a ModelCheckpoint
23
- callback based on the provided configuration. If the directory already exists,
24
- it is removed (only by the main process) to ensure a clean start. Otherwise,
25
- the directory is created.
26
-
27
- Args:
28
- dirpath (Path): The path to the directory where checkpoints will be saved.
29
- ckpt_cfg (CheckpointConfig): Configuration object containing checkpoint
30
- settings such as filename, save_top_k, monitor, and mode.
31
- save_weights_only (bool, optional): Whether to save only the model weights
32
- or the full model. Defaults to True.
33
-
34
- Returns:
35
- ModelCheckpoint: The configured ModelCheckpoint callback instance.
36
-
37
- """
38
- if dirpath.exists():
39
- if is_main_process():
40
- logger.warning(f"Checkpoint directory {dirpath} already exists.")
41
- rmtree(dirpath)
42
- logger.warning(f"Removed existing checkpoint directory {dirpath}.")
43
- else:
44
- logger.info(f"Creating checkpoint directory {dirpath}.")
45
- dirpath.mkdir(parents=True, exist_ok=True)
46
-
47
- checkpoint_callback = ModelCheckpoint(
48
- dirpath=dirpath,
49
- filename=ckpt_cfg.filename,
50
- save_top_k=ckpt_cfg.save_top_k,
51
- monitor=ckpt_cfg.monitor,
52
- mode=ckpt_cfg.mode,
53
- verbose=True,
54
- save_weights_only=save_weights_only,
55
- )
56
- return checkpoint_callback
@@ -1,202 +0,0 @@
1
- from pathlib import Path
2
- from typing import Any
3
- from typing import cast
4
-
5
- import torch
6
- from transformers import PretrainedConfig
7
- from transformers import PreTrainedModel
8
-
9
-
10
- try:
11
- from peft import PeftConfig
12
- except ImportError:
13
- PeftConfig = None # ty: ignore
14
-
15
- from kostyl.utils.logging import log_incompatible_keys
16
- from kostyl.utils.logging import setup_logger
17
- from torch import nn
18
-
19
-
20
- logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
21
-
22
-
23
- class LightningCheckpointLoaderMixin(PreTrainedModel):
24
- """A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
25
-
26
- @classmethod
27
- def from_lighting_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
28
- cls: type[TModelInstance],
29
- checkpoint_path: str | Path,
30
- config_key: str = "config",
31
- weights_prefix: str = "model.",
32
- should_log_incompatible_keys: bool = True,
33
- **kwargs: Any,
34
- ) -> TModelInstance:
35
- """
36
- Load a model from a Lightning checkpoint file.
37
-
38
- This class method loads a pretrained model from a PyTorch Lightning checkpoint file (.ckpt).
39
- It extracts the model configuration from the checkpoint, instantiates the model, and loads
40
- the state dictionary, handling any incompatible keys.
41
-
42
- Note:
43
- The method uses `torch.load` with `weights_only=False` and `mmap=True` for loading.
44
- Incompatible keys (missing, unexpected, mismatched) are collected and optionally logged.
45
-
46
- Args:
47
- cls (type["LightningPretrainedModelMixin"]): The class of the model to instantiate.
48
- checkpoint_path (str | Path): Path to the checkpoint file. Must be a .ckpt file.
49
- config_key (str, optional): Key in the checkpoint dictionary where the config is stored.
50
- Defaults to "config".
51
- weights_prefix (str, optional): Prefix to strip from state dict keys. Defaults to "model.".
52
- If not empty and doesn't end with ".", a "." is appended.
53
- should_log_incompatible_keys (bool, optional): Whether to log incompatible keys. Defaults to True.
54
- **kwargs: Additional keyword arguments to pass to the model loading method.
55
-
56
- Returns:
57
- TModelInstance: The loaded model instance.
58
-
59
- Raises:
60
- ValueError: If checkpoint_path is a directory, not a .ckpt file, or invalid.
61
- FileNotFoundError: If the checkpoint file does not exist.
62
-
63
- """
64
- if isinstance(checkpoint_path, str):
65
- checkpoint_path = Path(checkpoint_path)
66
-
67
- if checkpoint_path.is_dir():
68
- raise ValueError(f"{checkpoint_path} is a directory")
69
- if not checkpoint_path.exists():
70
- raise FileNotFoundError(f"{checkpoint_path} does not exist")
71
- if checkpoint_path.suffix != ".ckpt":
72
- raise ValueError(f"{checkpoint_path} is not a .ckpt file")
73
-
74
- checkpoint_dict = torch.load(
75
- checkpoint_path,
76
- map_location="cpu",
77
- weights_only=False,
78
- mmap=True,
79
- )
80
-
81
- # 1. Восстанавливаем конфиг
82
- config_cls = cast(type[PretrainedConfig], cls.config_class)
83
- config_dict = checkpoint_dict[config_key]
84
- config_dict.update(kwargs)
85
- config = config_cls.from_dict(config_dict)
86
-
87
- kwargs_for_model: dict[str, Any] = {}
88
- for key, value in kwargs.items():
89
- if not hasattr(config, key):
90
- kwargs_for_model[key] = value
91
-
92
- with torch.device("meta"):
93
- model = cls(config, **kwargs_for_model)
94
-
95
- # PEFT-адаптеры (оставляю твою логику как есть)
96
- if "peft_config" in checkpoint_dict:
97
- if PeftConfig is None:
98
- raise ImportError(
99
- "peft is not installed. Please install it to load PEFT models."
100
- )
101
- for name, adapter_dict in checkpoint_dict["peft_config"].items():
102
- peft_cfg = PeftConfig.from_peft_type(**adapter_dict)
103
- model.add_adapter(peft_cfg, adapter_name=name)
104
-
105
- incompatible_keys: dict[str, list[str]] = {}
106
-
107
- raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
108
-
109
- if weights_prefix:
110
- if not weights_prefix.endswith("."):
111
- weights_prefix = weights_prefix + "."
112
- state_dict: dict[str, torch.Tensor] = {}
113
- mismatched_keys: list[str] = []
114
-
115
- for key, value in raw_state_dict.items():
116
- if key.startswith(weights_prefix):
117
- new_key = key[len(weights_prefix) :]
118
- state_dict[new_key] = value
119
- else:
120
- mismatched_keys.append(key)
121
-
122
- if mismatched_keys:
123
- incompatible_keys["mismatched_keys"] = mismatched_keys
124
- else:
125
- state_dict = raw_state_dict
126
-
127
- # 5. Логика base_model_prefix как в HF:
128
- # поддержка загрузки базовой модели <-> модели с головой
129
- #
130
- # cls.base_model_prefix обычно "model" / "bert" / "encoder" и т.п.
131
- base_prefix: str = getattr(cls, "base_model_prefix", "") or ""
132
- model_to_load: nn.Module = model
133
-
134
- if base_prefix:
135
- prefix_with_dot = base_prefix + "."
136
- loaded_keys = list(state_dict.keys())
137
- full_model_state = model.state_dict()
138
- expected_keys = list(full_model_state.keys())
139
-
140
- has_prefix_module = any(k.startswith(prefix_with_dot) for k in loaded_keys)
141
- expects_prefix_module = any(
142
- k.startswith(prefix_with_dot) for k in expected_keys
143
- )
144
-
145
- # Кейc 1: загружаем базовую модель в модель с головой.
146
- # Пример: StaticEmbeddingsForSequenceClassification (имеет .model)
147
- # state_dict с ключами "embeddings.weight", "token_pos_weights", ...
148
- if (
149
- hasattr(model, base_prefix)
150
- and not has_prefix_module
151
- and expects_prefix_module
152
- ):
153
- # Веса без префикса -> грузим только в model.<base_prefix>
154
- model_to_load = getattr(model, base_prefix)
155
-
156
- # Кейc 2: загружаем чекпоинт модели с головой в базовую модель.
157
- # Пример: BertModel, а в state_dict ключи "bert.encoder.layer.0..."
158
- elif (
159
- not hasattr(model, base_prefix)
160
- and has_prefix_module
161
- and not expects_prefix_module
162
- ):
163
- new_state_dict: dict[str, torch.Tensor] = {}
164
- for key, value in state_dict.items():
165
- if key.startswith(prefix_with_dot):
166
- new_key = key[len(prefix_with_dot) :]
167
- else:
168
- new_key = key
169
- new_state_dict[new_key] = value
170
- state_dict = new_state_dict
171
-
172
- load_result = model_to_load.load_state_dict(
173
- state_dict, strict=False, assign=True
174
- )
175
- missing_keys, unexpected_keys = (
176
- load_result.missing_keys,
177
- load_result.unexpected_keys,
178
- )
179
-
180
- # Если мы грузили только в base-подмодуль, расширим missing_keys
181
- # до полного списка (base + голова), как в старых версиях HF.
182
- if model_to_load is not model and base_prefix:
183
- base_keys = set(model_to_load.state_dict().keys())
184
- # Приводим ключи полной модели к "безпрефиксному" виду
185
- head_like_keys = set()
186
- prefix_with_dot = base_prefix + "."
187
- for k in model.state_dict().keys():
188
- if k.startswith(prefix_with_dot):
189
- # отрезаем "model."
190
- head_like_keys.add(k[len(prefix_with_dot) :])
191
- else:
192
- head_like_keys.add(k)
193
- extra_missing = sorted(head_like_keys - base_keys)
194
- missing_keys = list(missing_keys) + extra_missing
195
-
196
- incompatible_keys["missing_keys"] = missing_keys
197
- incompatible_keys["unexpected_keys"] = unexpected_keys
198
-
199
- if should_log_incompatible_keys:
200
- log_incompatible_keys(incompatible_keys=incompatible_keys, logger=logger)
201
-
202
- return model