kostyl-toolkit 0.1.21__tar.gz → 0.1.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/PKG-INFO +1 -1
  2. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/callbacks/__init__.py +1 -1
  3. kostyl_toolkit-0.1.23/kostyl/ml/lightning/callbacks/checkpoint.py +327 -0
  4. kostyl_toolkit-0.1.21/kostyl/ml/lightning/callbacks/registry_uploading.py → kostyl_toolkit-0.1.23/kostyl/ml/lightning/callbacks/registry_uploader.py +78 -23
  5. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/extenstions/pretrained_model.py +93 -16
  6. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/pyproject.toml +1 -1
  7. kostyl_toolkit-0.1.21/kostyl/ml/lightning/callbacks/checkpoint.py +0 -56
  8. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/README.md +0 -0
  9. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/__init__.py +0 -0
  10. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/__init__.py +0 -0
  11. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/clearml/__init__.py +0 -0
  12. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/clearml/dataset_utils.py +0 -0
  13. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/clearml/logging_utils.py +0 -0
  14. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/clearml/pulling_utils.py +0 -0
  15. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/configs/__init__.py +0 -0
  16. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/configs/base_model.py +0 -0
  17. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/configs/hyperparams.py +0 -0
  18. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/configs/training_settings.py +0 -0
  19. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/dist_utils.py +0 -0
  20. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/__init__.py +0 -0
  21. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/callbacks/early_stopping.py +0 -0
  22. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/extenstions/__init__.py +0 -0
  23. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/extenstions/custom_module.py +0 -0
  24. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/loggers/__init__.py +0 -0
  25. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/loggers/tb_logger.py +0 -0
  26. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/steps_estimation.py +0 -0
  27. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/metrics_formatting.py +0 -0
  28. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/params_groups.py +0 -0
  29. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/schedulers/__init__.py +0 -0
  30. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/schedulers/base.py +0 -0
  31. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/schedulers/composite.py +0 -0
  32. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/schedulers/cosine.py +0 -0
  33. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/utils/__init__.py +0 -0
  34. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/utils/dict_manipulations.py +0 -0
  35. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/utils/fs.py +0 -0
  36. {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/utils/logging.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.21
3
+ Version: 0.1.23
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -1,6 +1,6 @@
1
1
  from .checkpoint import setup_checkpoint_callback
2
2
  from .early_stopping import setup_early_stopping_callback
3
- from .registry_uploading import ClearMLRegistryUploaderCallback
3
+ from .registry_uploader import ClearMLRegistryUploaderCallback
4
4
 
5
5
 
6
6
  __all__ = [
@@ -0,0 +1,327 @@
1
+ from datetime import timedelta
2
+ from pathlib import Path
3
+ from shutil import rmtree
4
+ from typing import Literal
5
+ from typing import cast
6
+
7
+ from lightning.fabric.utilities.types import _PATH
8
+ from lightning.pytorch.callbacks import ModelCheckpoint
9
+
10
+ from kostyl.ml.configs import CheckpointConfig
11
+ from kostyl.ml.dist_utils import is_main_process
12
+ from kostyl.utils import setup_logger
13
+
14
+ from .registry_uploader import RegistryUploaderCallback
15
+
16
+
17
+ logger = setup_logger("callbacks/checkpoint.py")
18
+
19
+
20
+ class CustomModelCheckpoint(ModelCheckpoint):
21
+ r"""
22
+ Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
23
+ :class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
24
+ checkpoint.
25
+
26
+ After training finishes, use :attr:`best_model_path` to retrieve the path to the
27
+ best checkpoint file and :attr:`best_model_score` to get its score.
28
+
29
+ .. note::
30
+ When using manual optimization with ``every_n_train_steps``, you should save the model state
31
+ in your ``training_step`` before the optimizer step if you want the checkpoint to reflect
32
+ the pre-optimization state. Example:
33
+
34
+ .. code-block:: python
35
+
36
+ def training_step(self, batch, batch_idx):
37
+ # ... forward pass, loss calculation, backward pass ...
38
+
39
+ # Save model state before optimization
40
+ if not hasattr(self, 'saved_models'):
41
+ self.saved_models = {}
42
+ self.saved_models[batch_idx] = {
43
+ k: v.detach().clone()
44
+ for k, v in self.layer.state_dict().items()
45
+ }
46
+
47
+ # Then perform optimization
48
+ optimizer.zero_grad()
49
+ self.manual_backward(loss)
50
+ optimizer.step()
51
+
52
+ # Optional: Clean up old states to save memory
53
+ if batch_idx > 10: # Keep last 10 states
54
+ del self.saved_models[batch_idx - 10]
55
+
56
+ Args:
57
+ dirpath: Directory to save the model file.
58
+ Example: ``dirpath='my/path/'``.
59
+
60
+ .. warning::
61
+ In a distributed environment like DDP, it's recommended to provide a `dirpath` to avoid race conditions.
62
+ When using manual optimization with ``every_n_train_steps``, make sure to save the model state
63
+ in your training loop as shown in the example above.
64
+
65
+ Can be remote file paths such as `s3://mybucket/path/` or 'hdfs://path/'
66
+ (default: ``None``). If dirpath is ``None``, we only keep the ``k`` best checkpoints
67
+ in memory, and do not save anything to disk.
68
+
69
+ filename: Checkpoint filename. Can contain named formatting options to be auto-filled.
70
+ If no name is provided, it will be ``None`` and the checkpoint will be saved to
71
+ ``{epoch}``.and if the Trainer uses a logger, the path will also contain logger name and version.
72
+
73
+ filename: checkpoint filename. Can contain named formatting options to be auto-filled.
74
+
75
+ Example::
76
+
77
+ # save any arbitrary metrics like `val_loss`, etc. in name
78
+ # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
79
+ >>> checkpoint_callback = ModelCheckpoint(
80
+ ... dirpath='my/path',
81
+ ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
82
+ ... )
83
+
84
+ By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``, where "epoch" and "step" match
85
+ the number of finished epoch and optimizer steps respectively.
86
+ monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch.
87
+ verbose: verbosity mode. Default: ``False``.
88
+ save_last: When ``True``, saves a `last.ckpt` copy whenever a checkpoint file gets saved. Can be set to
89
+ ``'link'`` on a local filesystem to create a symbolic link. This allows accessing the latest checkpoint
90
+ in a deterministic manner. Default: ``None``.
91
+ save_top_k: if ``save_top_k == k``,
92
+ the best k models according to the quantity monitored will be saved.
93
+ If ``save_top_k == 0``, no models are saved.
94
+ If ``save_top_k == -1``, all models are saved.
95
+ Please note that the monitors are checked every ``every_n_epochs`` epochs.
96
+ If ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, and the filename remains
97
+ unchanged, the name of the saved file will be appended with a version count starting with ``v1`` to avoid
98
+ collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k
99
+ ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid
100
+ collisions.
101
+ save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``False``.
102
+ mode: one of {min, max}.
103
+ If ``save_top_k != 0``, the decision to overwrite the current save file is made
104
+ based on either the maximization or the minimization of the monitored quantity.
105
+ For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc.
106
+ auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name.
107
+ For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve
108
+ to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/``
109
+ as this will result in extra folders.
110
+ For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False``
111
+ save_weights_only: if ``True``, then only the model's weights will be
112
+ saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
113
+ every_n_train_steps: How many training steps to wait before saving a checkpoint. This does not take into account
114
+ the steps of the current epoch. If ``every_n_train_steps == None or every_n_train_steps == 0``,
115
+ no checkpoints
116
+ will be saved during training. Mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
117
+
118
+ .. note::
119
+ When using with manual optimization, the checkpoint will be saved after the optimizer step by default.
120
+ To save the model state before the optimizer step, you need to save the model state in your
121
+ ``training_step`` before calling ``optimizer.step()``. See the class docstring for an example.
122
+ train_time_interval: Checkpoints are monitored at the specified time interval.
123
+ For all practical purposes, this cannot be smaller than the amount
124
+ of time it takes to process a single training batch. This is not
125
+ guaranteed to execute at the exact time specified, but should be close.
126
+ This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
127
+ every_n_epochs: Number of epochs between checkpoints.
128
+ This value must be ``None`` or non-negative.
129
+ To disable saving top-k checkpoints, set ``every_n_epochs = 0``.
130
+ This argument does not impact the saving of ``save_last=True`` checkpoints.
131
+ If all of ``every_n_epochs``, ``every_n_train_steps`` and
132
+ ``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch
133
+ (equivalent to ``every_n_epochs = 1``).
134
+ If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``,
135
+ saving at the end of each epoch is disabled
136
+ (equivalent to ``every_n_epochs = 0``).
137
+ This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``.
138
+ Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and
139
+ ``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
140
+ will only save checkpoints at epochs 0 < E <= N
141
+ where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
142
+ save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch.
143
+ If ``True``, checkpoints are saved at the end of every training epoch.
144
+ If ``False``, checkpoints are saved at the end of validation.
145
+ If ``None`` (default), checkpointing behavior is determined based on training configuration.
146
+ If ``val_check_interval`` is a str, dict, or `timedelta` (time-based), checkpointing is performed after
147
+ validation.
148
+ If ``check_val_every_n_epoch != 1``, checkpointing will not be performed at the end of
149
+ every training epoch. If there are no validation batches of data, checkpointing will occur at the
150
+ end of the training epoch. If there is a non-default number of validation runs per training epoch
151
+ (``val_check_interval != 1``), checkpointing is performed after validation.
152
+ enable_version_counter: Whether to append a version to the existing file name.
153
+ If ``False``, then the checkpoint files will be overwritten.
154
+
155
+ Note:
156
+ For extra customization, ModelCheckpoint includes the following attributes:
157
+
158
+ - ``CHECKPOINT_JOIN_CHAR = "-"``
159
+ - ``CHECKPOINT_EQUALS_CHAR = "="``
160
+ - ``CHECKPOINT_NAME_LAST = "last"``
161
+ - ``FILE_EXTENSION = ".ckpt"``
162
+ - ``STARTING_VERSION = 1``
163
+
164
+ For example, you can change the default last checkpoint name by doing
165
+ ``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``
166
+
167
+ If you want to checkpoint every N hours, every M train batches, and/or every K val epochs,
168
+ then you should create multiple ``ModelCheckpoint`` callbacks.
169
+
170
+ If the checkpoint's ``dirpath`` changed from what it was before while resuming the training,
171
+ only ``best_model_path`` will be reloaded and a warning will be issued.
172
+
173
+ If you provide a ``filename`` on a mounted device where changing permissions is not allowed (causing ``chmod``
174
+ to raise a ``PermissionError``), install `fsspec>=2025.5.0`. Then the error is caught, the file's permissions
175
+ remain unchanged, and the checkpoint is still saved. Otherwise, no checkpoint will be saved and training stops.
176
+
177
+ Raises:
178
+ MisconfigurationException:
179
+ If ``save_top_k`` is smaller than ``-1``,
180
+ if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or
181
+ if ``mode`` is none of ``"min"`` or ``"max"``.
182
+ ValueError:
183
+ If ``trainer.save_checkpoint`` is ``None``.
184
+
185
+ Example::
186
+
187
+ >>> from lightning.pytorch import Trainer
188
+ >>> from lightning.pytorch.callbacks import ModelCheckpoint
189
+
190
+ # saves checkpoints to 'my/path/' at every epoch
191
+ >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
192
+ >>> trainer = Trainer(callbacks=[checkpoint_callback])
193
+
194
+ # save epoch and val_loss in name
195
+ # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
196
+ >>> checkpoint_callback = ModelCheckpoint(
197
+ ... monitor='val_loss',
198
+ ... dirpath='my/path/',
199
+ ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
200
+ ... )
201
+
202
+ # save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
203
+ # or Neptune, due to the presence of characters like '=' or '/')
204
+ # saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
205
+ >>> checkpoint_callback = ModelCheckpoint(
206
+ ... monitor='val/loss',
207
+ ... dirpath='my/path/',
208
+ ... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
209
+ ... auto_insert_metric_name=False
210
+ ... )
211
+
212
+ # retrieve the best checkpoint after training
213
+ >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
214
+ >>> trainer = Trainer(callbacks=[checkpoint_callback])
215
+ >>> model = ... # doctest: +SKIP
216
+ >>> trainer.fit(model) # doctest: +SKIP
217
+ >>> print(checkpoint_callback.best_model_path) # doctest: +SKIP
218
+
219
+ .. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the
220
+ following arguments:
221
+
222
+ *monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval*
223
+
224
+ Read more: :ref:`Persisting Callback State <extensions/callbacks_state:save callback state>`
225
+
226
+ """ # noqa: D205
227
+
228
+ def __init__( # noqa: D107
229
+ self,
230
+ dirpath: _PATH | None = None,
231
+ filename: str | None = None,
232
+ monitor: str | None = None,
233
+ verbose: bool = False,
234
+ save_last: bool | Literal["link"] | None = None,
235
+ save_top_k: int = 1,
236
+ save_on_exception: bool = False,
237
+ save_weights_only: bool = False,
238
+ mode: str = "min",
239
+ auto_insert_metric_name: bool = True,
240
+ every_n_train_steps: int | None = None,
241
+ train_time_interval: timedelta | None = None,
242
+ every_n_epochs: int | None = None,
243
+ save_on_train_epoch_end: bool | None = None,
244
+ enable_version_counter: bool = True,
245
+ registry_uploader_callback: RegistryUploaderCallback | None = None,
246
+ ) -> None:
247
+ super().__init__(
248
+ dirpath=dirpath,
249
+ filename=filename,
250
+ monitor=monitor,
251
+ verbose=verbose,
252
+ save_last=save_last,
253
+ save_top_k=save_top_k,
254
+ save_on_exception=save_on_exception,
255
+ save_weights_only=save_weights_only,
256
+ mode=mode,
257
+ auto_insert_metric_name=auto_insert_metric_name,
258
+ every_n_train_steps=every_n_train_steps,
259
+ train_time_interval=train_time_interval,
260
+ every_n_epochs=every_n_epochs,
261
+ save_on_train_epoch_end=save_on_train_epoch_end,
262
+ enable_version_counter=enable_version_counter,
263
+ )
264
+ self.registry_uploader_callback = registry_uploader_callback
265
+ self._custom_best_model_path = cast(str, self.best_model_path)
266
+ return
267
+
268
+ @property
269
+ def best_model_path(self) -> str:
270
+ """Best model path."""
271
+ return self._custom_best_model_path
272
+
273
+ @best_model_path.setter
274
+ def best_model_path(self, value: str) -> None:
275
+ self._custom_best_model_path = value
276
+ if self.registry_uploader_callback is not None:
277
+ self.registry_uploader_callback.best_model_path = value
278
+ return
279
+
280
+
281
+ def setup_checkpoint_callback(
282
+ dirpath: Path,
283
+ ckpt_cfg: CheckpointConfig,
284
+ save_weights_only: bool = True,
285
+ registry_uploader_callback: RegistryUploaderCallback | None = None,
286
+ ) -> CustomModelCheckpoint:
287
+ """
288
+ Sets up a ModelCheckpoint callback for PyTorch Lightning.
289
+
290
+ This function prepares a checkpoint directory and configures a ModelCheckpoint
291
+ callback based on the provided configuration. If the directory already exists,
292
+ it is removed (only by the main process) to ensure a clean start. Otherwise,
293
+ the directory is created.
294
+
295
+ Args:
296
+ dirpath (Path): The path to the directory where checkpoints will be saved.
297
+ ckpt_cfg (CheckpointConfig): Configuration object containing checkpoint
298
+ settings such as filename, save_top_k, monitor, and mode.
299
+ save_weights_only (bool, optional): Whether to save only the model weights
300
+ or the full model. Defaults to True.
301
+ registry_uploader_callback (RegistryUploaderCallback | None, optional):
302
+ An optional callback for uploading checkpoints to a registry. Defaults to None.
303
+
304
+ Returns:
305
+ ModelCheckpoint: The configured ModelCheckpoint callback instance.
306
+
307
+ """
308
+ if dirpath.exists():
309
+ if is_main_process():
310
+ logger.warning(f"Checkpoint directory {dirpath} already exists.")
311
+ rmtree(dirpath)
312
+ logger.warning(f"Removed existing checkpoint directory {dirpath}.")
313
+ else:
314
+ logger.info(f"Creating checkpoint directory {dirpath}.")
315
+ dirpath.mkdir(parents=True, exist_ok=True)
316
+
317
+ checkpoint_callback = CustomModelCheckpoint(
318
+ dirpath=dirpath,
319
+ filename=ckpt_cfg.filename,
320
+ save_top_k=ckpt_cfg.save_top_k,
321
+ monitor=ckpt_cfg.monitor,
322
+ mode=ckpt_cfg.mode,
323
+ verbose=True,
324
+ save_weights_only=save_weights_only,
325
+ registry_uploader_callback=registry_uploader_callback,
326
+ )
327
+ return checkpoint_callback
@@ -1,3 +1,7 @@
1
+ from abc import ABC
2
+ from abc import abstractmethod
3
+ from collections.abc import Callable
4
+ from functools import partial
1
5
  from typing import Literal
2
6
  from typing import override
3
7
 
@@ -5,7 +9,6 @@ from clearml import OutputModel
5
9
  from clearml import Task
6
10
  from lightning import Trainer
7
11
  from lightning.pytorch.callbacks import Callback
8
- from lightning.pytorch.callbacks import ModelCheckpoint
9
12
 
10
13
  from kostyl.ml.clearml.logging_utils import find_version_in_tags
11
14
  from kostyl.ml.clearml.logging_utils import increment_version
@@ -16,17 +19,37 @@ from kostyl.utils.logging import setup_logger
16
19
  logger = setup_logger()
17
20
 
18
21
 
19
- class ClearMLRegistryUploaderCallback(Callback):
22
+ class RegistryUploaderCallback(Callback, ABC):
23
+ """Abstract Lightning callback responsible for tracking and uploading the best-performing model checkpoint."""
24
+
25
+ @property
26
+ @abstractmethod
27
+ def best_model_path(self) -> str:
28
+ """Return the file system path pointing to the best model artifact produced during training."""
29
+ raise NotImplementedError
30
+
31
+ @best_model_path.setter
32
+ @abstractmethod
33
+ def best_model_path(self, value: str) -> None:
34
+ raise NotImplementedError
35
+
36
+ @abstractmethod
37
+ def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
38
+ raise NotImplementedError
39
+
40
+
41
+ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
20
42
  """PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
21
43
 
22
44
  def __init__(
23
45
  self,
24
46
  task: Task,
25
- ckpt_callback: ModelCheckpoint,
26
47
  output_model_name: str,
27
48
  output_model_tags: list[str] | None = None,
28
49
  verbose: bool = True,
29
50
  enable_tag_versioning: bool = True,
51
+ label_enumeration: dict[str, int] | None = None,
52
+ config_dict: dict[str, str] | None = None,
30
53
  uploading_frequency: Literal[
31
54
  "after-every-eval", "on-train-end"
32
55
  ] = "on-train-end",
@@ -40,6 +63,8 @@ class ClearMLRegistryUploaderCallback(Callback):
40
63
  output_model_name: Name for the ClearML output model.
41
64
  output_model_tags: Tags for the output model.
42
65
  verbose: Whether to log messages.
66
+ label_enumeration: Optional mapping of label names to integer IDs.
67
+ config_dict: Optional configuration dictionary to associate with the model.
43
68
  enable_tag_versioning: Whether to enable versioning in tags. If True,
44
69
  the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
45
70
  uploading_frequency: When to upload:
@@ -52,15 +77,31 @@ class ClearMLRegistryUploaderCallback(Callback):
52
77
  output_model_tags = []
53
78
 
54
79
  self.task = task
55
- self.ckpt_callback = ckpt_callback
56
80
  self.output_model_name = output_model_name
57
81
  self.output_model_tags = output_model_tags
82
+ self.config_dict = config_dict
83
+ self.label_enumeration = label_enumeration
58
84
  self.verbose = verbose
59
85
  self.uploading_frequency = uploading_frequency
60
86
  self.enable_tag_versioning = enable_tag_versioning
61
87
 
62
88
  self._output_model: OutputModel | None = None
63
- self._last_best_model_path: str = ""
89
+ self._last_uploaded_model_path: str = ""
90
+ self._best_model_path: str = ""
91
+ self._upload_callback: Callable | None = None
92
+ return
93
+
94
+ @property
95
+ @override
96
+ def best_model_path(self) -> str:
97
+ return self._best_model_path
98
+
99
+ @best_model_path.setter
100
+ @override
101
+ def best_model_path(self, value: str) -> None:
102
+ self._best_model_path = value
103
+ if self._upload_callback is not None:
104
+ self._upload_callback()
64
105
  return
65
106
 
66
107
  def _create_output_model(self, pl_module: "KostylLightningModule") -> OutputModel:
@@ -75,44 +116,58 @@ class ClearMLRegistryUploaderCallback(Callback):
75
116
 
76
117
  if "LightningCheckpoint" not in self.output_model_tags:
77
118
  self.output_model_tags.append("LightningCheckpoint")
78
- config = pl_module.model_config
79
- if config is not None:
80
- config = config.to_dict()
119
+
120
+ if self.config_dict is None:
121
+ config = pl_module.model_config
122
+ if config is not None:
123
+ config = config.to_dict()
124
+ else:
125
+ config = self.config_dict
81
126
 
82
127
  return OutputModel(
83
128
  task=self.task,
84
129
  name=self.output_model_name,
85
130
  framework="PyTorch",
86
131
  tags=self.output_model_tags,
87
- config_dict=config,
132
+ config_dict=None,
133
+ label_enumeration=self.label_enumeration,
88
134
  )
89
135
 
136
+ @override
90
137
  def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
91
- current_best = self.ckpt_callback.best_model_path
92
-
93
- if not current_best:
94
- if self.verbose:
95
- logger.info("No best model found yet to upload")
96
- return
97
-
98
- if current_best == self._last_best_model_path:
99
- if self.verbose:
100
- logger.info("Best model unchanged since last upload")
138
+ if not self._best_model_path or (
139
+ self._best_model_path == self._last_uploaded_model_path
140
+ ):
141
+ if not self._best_model_path:
142
+ if self.verbose:
143
+ logger.info("No best model found yet to upload")
144
+ elif self._best_model_path == self._last_uploaded_model_path:
145
+ if self.verbose:
146
+ logger.info("Best model unchanged since last upload")
147
+ self._upload_callback = partial(self._upload_best_checkpoint, pl_module)
101
148
  return
149
+ self._upload_callback = None
102
150
 
103
151
  if self._output_model is None:
104
152
  self._output_model = self._create_output_model(pl_module)
105
153
 
106
154
  if self.verbose:
107
- logger.info(f"Uploading best model from {current_best}")
155
+ logger.info(f"Uploading best model from {self._best_model_path}")
108
156
 
109
157
  self._output_model.update_weights(
110
- current_best,
158
+ self._best_model_path,
111
159
  auto_delete_file=False,
112
160
  async_enable=False,
113
161
  )
114
-
115
- self._last_best_model_path = current_best
162
+ if self.config_dict is None:
163
+ config = pl_module.model_config
164
+ if config is not None:
165
+ config = config.to_dict()
166
+ else:
167
+ config = self.config_dict
168
+ self._output_model.update_design(config_dict=config)
169
+
170
+ self._last_uploaded_model_path = self._best_model_path
116
171
  return
117
172
 
118
173
  @override
@@ -14,6 +14,7 @@ except ImportError:
14
14
 
15
15
  from kostyl.utils.logging import log_incompatible_keys
16
16
  from kostyl.utils.logging import setup_logger
17
+ from torch import nn
17
18
 
18
19
 
19
20
  logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
@@ -67,7 +68,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
67
68
  raise ValueError(f"{checkpoint_path} is a directory")
68
69
  if not checkpoint_path.exists():
69
70
  raise FileNotFoundError(f"{checkpoint_path} does not exist")
70
- if not checkpoint_path.suffix == ".ckpt":
71
+ if checkpoint_path.suffix != ".ckpt":
71
72
  raise ValueError(f"{checkpoint_path} is not a .ckpt file")
72
73
 
73
74
  checkpoint_dict = torch.load(
@@ -77,19 +78,21 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
77
78
  mmap=True,
78
79
  )
79
80
 
80
- config_cls = cast(PretrainedConfig, type(cls.config_class))
81
+ # 1. Восстанавливаем конфиг
82
+ config_cls = cast(type[PretrainedConfig], cls.config_class)
81
83
  config_dict = checkpoint_dict[config_key]
82
84
  config_dict.update(kwargs)
83
85
  config = config_cls.from_dict(config_dict)
84
86
 
85
- kwargs_for_model = {}
86
- for key in kwargs:
87
+ kwargs_for_model: dict[str, Any] = {}
88
+ for key, value in kwargs.items():
87
89
  if not hasattr(config, key):
88
- kwargs_for_model[key] = kwargs[key]
90
+ kwargs_for_model[key] = value
89
91
 
90
92
  with torch.device("meta"):
91
93
  model = cls(config, **kwargs_for_model)
92
94
 
95
+ # PEFT-адаптеры (оставляю твою логику как есть)
93
96
  if "peft_config" in checkpoint_dict:
94
97
  if PeftConfig is None:
95
98
  raise ImportError(
@@ -100,26 +103,100 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
100
103
  model.add_adapter(peft_cfg, adapter_name=name)
101
104
 
102
105
  incompatible_keys: dict[str, list[str]] = {}
103
- if weights_prefix != "":
104
- if weights_prefix[-1] != ".":
105
- weights_prefix += "."
106
- model_state_dict = {}
107
- mismatched_keys = []
108
- for key, value in checkpoint_dict["state_dict"].items():
106
+
107
+ raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
108
+
109
+ if weights_prefix:
110
+ if not weights_prefix.endswith("."):
111
+ weights_prefix = weights_prefix + "."
112
+ state_dict: dict[str, torch.Tensor] = {}
113
+ mismatched_keys: list[str] = []
114
+
115
+ for key, value in raw_state_dict.items():
109
116
  if key.startswith(weights_prefix):
110
117
  new_key = key[len(weights_prefix) :]
111
- model_state_dict[new_key] = value
118
+ state_dict[new_key] = value
112
119
  else:
113
120
  mismatched_keys.append(key)
121
+
122
+ if mismatched_keys:
114
123
  incompatible_keys["mismatched_keys"] = mismatched_keys
115
124
  else:
116
- model_state_dict = checkpoint_dict["state_dict"]
117
-
118
- missing_keys, unexpected_keys = model.load_state_dict(
119
- model_state_dict, strict=False, assign=True
125
+ state_dict = raw_state_dict
126
+
127
+ # 5. Логика base_model_prefix как в HF:
128
+ # поддержка загрузки базовой модели <-> модели с головой
129
+ #
130
+ # cls.base_model_prefix обычно "model" / "bert" / "encoder" и т.п.
131
+ base_prefix: str = getattr(cls, "base_model_prefix", "") or ""
132
+ model_to_load: nn.Module = model
133
+
134
+ if base_prefix:
135
+ prefix_with_dot = base_prefix + "."
136
+ loaded_keys = list(state_dict.keys())
137
+ full_model_state = model.state_dict()
138
+ expected_keys = list(full_model_state.keys())
139
+
140
+ has_prefix_module = any(k.startswith(prefix_with_dot) for k in loaded_keys)
141
+ expects_prefix_module = any(
142
+ k.startswith(prefix_with_dot) for k in expected_keys
143
+ )
144
+
145
+ # Кейc 1: загружаем базовую модель в модель с головой.
146
+ # Пример: StaticEmbeddingsForSequenceClassification (имеет .model)
147
+ # state_dict с ключами "embeddings.weight", "token_pos_weights", ...
148
+ if (
149
+ hasattr(model, base_prefix)
150
+ and not has_prefix_module
151
+ and expects_prefix_module
152
+ ):
153
+ # Веса без префикса -> грузим только в model.<base_prefix>
154
+ model_to_load = getattr(model, base_prefix)
155
+
156
+ # Кейc 2: загружаем чекпоинт модели с головой в базовую модель.
157
+ # Пример: BertModel, а в state_dict ключи "bert.encoder.layer.0..."
158
+ elif (
159
+ not hasattr(model, base_prefix)
160
+ and has_prefix_module
161
+ and not expects_prefix_module
162
+ ):
163
+ new_state_dict: dict[str, torch.Tensor] = {}
164
+ for key, value in state_dict.items():
165
+ if key.startswith(prefix_with_dot):
166
+ new_key = key[len(prefix_with_dot) :]
167
+ else:
168
+ new_key = key
169
+ new_state_dict[new_key] = value
170
+ state_dict = new_state_dict
171
+
172
+ load_result = model_to_load.load_state_dict(
173
+ state_dict, strict=False, assign=True
174
+ )
175
+ missing_keys, unexpected_keys = (
176
+ load_result.missing_keys,
177
+ load_result.unexpected_keys,
120
178
  )
179
+
180
+ # Если мы грузили только в base-подмодуль, расширим missing_keys
181
+ # до полного списка (base + голова), как в старых версиях HF.
182
+ if model_to_load is not model and base_prefix:
183
+ base_keys = set(model_to_load.state_dict().keys())
184
+ # Приводим ключи полной модели к "безпрефиксному" виду
185
+ head_like_keys = set()
186
+ prefix_with_dot = base_prefix + "."
187
+ for k in model.state_dict().keys():
188
+ if k.startswith(prefix_with_dot):
189
+ # отрезаем "model."
190
+ head_like_keys.add(k[len(prefix_with_dot) :])
191
+ else:
192
+ head_like_keys.add(k)
193
+ extra_missing = sorted(head_like_keys - base_keys)
194
+ missing_keys = list(missing_keys) + extra_missing
195
+
121
196
  incompatible_keys["missing_keys"] = missing_keys
122
197
  incompatible_keys["unexpected_keys"] = unexpected_keys
198
+
123
199
  if should_log_incompatible_keys:
124
200
  log_incompatible_keys(incompatible_keys=incompatible_keys, logger=logger)
201
+
125
202
  return model
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "kostyl-toolkit"
3
- version = "0.1.21"
3
+ version = "0.1.23"
4
4
  description = "Kickass Orchestration System for Training, Yielding & Logging "
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -1,56 +0,0 @@
1
- from pathlib import Path
2
- from shutil import rmtree
3
-
4
- from lightning.pytorch.callbacks import ModelCheckpoint
5
-
6
- from kostyl.ml.configs import CheckpointConfig
7
- from kostyl.ml.dist_utils import is_main_process
8
- from kostyl.utils import setup_logger
9
-
10
-
11
- logger = setup_logger("callbacks/checkpoint.py")
12
-
13
-
14
- def setup_checkpoint_callback(
15
- dirpath: Path,
16
- ckpt_cfg: CheckpointConfig,
17
- save_weights_only: bool = True,
18
- ) -> ModelCheckpoint:
19
- """
20
- Sets up a ModelCheckpoint callback for PyTorch Lightning.
21
-
22
- This function prepares a checkpoint directory and configures a ModelCheckpoint
23
- callback based on the provided configuration. If the directory already exists,
24
- it is removed (only by the main process) to ensure a clean start. Otherwise,
25
- the directory is created.
26
-
27
- Args:
28
- dirpath (Path): The path to the directory where checkpoints will be saved.
29
- ckpt_cfg (CheckpointConfig): Configuration object containing checkpoint
30
- settings such as filename, save_top_k, monitor, and mode.
31
- save_weights_only (bool, optional): Whether to save only the model weights
32
- or the full model. Defaults to True.
33
-
34
- Returns:
35
- ModelCheckpoint: The configured ModelCheckpoint callback instance.
36
-
37
- """
38
- if dirpath.exists():
39
- if is_main_process():
40
- logger.warning(f"Checkpoint directory {dirpath} already exists.")
41
- rmtree(dirpath)
42
- logger.warning(f"Removed existing checkpoint directory {dirpath}.")
43
- else:
44
- logger.info(f"Creating checkpoint directory {dirpath}.")
45
- dirpath.mkdir(parents=True, exist_ok=True)
46
-
47
- checkpoint_callback = ModelCheckpoint(
48
- dirpath=dirpath,
49
- filename=ckpt_cfg.filename,
50
- save_top_k=ckpt_cfg.save_top_k,
51
- monitor=ckpt_cfg.monitor,
52
- mode=ckpt_cfg.mode,
53
- verbose=True,
54
- save_weights_only=save_weights_only,
55
- )
56
- return checkpoint_callback