kostyl-toolkit 0.1.21__tar.gz → 0.1.23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/PKG-INFO +1 -1
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/callbacks/__init__.py +1 -1
- kostyl_toolkit-0.1.23/kostyl/ml/lightning/callbacks/checkpoint.py +327 -0
- kostyl_toolkit-0.1.21/kostyl/ml/lightning/callbacks/registry_uploading.py → kostyl_toolkit-0.1.23/kostyl/ml/lightning/callbacks/registry_uploader.py +78 -23
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/extenstions/pretrained_model.py +93 -16
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/pyproject.toml +1 -1
- kostyl_toolkit-0.1.21/kostyl/ml/lightning/callbacks/checkpoint.py +0 -56
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/README.md +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/__init__.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/__init__.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/clearml/__init__.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/clearml/dataset_utils.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/clearml/logging_utils.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/clearml/pulling_utils.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/configs/__init__.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/configs/base_model.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/configs/hyperparams.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/configs/training_settings.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/dist_utils.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/__init__.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/callbacks/early_stopping.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/extenstions/__init__.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/extenstions/custom_module.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/loggers/__init__.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/loggers/tb_logger.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/steps_estimation.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/metrics_formatting.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/params_groups.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/schedulers/__init__.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/schedulers/base.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/schedulers/composite.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/schedulers/cosine.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/utils/__init__.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/utils/dict_manipulations.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/utils/fs.py +0 -0
- {kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/utils/logging.py +0 -0
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
from datetime import timedelta
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from shutil import rmtree
|
|
4
|
+
from typing import Literal
|
|
5
|
+
from typing import cast
|
|
6
|
+
|
|
7
|
+
from lightning.fabric.utilities.types import _PATH
|
|
8
|
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
9
|
+
|
|
10
|
+
from kostyl.ml.configs import CheckpointConfig
|
|
11
|
+
from kostyl.ml.dist_utils import is_main_process
|
|
12
|
+
from kostyl.utils import setup_logger
|
|
13
|
+
|
|
14
|
+
from .registry_uploader import RegistryUploaderCallback
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
logger = setup_logger("callbacks/checkpoint.py")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CustomModelCheckpoint(ModelCheckpoint):
|
|
21
|
+
r"""
|
|
22
|
+
Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
|
|
23
|
+
:class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
|
|
24
|
+
checkpoint.
|
|
25
|
+
|
|
26
|
+
After training finishes, use :attr:`best_model_path` to retrieve the path to the
|
|
27
|
+
best checkpoint file and :attr:`best_model_score` to get its score.
|
|
28
|
+
|
|
29
|
+
.. note::
|
|
30
|
+
When using manual optimization with ``every_n_train_steps``, you should save the model state
|
|
31
|
+
in your ``training_step`` before the optimizer step if you want the checkpoint to reflect
|
|
32
|
+
the pre-optimization state. Example:
|
|
33
|
+
|
|
34
|
+
.. code-block:: python
|
|
35
|
+
|
|
36
|
+
def training_step(self, batch, batch_idx):
|
|
37
|
+
# ... forward pass, loss calculation, backward pass ...
|
|
38
|
+
|
|
39
|
+
# Save model state before optimization
|
|
40
|
+
if not hasattr(self, 'saved_models'):
|
|
41
|
+
self.saved_models = {}
|
|
42
|
+
self.saved_models[batch_idx] = {
|
|
43
|
+
k: v.detach().clone()
|
|
44
|
+
for k, v in self.layer.state_dict().items()
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
# Then perform optimization
|
|
48
|
+
optimizer.zero_grad()
|
|
49
|
+
self.manual_backward(loss)
|
|
50
|
+
optimizer.step()
|
|
51
|
+
|
|
52
|
+
# Optional: Clean up old states to save memory
|
|
53
|
+
if batch_idx > 10: # Keep last 10 states
|
|
54
|
+
del self.saved_models[batch_idx - 10]
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
dirpath: Directory to save the model file.
|
|
58
|
+
Example: ``dirpath='my/path/'``.
|
|
59
|
+
|
|
60
|
+
.. warning::
|
|
61
|
+
In a distributed environment like DDP, it's recommended to provide a `dirpath` to avoid race conditions.
|
|
62
|
+
When using manual optimization with ``every_n_train_steps``, make sure to save the model state
|
|
63
|
+
in your training loop as shown in the example above.
|
|
64
|
+
|
|
65
|
+
Can be remote file paths such as `s3://mybucket/path/` or 'hdfs://path/'
|
|
66
|
+
(default: ``None``). If dirpath is ``None``, we only keep the ``k`` best checkpoints
|
|
67
|
+
in memory, and do not save anything to disk.
|
|
68
|
+
|
|
69
|
+
filename: Checkpoint filename. Can contain named formatting options to be auto-filled.
|
|
70
|
+
If no name is provided, it will be ``None`` and the checkpoint will be saved to
|
|
71
|
+
``{epoch}``.and if the Trainer uses a logger, the path will also contain logger name and version.
|
|
72
|
+
|
|
73
|
+
filename: checkpoint filename. Can contain named formatting options to be auto-filled.
|
|
74
|
+
|
|
75
|
+
Example::
|
|
76
|
+
|
|
77
|
+
# save any arbitrary metrics like `val_loss`, etc. in name
|
|
78
|
+
# saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
|
|
79
|
+
>>> checkpoint_callback = ModelCheckpoint(
|
|
80
|
+
... dirpath='my/path',
|
|
81
|
+
... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
|
|
82
|
+
... )
|
|
83
|
+
|
|
84
|
+
By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``, where "epoch" and "step" match
|
|
85
|
+
the number of finished epoch and optimizer steps respectively.
|
|
86
|
+
monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch.
|
|
87
|
+
verbose: verbosity mode. Default: ``False``.
|
|
88
|
+
save_last: When ``True``, saves a `last.ckpt` copy whenever a checkpoint file gets saved. Can be set to
|
|
89
|
+
``'link'`` on a local filesystem to create a symbolic link. This allows accessing the latest checkpoint
|
|
90
|
+
in a deterministic manner. Default: ``None``.
|
|
91
|
+
save_top_k: if ``save_top_k == k``,
|
|
92
|
+
the best k models according to the quantity monitored will be saved.
|
|
93
|
+
If ``save_top_k == 0``, no models are saved.
|
|
94
|
+
If ``save_top_k == -1``, all models are saved.
|
|
95
|
+
Please note that the monitors are checked every ``every_n_epochs`` epochs.
|
|
96
|
+
If ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, and the filename remains
|
|
97
|
+
unchanged, the name of the saved file will be appended with a version count starting with ``v1`` to avoid
|
|
98
|
+
collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k
|
|
99
|
+
ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid
|
|
100
|
+
collisions.
|
|
101
|
+
save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``False``.
|
|
102
|
+
mode: one of {min, max}.
|
|
103
|
+
If ``save_top_k != 0``, the decision to overwrite the current save file is made
|
|
104
|
+
based on either the maximization or the minimization of the monitored quantity.
|
|
105
|
+
For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc.
|
|
106
|
+
auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name.
|
|
107
|
+
For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve
|
|
108
|
+
to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/``
|
|
109
|
+
as this will result in extra folders.
|
|
110
|
+
For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False``
|
|
111
|
+
save_weights_only: if ``True``, then only the model's weights will be
|
|
112
|
+
saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
|
|
113
|
+
every_n_train_steps: How many training steps to wait before saving a checkpoint. This does not take into account
|
|
114
|
+
the steps of the current epoch. If ``every_n_train_steps == None or every_n_train_steps == 0``,
|
|
115
|
+
no checkpoints
|
|
116
|
+
will be saved during training. Mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
|
|
117
|
+
|
|
118
|
+
.. note::
|
|
119
|
+
When using with manual optimization, the checkpoint will be saved after the optimizer step by default.
|
|
120
|
+
To save the model state before the optimizer step, you need to save the model state in your
|
|
121
|
+
``training_step`` before calling ``optimizer.step()``. See the class docstring for an example.
|
|
122
|
+
train_time_interval: Checkpoints are monitored at the specified time interval.
|
|
123
|
+
For all practical purposes, this cannot be smaller than the amount
|
|
124
|
+
of time it takes to process a single training batch. This is not
|
|
125
|
+
guaranteed to execute at the exact time specified, but should be close.
|
|
126
|
+
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
|
|
127
|
+
every_n_epochs: Number of epochs between checkpoints.
|
|
128
|
+
This value must be ``None`` or non-negative.
|
|
129
|
+
To disable saving top-k checkpoints, set ``every_n_epochs = 0``.
|
|
130
|
+
This argument does not impact the saving of ``save_last=True`` checkpoints.
|
|
131
|
+
If all of ``every_n_epochs``, ``every_n_train_steps`` and
|
|
132
|
+
``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch
|
|
133
|
+
(equivalent to ``every_n_epochs = 1``).
|
|
134
|
+
If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``,
|
|
135
|
+
saving at the end of each epoch is disabled
|
|
136
|
+
(equivalent to ``every_n_epochs = 0``).
|
|
137
|
+
This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``.
|
|
138
|
+
Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and
|
|
139
|
+
``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
|
|
140
|
+
will only save checkpoints at epochs 0 < E <= N
|
|
141
|
+
where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
|
|
142
|
+
save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch.
|
|
143
|
+
If ``True``, checkpoints are saved at the end of every training epoch.
|
|
144
|
+
If ``False``, checkpoints are saved at the end of validation.
|
|
145
|
+
If ``None`` (default), checkpointing behavior is determined based on training configuration.
|
|
146
|
+
If ``val_check_interval`` is a str, dict, or `timedelta` (time-based), checkpointing is performed after
|
|
147
|
+
validation.
|
|
148
|
+
If ``check_val_every_n_epoch != 1``, checkpointing will not be performed at the end of
|
|
149
|
+
every training epoch. If there are no validation batches of data, checkpointing will occur at the
|
|
150
|
+
end of the training epoch. If there is a non-default number of validation runs per training epoch
|
|
151
|
+
(``val_check_interval != 1``), checkpointing is performed after validation.
|
|
152
|
+
enable_version_counter: Whether to append a version to the existing file name.
|
|
153
|
+
If ``False``, then the checkpoint files will be overwritten.
|
|
154
|
+
|
|
155
|
+
Note:
|
|
156
|
+
For extra customization, ModelCheckpoint includes the following attributes:
|
|
157
|
+
|
|
158
|
+
- ``CHECKPOINT_JOIN_CHAR = "-"``
|
|
159
|
+
- ``CHECKPOINT_EQUALS_CHAR = "="``
|
|
160
|
+
- ``CHECKPOINT_NAME_LAST = "last"``
|
|
161
|
+
- ``FILE_EXTENSION = ".ckpt"``
|
|
162
|
+
- ``STARTING_VERSION = 1``
|
|
163
|
+
|
|
164
|
+
For example, you can change the default last checkpoint name by doing
|
|
165
|
+
``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``
|
|
166
|
+
|
|
167
|
+
If you want to checkpoint every N hours, every M train batches, and/or every K val epochs,
|
|
168
|
+
then you should create multiple ``ModelCheckpoint`` callbacks.
|
|
169
|
+
|
|
170
|
+
If the checkpoint's ``dirpath`` changed from what it was before while resuming the training,
|
|
171
|
+
only ``best_model_path`` will be reloaded and a warning will be issued.
|
|
172
|
+
|
|
173
|
+
If you provide a ``filename`` on a mounted device where changing permissions is not allowed (causing ``chmod``
|
|
174
|
+
to raise a ``PermissionError``), install `fsspec>=2025.5.0`. Then the error is caught, the file's permissions
|
|
175
|
+
remain unchanged, and the checkpoint is still saved. Otherwise, no checkpoint will be saved and training stops.
|
|
176
|
+
|
|
177
|
+
Raises:
|
|
178
|
+
MisconfigurationException:
|
|
179
|
+
If ``save_top_k`` is smaller than ``-1``,
|
|
180
|
+
if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or
|
|
181
|
+
if ``mode`` is none of ``"min"`` or ``"max"``.
|
|
182
|
+
ValueError:
|
|
183
|
+
If ``trainer.save_checkpoint`` is ``None``.
|
|
184
|
+
|
|
185
|
+
Example::
|
|
186
|
+
|
|
187
|
+
>>> from lightning.pytorch import Trainer
|
|
188
|
+
>>> from lightning.pytorch.callbacks import ModelCheckpoint
|
|
189
|
+
|
|
190
|
+
# saves checkpoints to 'my/path/' at every epoch
|
|
191
|
+
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
|
|
192
|
+
>>> trainer = Trainer(callbacks=[checkpoint_callback])
|
|
193
|
+
|
|
194
|
+
# save epoch and val_loss in name
|
|
195
|
+
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
|
|
196
|
+
>>> checkpoint_callback = ModelCheckpoint(
|
|
197
|
+
... monitor='val_loss',
|
|
198
|
+
... dirpath='my/path/',
|
|
199
|
+
... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
|
|
200
|
+
... )
|
|
201
|
+
|
|
202
|
+
# save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
|
|
203
|
+
# or Neptune, due to the presence of characters like '=' or '/')
|
|
204
|
+
# saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
|
|
205
|
+
>>> checkpoint_callback = ModelCheckpoint(
|
|
206
|
+
... monitor='val/loss',
|
|
207
|
+
... dirpath='my/path/',
|
|
208
|
+
... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
|
|
209
|
+
... auto_insert_metric_name=False
|
|
210
|
+
... )
|
|
211
|
+
|
|
212
|
+
# retrieve the best checkpoint after training
|
|
213
|
+
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
|
|
214
|
+
>>> trainer = Trainer(callbacks=[checkpoint_callback])
|
|
215
|
+
>>> model = ... # doctest: +SKIP
|
|
216
|
+
>>> trainer.fit(model) # doctest: +SKIP
|
|
217
|
+
>>> print(checkpoint_callback.best_model_path) # doctest: +SKIP
|
|
218
|
+
|
|
219
|
+
.. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the
|
|
220
|
+
following arguments:
|
|
221
|
+
|
|
222
|
+
*monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval*
|
|
223
|
+
|
|
224
|
+
Read more: :ref:`Persisting Callback State <extensions/callbacks_state:save callback state>`
|
|
225
|
+
|
|
226
|
+
""" # noqa: D205
|
|
227
|
+
|
|
228
|
+
def __init__( # noqa: D107
|
|
229
|
+
self,
|
|
230
|
+
dirpath: _PATH | None = None,
|
|
231
|
+
filename: str | None = None,
|
|
232
|
+
monitor: str | None = None,
|
|
233
|
+
verbose: bool = False,
|
|
234
|
+
save_last: bool | Literal["link"] | None = None,
|
|
235
|
+
save_top_k: int = 1,
|
|
236
|
+
save_on_exception: bool = False,
|
|
237
|
+
save_weights_only: bool = False,
|
|
238
|
+
mode: str = "min",
|
|
239
|
+
auto_insert_metric_name: bool = True,
|
|
240
|
+
every_n_train_steps: int | None = None,
|
|
241
|
+
train_time_interval: timedelta | None = None,
|
|
242
|
+
every_n_epochs: int | None = None,
|
|
243
|
+
save_on_train_epoch_end: bool | None = None,
|
|
244
|
+
enable_version_counter: bool = True,
|
|
245
|
+
registry_uploader_callback: RegistryUploaderCallback | None = None,
|
|
246
|
+
) -> None:
|
|
247
|
+
super().__init__(
|
|
248
|
+
dirpath=dirpath,
|
|
249
|
+
filename=filename,
|
|
250
|
+
monitor=monitor,
|
|
251
|
+
verbose=verbose,
|
|
252
|
+
save_last=save_last,
|
|
253
|
+
save_top_k=save_top_k,
|
|
254
|
+
save_on_exception=save_on_exception,
|
|
255
|
+
save_weights_only=save_weights_only,
|
|
256
|
+
mode=mode,
|
|
257
|
+
auto_insert_metric_name=auto_insert_metric_name,
|
|
258
|
+
every_n_train_steps=every_n_train_steps,
|
|
259
|
+
train_time_interval=train_time_interval,
|
|
260
|
+
every_n_epochs=every_n_epochs,
|
|
261
|
+
save_on_train_epoch_end=save_on_train_epoch_end,
|
|
262
|
+
enable_version_counter=enable_version_counter,
|
|
263
|
+
)
|
|
264
|
+
self.registry_uploader_callback = registry_uploader_callback
|
|
265
|
+
self._custom_best_model_path = cast(str, self.best_model_path)
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
@property
|
|
269
|
+
def best_model_path(self) -> str:
|
|
270
|
+
"""Best model path."""
|
|
271
|
+
return self._custom_best_model_path
|
|
272
|
+
|
|
273
|
+
@best_model_path.setter
|
|
274
|
+
def best_model_path(self, value: str) -> None:
|
|
275
|
+
self._custom_best_model_path = value
|
|
276
|
+
if self.registry_uploader_callback is not None:
|
|
277
|
+
self.registry_uploader_callback.best_model_path = value
|
|
278
|
+
return
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def setup_checkpoint_callback(
|
|
282
|
+
dirpath: Path,
|
|
283
|
+
ckpt_cfg: CheckpointConfig,
|
|
284
|
+
save_weights_only: bool = True,
|
|
285
|
+
registry_uploader_callback: RegistryUploaderCallback | None = None,
|
|
286
|
+
) -> CustomModelCheckpoint:
|
|
287
|
+
"""
|
|
288
|
+
Sets up a ModelCheckpoint callback for PyTorch Lightning.
|
|
289
|
+
|
|
290
|
+
This function prepares a checkpoint directory and configures a ModelCheckpoint
|
|
291
|
+
callback based on the provided configuration. If the directory already exists,
|
|
292
|
+
it is removed (only by the main process) to ensure a clean start. Otherwise,
|
|
293
|
+
the directory is created.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
dirpath (Path): The path to the directory where checkpoints will be saved.
|
|
297
|
+
ckpt_cfg (CheckpointConfig): Configuration object containing checkpoint
|
|
298
|
+
settings such as filename, save_top_k, monitor, and mode.
|
|
299
|
+
save_weights_only (bool, optional): Whether to save only the model weights
|
|
300
|
+
or the full model. Defaults to True.
|
|
301
|
+
registry_uploader_callback (RegistryUploaderCallback | None, optional):
|
|
302
|
+
An optional callback for uploading checkpoints to a registry. Defaults to None.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
ModelCheckpoint: The configured ModelCheckpoint callback instance.
|
|
306
|
+
|
|
307
|
+
"""
|
|
308
|
+
if dirpath.exists():
|
|
309
|
+
if is_main_process():
|
|
310
|
+
logger.warning(f"Checkpoint directory {dirpath} already exists.")
|
|
311
|
+
rmtree(dirpath)
|
|
312
|
+
logger.warning(f"Removed existing checkpoint directory {dirpath}.")
|
|
313
|
+
else:
|
|
314
|
+
logger.info(f"Creating checkpoint directory {dirpath}.")
|
|
315
|
+
dirpath.mkdir(parents=True, exist_ok=True)
|
|
316
|
+
|
|
317
|
+
checkpoint_callback = CustomModelCheckpoint(
|
|
318
|
+
dirpath=dirpath,
|
|
319
|
+
filename=ckpt_cfg.filename,
|
|
320
|
+
save_top_k=ckpt_cfg.save_top_k,
|
|
321
|
+
monitor=ckpt_cfg.monitor,
|
|
322
|
+
mode=ckpt_cfg.mode,
|
|
323
|
+
verbose=True,
|
|
324
|
+
save_weights_only=save_weights_only,
|
|
325
|
+
registry_uploader_callback=registry_uploader_callback,
|
|
326
|
+
)
|
|
327
|
+
return checkpoint_callback
|
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from functools import partial
|
|
1
5
|
from typing import Literal
|
|
2
6
|
from typing import override
|
|
3
7
|
|
|
@@ -5,7 +9,6 @@ from clearml import OutputModel
|
|
|
5
9
|
from clearml import Task
|
|
6
10
|
from lightning import Trainer
|
|
7
11
|
from lightning.pytorch.callbacks import Callback
|
|
8
|
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
9
12
|
|
|
10
13
|
from kostyl.ml.clearml.logging_utils import find_version_in_tags
|
|
11
14
|
from kostyl.ml.clearml.logging_utils import increment_version
|
|
@@ -16,17 +19,37 @@ from kostyl.utils.logging import setup_logger
|
|
|
16
19
|
logger = setup_logger()
|
|
17
20
|
|
|
18
21
|
|
|
19
|
-
class
|
|
22
|
+
class RegistryUploaderCallback(Callback, ABC):
|
|
23
|
+
"""Abstract Lightning callback responsible for tracking and uploading the best-performing model checkpoint."""
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def best_model_path(self) -> str:
|
|
28
|
+
"""Return the file system path pointing to the best model artifact produced during training."""
|
|
29
|
+
raise NotImplementedError
|
|
30
|
+
|
|
31
|
+
@best_model_path.setter
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def best_model_path(self, value: str) -> None:
|
|
34
|
+
raise NotImplementedError
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
20
42
|
"""PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
|
|
21
43
|
|
|
22
44
|
def __init__(
|
|
23
45
|
self,
|
|
24
46
|
task: Task,
|
|
25
|
-
ckpt_callback: ModelCheckpoint,
|
|
26
47
|
output_model_name: str,
|
|
27
48
|
output_model_tags: list[str] | None = None,
|
|
28
49
|
verbose: bool = True,
|
|
29
50
|
enable_tag_versioning: bool = True,
|
|
51
|
+
label_enumeration: dict[str, int] | None = None,
|
|
52
|
+
config_dict: dict[str, str] | None = None,
|
|
30
53
|
uploading_frequency: Literal[
|
|
31
54
|
"after-every-eval", "on-train-end"
|
|
32
55
|
] = "on-train-end",
|
|
@@ -40,6 +63,8 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
40
63
|
output_model_name: Name for the ClearML output model.
|
|
41
64
|
output_model_tags: Tags for the output model.
|
|
42
65
|
verbose: Whether to log messages.
|
|
66
|
+
label_enumeration: Optional mapping of label names to integer IDs.
|
|
67
|
+
config_dict: Optional configuration dictionary to associate with the model.
|
|
43
68
|
enable_tag_versioning: Whether to enable versioning in tags. If True,
|
|
44
69
|
the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
|
|
45
70
|
uploading_frequency: When to upload:
|
|
@@ -52,15 +77,31 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
52
77
|
output_model_tags = []
|
|
53
78
|
|
|
54
79
|
self.task = task
|
|
55
|
-
self.ckpt_callback = ckpt_callback
|
|
56
80
|
self.output_model_name = output_model_name
|
|
57
81
|
self.output_model_tags = output_model_tags
|
|
82
|
+
self.config_dict = config_dict
|
|
83
|
+
self.label_enumeration = label_enumeration
|
|
58
84
|
self.verbose = verbose
|
|
59
85
|
self.uploading_frequency = uploading_frequency
|
|
60
86
|
self.enable_tag_versioning = enable_tag_versioning
|
|
61
87
|
|
|
62
88
|
self._output_model: OutputModel | None = None
|
|
63
|
-
self.
|
|
89
|
+
self._last_uploaded_model_path: str = ""
|
|
90
|
+
self._best_model_path: str = ""
|
|
91
|
+
self._upload_callback: Callable | None = None
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
@override
|
|
96
|
+
def best_model_path(self) -> str:
|
|
97
|
+
return self._best_model_path
|
|
98
|
+
|
|
99
|
+
@best_model_path.setter
|
|
100
|
+
@override
|
|
101
|
+
def best_model_path(self, value: str) -> None:
|
|
102
|
+
self._best_model_path = value
|
|
103
|
+
if self._upload_callback is not None:
|
|
104
|
+
self._upload_callback()
|
|
64
105
|
return
|
|
65
106
|
|
|
66
107
|
def _create_output_model(self, pl_module: "KostylLightningModule") -> OutputModel:
|
|
@@ -75,44 +116,58 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
75
116
|
|
|
76
117
|
if "LightningCheckpoint" not in self.output_model_tags:
|
|
77
118
|
self.output_model_tags.append("LightningCheckpoint")
|
|
78
|
-
|
|
79
|
-
if
|
|
80
|
-
config =
|
|
119
|
+
|
|
120
|
+
if self.config_dict is None:
|
|
121
|
+
config = pl_module.model_config
|
|
122
|
+
if config is not None:
|
|
123
|
+
config = config.to_dict()
|
|
124
|
+
else:
|
|
125
|
+
config = self.config_dict
|
|
81
126
|
|
|
82
127
|
return OutputModel(
|
|
83
128
|
task=self.task,
|
|
84
129
|
name=self.output_model_name,
|
|
85
130
|
framework="PyTorch",
|
|
86
131
|
tags=self.output_model_tags,
|
|
87
|
-
config_dict=
|
|
132
|
+
config_dict=None,
|
|
133
|
+
label_enumeration=self.label_enumeration,
|
|
88
134
|
)
|
|
89
135
|
|
|
136
|
+
@override
|
|
90
137
|
def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
if self.
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
138
|
+
if not self._best_model_path or (
|
|
139
|
+
self._best_model_path == self._last_uploaded_model_path
|
|
140
|
+
):
|
|
141
|
+
if not self._best_model_path:
|
|
142
|
+
if self.verbose:
|
|
143
|
+
logger.info("No best model found yet to upload")
|
|
144
|
+
elif self._best_model_path == self._last_uploaded_model_path:
|
|
145
|
+
if self.verbose:
|
|
146
|
+
logger.info("Best model unchanged since last upload")
|
|
147
|
+
self._upload_callback = partial(self._upload_best_checkpoint, pl_module)
|
|
101
148
|
return
|
|
149
|
+
self._upload_callback = None
|
|
102
150
|
|
|
103
151
|
if self._output_model is None:
|
|
104
152
|
self._output_model = self._create_output_model(pl_module)
|
|
105
153
|
|
|
106
154
|
if self.verbose:
|
|
107
|
-
logger.info(f"Uploading best model from {
|
|
155
|
+
logger.info(f"Uploading best model from {self._best_model_path}")
|
|
108
156
|
|
|
109
157
|
self._output_model.update_weights(
|
|
110
|
-
|
|
158
|
+
self._best_model_path,
|
|
111
159
|
auto_delete_file=False,
|
|
112
160
|
async_enable=False,
|
|
113
161
|
)
|
|
114
|
-
|
|
115
|
-
|
|
162
|
+
if self.config_dict is None:
|
|
163
|
+
config = pl_module.model_config
|
|
164
|
+
if config is not None:
|
|
165
|
+
config = config.to_dict()
|
|
166
|
+
else:
|
|
167
|
+
config = self.config_dict
|
|
168
|
+
self._output_model.update_design(config_dict=config)
|
|
169
|
+
|
|
170
|
+
self._last_uploaded_model_path = self._best_model_path
|
|
116
171
|
return
|
|
117
172
|
|
|
118
173
|
@override
|
{kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/extenstions/pretrained_model.py
RENAMED
|
@@ -14,6 +14,7 @@ except ImportError:
|
|
|
14
14
|
|
|
15
15
|
from kostyl.utils.logging import log_incompatible_keys
|
|
16
16
|
from kostyl.utils.logging import setup_logger
|
|
17
|
+
from torch import nn
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
|
|
@@ -67,7 +68,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
67
68
|
raise ValueError(f"{checkpoint_path} is a directory")
|
|
68
69
|
if not checkpoint_path.exists():
|
|
69
70
|
raise FileNotFoundError(f"{checkpoint_path} does not exist")
|
|
70
|
-
if
|
|
71
|
+
if checkpoint_path.suffix != ".ckpt":
|
|
71
72
|
raise ValueError(f"{checkpoint_path} is not a .ckpt file")
|
|
72
73
|
|
|
73
74
|
checkpoint_dict = torch.load(
|
|
@@ -77,19 +78,21 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
77
78
|
mmap=True,
|
|
78
79
|
)
|
|
79
80
|
|
|
80
|
-
|
|
81
|
+
# 1. Восстанавливаем конфиг
|
|
82
|
+
config_cls = cast(type[PretrainedConfig], cls.config_class)
|
|
81
83
|
config_dict = checkpoint_dict[config_key]
|
|
82
84
|
config_dict.update(kwargs)
|
|
83
85
|
config = config_cls.from_dict(config_dict)
|
|
84
86
|
|
|
85
|
-
kwargs_for_model = {}
|
|
86
|
-
for key in kwargs:
|
|
87
|
+
kwargs_for_model: dict[str, Any] = {}
|
|
88
|
+
for key, value in kwargs.items():
|
|
87
89
|
if not hasattr(config, key):
|
|
88
|
-
kwargs_for_model[key] =
|
|
90
|
+
kwargs_for_model[key] = value
|
|
89
91
|
|
|
90
92
|
with torch.device("meta"):
|
|
91
93
|
model = cls(config, **kwargs_for_model)
|
|
92
94
|
|
|
95
|
+
# PEFT-адаптеры (оставляю твою логику как есть)
|
|
93
96
|
if "peft_config" in checkpoint_dict:
|
|
94
97
|
if PeftConfig is None:
|
|
95
98
|
raise ImportError(
|
|
@@ -100,26 +103,100 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
100
103
|
model.add_adapter(peft_cfg, adapter_name=name)
|
|
101
104
|
|
|
102
105
|
incompatible_keys: dict[str, list[str]] = {}
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
106
|
+
|
|
107
|
+
raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
|
|
108
|
+
|
|
109
|
+
if weights_prefix:
|
|
110
|
+
if not weights_prefix.endswith("."):
|
|
111
|
+
weights_prefix = weights_prefix + "."
|
|
112
|
+
state_dict: dict[str, torch.Tensor] = {}
|
|
113
|
+
mismatched_keys: list[str] = []
|
|
114
|
+
|
|
115
|
+
for key, value in raw_state_dict.items():
|
|
109
116
|
if key.startswith(weights_prefix):
|
|
110
117
|
new_key = key[len(weights_prefix) :]
|
|
111
|
-
|
|
118
|
+
state_dict[new_key] = value
|
|
112
119
|
else:
|
|
113
120
|
mismatched_keys.append(key)
|
|
121
|
+
|
|
122
|
+
if mismatched_keys:
|
|
114
123
|
incompatible_keys["mismatched_keys"] = mismatched_keys
|
|
115
124
|
else:
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
125
|
+
state_dict = raw_state_dict
|
|
126
|
+
|
|
127
|
+
# 5. Логика base_model_prefix как в HF:
|
|
128
|
+
# поддержка загрузки базовой модели <-> модели с головой
|
|
129
|
+
#
|
|
130
|
+
# cls.base_model_prefix обычно "model" / "bert" / "encoder" и т.п.
|
|
131
|
+
base_prefix: str = getattr(cls, "base_model_prefix", "") or ""
|
|
132
|
+
model_to_load: nn.Module = model
|
|
133
|
+
|
|
134
|
+
if base_prefix:
|
|
135
|
+
prefix_with_dot = base_prefix + "."
|
|
136
|
+
loaded_keys = list(state_dict.keys())
|
|
137
|
+
full_model_state = model.state_dict()
|
|
138
|
+
expected_keys = list(full_model_state.keys())
|
|
139
|
+
|
|
140
|
+
has_prefix_module = any(k.startswith(prefix_with_dot) for k in loaded_keys)
|
|
141
|
+
expects_prefix_module = any(
|
|
142
|
+
k.startswith(prefix_with_dot) for k in expected_keys
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Кейc 1: загружаем базовую модель в модель с головой.
|
|
146
|
+
# Пример: StaticEmbeddingsForSequenceClassification (имеет .model)
|
|
147
|
+
# state_dict с ключами "embeddings.weight", "token_pos_weights", ...
|
|
148
|
+
if (
|
|
149
|
+
hasattr(model, base_prefix)
|
|
150
|
+
and not has_prefix_module
|
|
151
|
+
and expects_prefix_module
|
|
152
|
+
):
|
|
153
|
+
# Веса без префикса -> грузим только в model.<base_prefix>
|
|
154
|
+
model_to_load = getattr(model, base_prefix)
|
|
155
|
+
|
|
156
|
+
# Кейc 2: загружаем чекпоинт модели с головой в базовую модель.
|
|
157
|
+
# Пример: BertModel, а в state_dict ключи "bert.encoder.layer.0..."
|
|
158
|
+
elif (
|
|
159
|
+
not hasattr(model, base_prefix)
|
|
160
|
+
and has_prefix_module
|
|
161
|
+
and not expects_prefix_module
|
|
162
|
+
):
|
|
163
|
+
new_state_dict: dict[str, torch.Tensor] = {}
|
|
164
|
+
for key, value in state_dict.items():
|
|
165
|
+
if key.startswith(prefix_with_dot):
|
|
166
|
+
new_key = key[len(prefix_with_dot) :]
|
|
167
|
+
else:
|
|
168
|
+
new_key = key
|
|
169
|
+
new_state_dict[new_key] = value
|
|
170
|
+
state_dict = new_state_dict
|
|
171
|
+
|
|
172
|
+
load_result = model_to_load.load_state_dict(
|
|
173
|
+
state_dict, strict=False, assign=True
|
|
174
|
+
)
|
|
175
|
+
missing_keys, unexpected_keys = (
|
|
176
|
+
load_result.missing_keys,
|
|
177
|
+
load_result.unexpected_keys,
|
|
120
178
|
)
|
|
179
|
+
|
|
180
|
+
# Если мы грузили только в base-подмодуль, расширим missing_keys
|
|
181
|
+
# до полного списка (base + голова), как в старых версиях HF.
|
|
182
|
+
if model_to_load is not model and base_prefix:
|
|
183
|
+
base_keys = set(model_to_load.state_dict().keys())
|
|
184
|
+
# Приводим ключи полной модели к "безпрефиксному" виду
|
|
185
|
+
head_like_keys = set()
|
|
186
|
+
prefix_with_dot = base_prefix + "."
|
|
187
|
+
for k in model.state_dict().keys():
|
|
188
|
+
if k.startswith(prefix_with_dot):
|
|
189
|
+
# отрезаем "model."
|
|
190
|
+
head_like_keys.add(k[len(prefix_with_dot) :])
|
|
191
|
+
else:
|
|
192
|
+
head_like_keys.add(k)
|
|
193
|
+
extra_missing = sorted(head_like_keys - base_keys)
|
|
194
|
+
missing_keys = list(missing_keys) + extra_missing
|
|
195
|
+
|
|
121
196
|
incompatible_keys["missing_keys"] = missing_keys
|
|
122
197
|
incompatible_keys["unexpected_keys"] = unexpected_keys
|
|
198
|
+
|
|
123
199
|
if should_log_incompatible_keys:
|
|
124
200
|
log_incompatible_keys(incompatible_keys=incompatible_keys, logger=logger)
|
|
201
|
+
|
|
125
202
|
return model
|
|
@@ -1,56 +0,0 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
|
-
from shutil import rmtree
|
|
3
|
-
|
|
4
|
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
5
|
-
|
|
6
|
-
from kostyl.ml.configs import CheckpointConfig
|
|
7
|
-
from kostyl.ml.dist_utils import is_main_process
|
|
8
|
-
from kostyl.utils import setup_logger
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
logger = setup_logger("callbacks/checkpoint.py")
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def setup_checkpoint_callback(
|
|
15
|
-
dirpath: Path,
|
|
16
|
-
ckpt_cfg: CheckpointConfig,
|
|
17
|
-
save_weights_only: bool = True,
|
|
18
|
-
) -> ModelCheckpoint:
|
|
19
|
-
"""
|
|
20
|
-
Sets up a ModelCheckpoint callback for PyTorch Lightning.
|
|
21
|
-
|
|
22
|
-
This function prepares a checkpoint directory and configures a ModelCheckpoint
|
|
23
|
-
callback based on the provided configuration. If the directory already exists,
|
|
24
|
-
it is removed (only by the main process) to ensure a clean start. Otherwise,
|
|
25
|
-
the directory is created.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
dirpath (Path): The path to the directory where checkpoints will be saved.
|
|
29
|
-
ckpt_cfg (CheckpointConfig): Configuration object containing checkpoint
|
|
30
|
-
settings such as filename, save_top_k, monitor, and mode.
|
|
31
|
-
save_weights_only (bool, optional): Whether to save only the model weights
|
|
32
|
-
or the full model. Defaults to True.
|
|
33
|
-
|
|
34
|
-
Returns:
|
|
35
|
-
ModelCheckpoint: The configured ModelCheckpoint callback instance.
|
|
36
|
-
|
|
37
|
-
"""
|
|
38
|
-
if dirpath.exists():
|
|
39
|
-
if is_main_process():
|
|
40
|
-
logger.warning(f"Checkpoint directory {dirpath} already exists.")
|
|
41
|
-
rmtree(dirpath)
|
|
42
|
-
logger.warning(f"Removed existing checkpoint directory {dirpath}.")
|
|
43
|
-
else:
|
|
44
|
-
logger.info(f"Creating checkpoint directory {dirpath}.")
|
|
45
|
-
dirpath.mkdir(parents=True, exist_ok=True)
|
|
46
|
-
|
|
47
|
-
checkpoint_callback = ModelCheckpoint(
|
|
48
|
-
dirpath=dirpath,
|
|
49
|
-
filename=ckpt_cfg.filename,
|
|
50
|
-
save_top_k=ckpt_cfg.save_top_k,
|
|
51
|
-
monitor=ckpt_cfg.monitor,
|
|
52
|
-
mode=ckpt_cfg.mode,
|
|
53
|
-
verbose=True,
|
|
54
|
-
save_weights_only=save_weights_only,
|
|
55
|
-
)
|
|
56
|
-
return checkpoint_callback
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/callbacks/early_stopping.py
RENAMED
|
File without changes
|
|
File without changes
|
{kostyl_toolkit-0.1.21 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/extenstions/custom_module.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|