kostyl-toolkit 0.1.22__py3-none-any.whl → 0.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kostyl/ml/clearml/pulling_utils.py +1 -1
- kostyl/ml/lightning/callbacks/__init__.py +1 -1
- kostyl/ml/lightning/callbacks/checkpoint.py +273 -2
- kostyl/ml/lightning/callbacks/{registry_uploading.py → registry_uploader.py} +55 -18
- kostyl/ml/lightning/extenstions/pretrained_model.py +8 -105
- {kostyl_toolkit-0.1.22.dist-info → kostyl_toolkit-0.1.24.dist-info}/METADATA +1 -1
- {kostyl_toolkit-0.1.22.dist-info → kostyl_toolkit-0.1.24.dist-info}/RECORD +8 -8
- {kostyl_toolkit-0.1.22.dist-info → kostyl_toolkit-0.1.24.dist-info}/WHEEL +0 -0
|
@@ -95,7 +95,7 @@ def get_model_from_clearml[
|
|
|
95
95
|
raise ValueError(
|
|
96
96
|
f"Model class {model.__name__} is not compatible with Lightning checkpoints."
|
|
97
97
|
)
|
|
98
|
-
model_instance = model.
|
|
98
|
+
model_instance = model.from_lightning_checkpoint(local_path, **kwargs)
|
|
99
99
|
else:
|
|
100
100
|
raise ValueError(
|
|
101
101
|
f"Unsupported model format for path: {local_path}. "
|
|
@@ -1,21 +1,289 @@
|
|
|
1
|
+
from datetime import timedelta
|
|
1
2
|
from pathlib import Path
|
|
2
3
|
from shutil import rmtree
|
|
4
|
+
from typing import Literal
|
|
5
|
+
from typing import cast
|
|
3
6
|
|
|
7
|
+
from lightning.fabric.utilities.types import _PATH
|
|
4
8
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
5
9
|
|
|
6
10
|
from kostyl.ml.configs import CheckpointConfig
|
|
7
11
|
from kostyl.ml.dist_utils import is_main_process
|
|
8
12
|
from kostyl.utils import setup_logger
|
|
9
13
|
|
|
14
|
+
from .registry_uploader import RegistryUploaderCallback
|
|
15
|
+
|
|
10
16
|
|
|
11
17
|
logger = setup_logger("callbacks/checkpoint.py")
|
|
12
18
|
|
|
13
19
|
|
|
20
|
+
class CustomModelCheckpoint(ModelCheckpoint):
|
|
21
|
+
r"""
|
|
22
|
+
Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
|
|
23
|
+
:class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
|
|
24
|
+
checkpoint.
|
|
25
|
+
|
|
26
|
+
After training finishes, use :attr:`best_model_path` to retrieve the path to the
|
|
27
|
+
best checkpoint file and :attr:`best_model_score` to get its score.
|
|
28
|
+
|
|
29
|
+
.. note::
|
|
30
|
+
When using manual optimization with ``every_n_train_steps``, you should save the model state
|
|
31
|
+
in your ``training_step`` before the optimizer step if you want the checkpoint to reflect
|
|
32
|
+
the pre-optimization state. Example:
|
|
33
|
+
|
|
34
|
+
.. code-block:: python
|
|
35
|
+
|
|
36
|
+
def training_step(self, batch, batch_idx):
|
|
37
|
+
# ... forward pass, loss calculation, backward pass ...
|
|
38
|
+
|
|
39
|
+
# Save model state before optimization
|
|
40
|
+
if not hasattr(self, 'saved_models'):
|
|
41
|
+
self.saved_models = {}
|
|
42
|
+
self.saved_models[batch_idx] = {
|
|
43
|
+
k: v.detach().clone()
|
|
44
|
+
for k, v in self.layer.state_dict().items()
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
# Then perform optimization
|
|
48
|
+
optimizer.zero_grad()
|
|
49
|
+
self.manual_backward(loss)
|
|
50
|
+
optimizer.step()
|
|
51
|
+
|
|
52
|
+
# Optional: Clean up old states to save memory
|
|
53
|
+
if batch_idx > 10: # Keep last 10 states
|
|
54
|
+
del self.saved_models[batch_idx - 10]
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
dirpath: Directory to save the model file.
|
|
58
|
+
Example: ``dirpath='my/path/'``.
|
|
59
|
+
|
|
60
|
+
.. warning::
|
|
61
|
+
In a distributed environment like DDP, it's recommended to provide a `dirpath` to avoid race conditions.
|
|
62
|
+
When using manual optimization with ``every_n_train_steps``, make sure to save the model state
|
|
63
|
+
in your training loop as shown in the example above.
|
|
64
|
+
|
|
65
|
+
Can be remote file paths such as `s3://mybucket/path/` or 'hdfs://path/'
|
|
66
|
+
(default: ``None``). If dirpath is ``None``, we only keep the ``k`` best checkpoints
|
|
67
|
+
in memory, and do not save anything to disk.
|
|
68
|
+
|
|
69
|
+
filename: Checkpoint filename. Can contain named formatting options to be auto-filled.
|
|
70
|
+
If no name is provided, it will be ``None`` and the checkpoint will be saved to
|
|
71
|
+
``{epoch}``.and if the Trainer uses a logger, the path will also contain logger name and version.
|
|
72
|
+
|
|
73
|
+
filename: checkpoint filename. Can contain named formatting options to be auto-filled.
|
|
74
|
+
|
|
75
|
+
Example::
|
|
76
|
+
|
|
77
|
+
# save any arbitrary metrics like `val_loss`, etc. in name
|
|
78
|
+
# saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
|
|
79
|
+
>>> checkpoint_callback = ModelCheckpoint(
|
|
80
|
+
... dirpath='my/path',
|
|
81
|
+
... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
|
|
82
|
+
... )
|
|
83
|
+
|
|
84
|
+
By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``, where "epoch" and "step" match
|
|
85
|
+
the number of finished epoch and optimizer steps respectively.
|
|
86
|
+
monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch.
|
|
87
|
+
verbose: verbosity mode. Default: ``False``.
|
|
88
|
+
save_last: When ``True``, saves a `last.ckpt` copy whenever a checkpoint file gets saved. Can be set to
|
|
89
|
+
``'link'`` on a local filesystem to create a symbolic link. This allows accessing the latest checkpoint
|
|
90
|
+
in a deterministic manner. Default: ``None``.
|
|
91
|
+
save_top_k: if ``save_top_k == k``,
|
|
92
|
+
the best k models according to the quantity monitored will be saved.
|
|
93
|
+
If ``save_top_k == 0``, no models are saved.
|
|
94
|
+
If ``save_top_k == -1``, all models are saved.
|
|
95
|
+
Please note that the monitors are checked every ``every_n_epochs`` epochs.
|
|
96
|
+
If ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, and the filename remains
|
|
97
|
+
unchanged, the name of the saved file will be appended with a version count starting with ``v1`` to avoid
|
|
98
|
+
collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k
|
|
99
|
+
ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid
|
|
100
|
+
collisions.
|
|
101
|
+
save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``False``.
|
|
102
|
+
mode: one of {min, max}.
|
|
103
|
+
If ``save_top_k != 0``, the decision to overwrite the current save file is made
|
|
104
|
+
based on either the maximization or the minimization of the monitored quantity.
|
|
105
|
+
For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc.
|
|
106
|
+
auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name.
|
|
107
|
+
For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve
|
|
108
|
+
to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/``
|
|
109
|
+
as this will result in extra folders.
|
|
110
|
+
For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False``
|
|
111
|
+
save_weights_only: if ``True``, then only the model's weights will be
|
|
112
|
+
saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
|
|
113
|
+
every_n_train_steps: How many training steps to wait before saving a checkpoint. This does not take into account
|
|
114
|
+
the steps of the current epoch. If ``every_n_train_steps == None or every_n_train_steps == 0``,
|
|
115
|
+
no checkpoints
|
|
116
|
+
will be saved during training. Mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
|
|
117
|
+
|
|
118
|
+
.. note::
|
|
119
|
+
When using with manual optimization, the checkpoint will be saved after the optimizer step by default.
|
|
120
|
+
To save the model state before the optimizer step, you need to save the model state in your
|
|
121
|
+
``training_step`` before calling ``optimizer.step()``. See the class docstring for an example.
|
|
122
|
+
train_time_interval: Checkpoints are monitored at the specified time interval.
|
|
123
|
+
For all practical purposes, this cannot be smaller than the amount
|
|
124
|
+
of time it takes to process a single training batch. This is not
|
|
125
|
+
guaranteed to execute at the exact time specified, but should be close.
|
|
126
|
+
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
|
|
127
|
+
every_n_epochs: Number of epochs between checkpoints.
|
|
128
|
+
This value must be ``None`` or non-negative.
|
|
129
|
+
To disable saving top-k checkpoints, set ``every_n_epochs = 0``.
|
|
130
|
+
This argument does not impact the saving of ``save_last=True`` checkpoints.
|
|
131
|
+
If all of ``every_n_epochs``, ``every_n_train_steps`` and
|
|
132
|
+
``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch
|
|
133
|
+
(equivalent to ``every_n_epochs = 1``).
|
|
134
|
+
If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``,
|
|
135
|
+
saving at the end of each epoch is disabled
|
|
136
|
+
(equivalent to ``every_n_epochs = 0``).
|
|
137
|
+
This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``.
|
|
138
|
+
Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and
|
|
139
|
+
``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
|
|
140
|
+
will only save checkpoints at epochs 0 < E <= N
|
|
141
|
+
where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
|
|
142
|
+
save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch.
|
|
143
|
+
If ``True``, checkpoints are saved at the end of every training epoch.
|
|
144
|
+
If ``False``, checkpoints are saved at the end of validation.
|
|
145
|
+
If ``None`` (default), checkpointing behavior is determined based on training configuration.
|
|
146
|
+
If ``val_check_interval`` is a str, dict, or `timedelta` (time-based), checkpointing is performed after
|
|
147
|
+
validation.
|
|
148
|
+
If ``check_val_every_n_epoch != 1``, checkpointing will not be performed at the end of
|
|
149
|
+
every training epoch. If there are no validation batches of data, checkpointing will occur at the
|
|
150
|
+
end of the training epoch. If there is a non-default number of validation runs per training epoch
|
|
151
|
+
(``val_check_interval != 1``), checkpointing is performed after validation.
|
|
152
|
+
enable_version_counter: Whether to append a version to the existing file name.
|
|
153
|
+
If ``False``, then the checkpoint files will be overwritten.
|
|
154
|
+
|
|
155
|
+
Note:
|
|
156
|
+
For extra customization, ModelCheckpoint includes the following attributes:
|
|
157
|
+
|
|
158
|
+
- ``CHECKPOINT_JOIN_CHAR = "-"``
|
|
159
|
+
- ``CHECKPOINT_EQUALS_CHAR = "="``
|
|
160
|
+
- ``CHECKPOINT_NAME_LAST = "last"``
|
|
161
|
+
- ``FILE_EXTENSION = ".ckpt"``
|
|
162
|
+
- ``STARTING_VERSION = 1``
|
|
163
|
+
|
|
164
|
+
For example, you can change the default last checkpoint name by doing
|
|
165
|
+
``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``
|
|
166
|
+
|
|
167
|
+
If you want to checkpoint every N hours, every M train batches, and/or every K val epochs,
|
|
168
|
+
then you should create multiple ``ModelCheckpoint`` callbacks.
|
|
169
|
+
|
|
170
|
+
If the checkpoint's ``dirpath`` changed from what it was before while resuming the training,
|
|
171
|
+
only ``best_model_path`` will be reloaded and a warning will be issued.
|
|
172
|
+
|
|
173
|
+
If you provide a ``filename`` on a mounted device where changing permissions is not allowed (causing ``chmod``
|
|
174
|
+
to raise a ``PermissionError``), install `fsspec>=2025.5.0`. Then the error is caught, the file's permissions
|
|
175
|
+
remain unchanged, and the checkpoint is still saved. Otherwise, no checkpoint will be saved and training stops.
|
|
176
|
+
|
|
177
|
+
Raises:
|
|
178
|
+
MisconfigurationException:
|
|
179
|
+
If ``save_top_k`` is smaller than ``-1``,
|
|
180
|
+
if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or
|
|
181
|
+
if ``mode`` is none of ``"min"`` or ``"max"``.
|
|
182
|
+
ValueError:
|
|
183
|
+
If ``trainer.save_checkpoint`` is ``None``.
|
|
184
|
+
|
|
185
|
+
Example::
|
|
186
|
+
|
|
187
|
+
>>> from lightning.pytorch import Trainer
|
|
188
|
+
>>> from lightning.pytorch.callbacks import ModelCheckpoint
|
|
189
|
+
|
|
190
|
+
# saves checkpoints to 'my/path/' at every epoch
|
|
191
|
+
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
|
|
192
|
+
>>> trainer = Trainer(callbacks=[checkpoint_callback])
|
|
193
|
+
|
|
194
|
+
# save epoch and val_loss in name
|
|
195
|
+
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
|
|
196
|
+
>>> checkpoint_callback = ModelCheckpoint(
|
|
197
|
+
... monitor='val_loss',
|
|
198
|
+
... dirpath='my/path/',
|
|
199
|
+
... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
|
|
200
|
+
... )
|
|
201
|
+
|
|
202
|
+
# save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
|
|
203
|
+
# or Neptune, due to the presence of characters like '=' or '/')
|
|
204
|
+
# saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
|
|
205
|
+
>>> checkpoint_callback = ModelCheckpoint(
|
|
206
|
+
... monitor='val/loss',
|
|
207
|
+
... dirpath='my/path/',
|
|
208
|
+
... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
|
|
209
|
+
... auto_insert_metric_name=False
|
|
210
|
+
... )
|
|
211
|
+
|
|
212
|
+
# retrieve the best checkpoint after training
|
|
213
|
+
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
|
|
214
|
+
>>> trainer = Trainer(callbacks=[checkpoint_callback])
|
|
215
|
+
>>> model = ... # doctest: +SKIP
|
|
216
|
+
>>> trainer.fit(model) # doctest: +SKIP
|
|
217
|
+
>>> print(checkpoint_callback.best_model_path) # doctest: +SKIP
|
|
218
|
+
|
|
219
|
+
.. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the
|
|
220
|
+
following arguments:
|
|
221
|
+
|
|
222
|
+
*monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval*
|
|
223
|
+
|
|
224
|
+
Read more: :ref:`Persisting Callback State <extensions/callbacks_state:save callback state>`
|
|
225
|
+
|
|
226
|
+
""" # noqa: D205
|
|
227
|
+
|
|
228
|
+
def __init__( # noqa: D107
|
|
229
|
+
self,
|
|
230
|
+
dirpath: _PATH | None = None,
|
|
231
|
+
filename: str | None = None,
|
|
232
|
+
monitor: str | None = None,
|
|
233
|
+
verbose: bool = False,
|
|
234
|
+
save_last: bool | Literal["link"] | None = None,
|
|
235
|
+
save_top_k: int = 1,
|
|
236
|
+
save_on_exception: bool = False,
|
|
237
|
+
save_weights_only: bool = False,
|
|
238
|
+
mode: str = "min",
|
|
239
|
+
auto_insert_metric_name: bool = True,
|
|
240
|
+
every_n_train_steps: int | None = None,
|
|
241
|
+
train_time_interval: timedelta | None = None,
|
|
242
|
+
every_n_epochs: int | None = None,
|
|
243
|
+
save_on_train_epoch_end: bool | None = None,
|
|
244
|
+
enable_version_counter: bool = True,
|
|
245
|
+
registry_uploader_callback: RegistryUploaderCallback | None = None,
|
|
246
|
+
) -> None:
|
|
247
|
+
super().__init__(
|
|
248
|
+
dirpath=dirpath,
|
|
249
|
+
filename=filename,
|
|
250
|
+
monitor=monitor,
|
|
251
|
+
verbose=verbose,
|
|
252
|
+
save_last=save_last,
|
|
253
|
+
save_top_k=save_top_k,
|
|
254
|
+
save_on_exception=save_on_exception,
|
|
255
|
+
save_weights_only=save_weights_only,
|
|
256
|
+
mode=mode,
|
|
257
|
+
auto_insert_metric_name=auto_insert_metric_name,
|
|
258
|
+
every_n_train_steps=every_n_train_steps,
|
|
259
|
+
train_time_interval=train_time_interval,
|
|
260
|
+
every_n_epochs=every_n_epochs,
|
|
261
|
+
save_on_train_epoch_end=save_on_train_epoch_end,
|
|
262
|
+
enable_version_counter=enable_version_counter,
|
|
263
|
+
)
|
|
264
|
+
self.registry_uploader_callback = registry_uploader_callback
|
|
265
|
+
self._custom_best_model_path = cast(str, self.best_model_path)
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
@property
|
|
269
|
+
def best_model_path(self) -> str:
|
|
270
|
+
"""Best model path."""
|
|
271
|
+
return self._custom_best_model_path
|
|
272
|
+
|
|
273
|
+
@best_model_path.setter
|
|
274
|
+
def best_model_path(self, value: str) -> None:
|
|
275
|
+
self._custom_best_model_path = value
|
|
276
|
+
if self.registry_uploader_callback is not None:
|
|
277
|
+
self.registry_uploader_callback.best_model_path = value
|
|
278
|
+
return
|
|
279
|
+
|
|
280
|
+
|
|
14
281
|
def setup_checkpoint_callback(
|
|
15
282
|
dirpath: Path,
|
|
16
283
|
ckpt_cfg: CheckpointConfig,
|
|
17
284
|
save_weights_only: bool = True,
|
|
18
|
-
|
|
285
|
+
registry_uploader_callback: RegistryUploaderCallback | None = None,
|
|
286
|
+
) -> CustomModelCheckpoint:
|
|
19
287
|
"""
|
|
20
288
|
Sets up a ModelCheckpoint callback for PyTorch Lightning.
|
|
21
289
|
|
|
@@ -30,6 +298,8 @@ def setup_checkpoint_callback(
|
|
|
30
298
|
settings such as filename, save_top_k, monitor, and mode.
|
|
31
299
|
save_weights_only (bool, optional): Whether to save only the model weights
|
|
32
300
|
or the full model. Defaults to True.
|
|
301
|
+
registry_uploader_callback (RegistryUploaderCallback | None, optional):
|
|
302
|
+
An optional callback for uploading checkpoints to a registry. Defaults to None.
|
|
33
303
|
|
|
34
304
|
Returns:
|
|
35
305
|
ModelCheckpoint: The configured ModelCheckpoint callback instance.
|
|
@@ -44,7 +314,7 @@ def setup_checkpoint_callback(
|
|
|
44
314
|
logger.info(f"Creating checkpoint directory {dirpath}.")
|
|
45
315
|
dirpath.mkdir(parents=True, exist_ok=True)
|
|
46
316
|
|
|
47
|
-
checkpoint_callback =
|
|
317
|
+
checkpoint_callback = CustomModelCheckpoint(
|
|
48
318
|
dirpath=dirpath,
|
|
49
319
|
filename=ckpt_cfg.filename,
|
|
50
320
|
save_top_k=ckpt_cfg.save_top_k,
|
|
@@ -52,5 +322,6 @@ def setup_checkpoint_callback(
|
|
|
52
322
|
mode=ckpt_cfg.mode,
|
|
53
323
|
verbose=True,
|
|
54
324
|
save_weights_only=save_weights_only,
|
|
325
|
+
registry_uploader_callback=registry_uploader_callback,
|
|
55
326
|
)
|
|
56
327
|
return checkpoint_callback
|
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from functools import partial
|
|
1
5
|
from typing import Literal
|
|
2
6
|
from typing import override
|
|
3
7
|
|
|
@@ -5,7 +9,6 @@ from clearml import OutputModel
|
|
|
5
9
|
from clearml import Task
|
|
6
10
|
from lightning import Trainer
|
|
7
11
|
from lightning.pytorch.callbacks import Callback
|
|
8
|
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
9
12
|
|
|
10
13
|
from kostyl.ml.clearml.logging_utils import find_version_in_tags
|
|
11
14
|
from kostyl.ml.clearml.logging_utils import increment_version
|
|
@@ -16,13 +19,31 @@ from kostyl.utils.logging import setup_logger
|
|
|
16
19
|
logger = setup_logger()
|
|
17
20
|
|
|
18
21
|
|
|
19
|
-
class
|
|
22
|
+
class RegistryUploaderCallback(Callback, ABC):
|
|
23
|
+
"""Abstract Lightning callback responsible for tracking and uploading the best-performing model checkpoint."""
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def best_model_path(self) -> str:
|
|
28
|
+
"""Return the file system path pointing to the best model artifact produced during training."""
|
|
29
|
+
raise NotImplementedError
|
|
30
|
+
|
|
31
|
+
@best_model_path.setter
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def best_model_path(self, value: str) -> None:
|
|
34
|
+
raise NotImplementedError
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
20
42
|
"""PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
|
|
21
43
|
|
|
22
44
|
def __init__(
|
|
23
45
|
self,
|
|
24
46
|
task: Task,
|
|
25
|
-
ckpt_callback: ModelCheckpoint,
|
|
26
47
|
output_model_name: str,
|
|
27
48
|
output_model_tags: list[str] | None = None,
|
|
28
49
|
verbose: bool = True,
|
|
@@ -56,7 +77,6 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
56
77
|
output_model_tags = []
|
|
57
78
|
|
|
58
79
|
self.task = task
|
|
59
|
-
self.ckpt_callback = ckpt_callback
|
|
60
80
|
self.output_model_name = output_model_name
|
|
61
81
|
self.output_model_tags = output_model_tags
|
|
62
82
|
self.config_dict = config_dict
|
|
@@ -66,7 +86,22 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
66
86
|
self.enable_tag_versioning = enable_tag_versioning
|
|
67
87
|
|
|
68
88
|
self._output_model: OutputModel | None = None
|
|
69
|
-
self.
|
|
89
|
+
self._last_uploaded_model_path: str = ""
|
|
90
|
+
self._best_model_path: str = ""
|
|
91
|
+
self._upload_callback: Callable | None = None
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
@override
|
|
96
|
+
def best_model_path(self) -> str:
|
|
97
|
+
return self._best_model_path
|
|
98
|
+
|
|
99
|
+
@best_model_path.setter
|
|
100
|
+
@override
|
|
101
|
+
def best_model_path(self, value: str) -> None:
|
|
102
|
+
self._best_model_path = value
|
|
103
|
+
if self._upload_callback is not None:
|
|
104
|
+
self._upload_callback()
|
|
70
105
|
return
|
|
71
106
|
|
|
72
107
|
def _create_output_model(self, pl_module: "KostylLightningModule") -> OutputModel:
|
|
@@ -98,27 +133,29 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
98
133
|
label_enumeration=self.label_enumeration,
|
|
99
134
|
)
|
|
100
135
|
|
|
136
|
+
@override
|
|
101
137
|
def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
if self.
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
138
|
+
if not self._best_model_path or (
|
|
139
|
+
self._best_model_path == self._last_uploaded_model_path
|
|
140
|
+
):
|
|
141
|
+
if not self._best_model_path:
|
|
142
|
+
if self.verbose:
|
|
143
|
+
logger.info("No best model found yet to upload")
|
|
144
|
+
elif self._best_model_path == self._last_uploaded_model_path:
|
|
145
|
+
if self.verbose:
|
|
146
|
+
logger.info("Best model unchanged since last upload")
|
|
147
|
+
self._upload_callback = partial(self._upload_best_checkpoint, pl_module)
|
|
112
148
|
return
|
|
149
|
+
self._upload_callback = None
|
|
113
150
|
|
|
114
151
|
if self._output_model is None:
|
|
115
152
|
self._output_model = self._create_output_model(pl_module)
|
|
116
153
|
|
|
117
154
|
if self.verbose:
|
|
118
|
-
logger.info(f"Uploading best model from {
|
|
155
|
+
logger.info(f"Uploading best model from {self._best_model_path}")
|
|
119
156
|
|
|
120
157
|
self._output_model.update_weights(
|
|
121
|
-
|
|
158
|
+
self._best_model_path,
|
|
122
159
|
auto_delete_file=False,
|
|
123
160
|
async_enable=False,
|
|
124
161
|
)
|
|
@@ -130,7 +167,7 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
130
167
|
config = self.config_dict
|
|
131
168
|
self._output_model.update_design(config_dict=config)
|
|
132
169
|
|
|
133
|
-
self.
|
|
170
|
+
self._last_uploaded_model_path = self._best_model_path
|
|
134
171
|
return
|
|
135
172
|
|
|
136
173
|
@override
|
|
@@ -6,15 +6,7 @@ import torch
|
|
|
6
6
|
from transformers import PretrainedConfig
|
|
7
7
|
from transformers import PreTrainedModel
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
try:
|
|
11
|
-
from peft import PeftConfig
|
|
12
|
-
except ImportError:
|
|
13
|
-
PeftConfig = None # ty: ignore
|
|
14
|
-
|
|
15
|
-
from kostyl.utils.logging import log_incompatible_keys
|
|
16
9
|
from kostyl.utils.logging import setup_logger
|
|
17
|
-
from torch import nn
|
|
18
10
|
|
|
19
11
|
|
|
20
12
|
logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
|
|
@@ -24,12 +16,11 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
24
16
|
"""A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
|
|
25
17
|
|
|
26
18
|
@classmethod
|
|
27
|
-
def
|
|
19
|
+
def from_lightning_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
|
|
28
20
|
cls: type[TModelInstance],
|
|
29
21
|
checkpoint_path: str | Path,
|
|
30
22
|
config_key: str = "config",
|
|
31
23
|
weights_prefix: str = "model.",
|
|
32
|
-
should_log_incompatible_keys: bool = True,
|
|
33
24
|
**kwargs: Any,
|
|
34
25
|
) -> TModelInstance:
|
|
35
26
|
"""
|
|
@@ -50,8 +41,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
50
41
|
Defaults to "config".
|
|
51
42
|
weights_prefix (str, optional): Prefix to strip from state dict keys. Defaults to "model.".
|
|
52
43
|
If not empty and doesn't end with ".", a "." is appended.
|
|
53
|
-
|
|
54
|
-
**kwargs: Additional keyword arguments to pass to the model loading method.
|
|
44
|
+
kwargs: Additional keyword arguments to pass to the model's `from_pretrained` method.
|
|
55
45
|
|
|
56
46
|
Returns:
|
|
57
47
|
TModelInstance: The loaded model instance.
|
|
@@ -89,114 +79,27 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
89
79
|
if not hasattr(config, key):
|
|
90
80
|
kwargs_for_model[key] = value
|
|
91
81
|
|
|
92
|
-
with torch.device("meta"):
|
|
93
|
-
model = cls(config, **kwargs_for_model)
|
|
94
|
-
|
|
95
|
-
# PEFT-адаптеры (оставляю твою логику как есть)
|
|
96
|
-
if "peft_config" in checkpoint_dict:
|
|
97
|
-
if PeftConfig is None:
|
|
98
|
-
raise ImportError(
|
|
99
|
-
"peft is not installed. Please install it to load PEFT models."
|
|
100
|
-
)
|
|
101
|
-
for name, adapter_dict in checkpoint_dict["peft_config"].items():
|
|
102
|
-
peft_cfg = PeftConfig.from_peft_type(**adapter_dict)
|
|
103
|
-
model.add_adapter(peft_cfg, adapter_name=name)
|
|
104
|
-
|
|
105
|
-
incompatible_keys: dict[str, list[str]] = {}
|
|
106
|
-
|
|
107
82
|
raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
|
|
108
83
|
|
|
109
84
|
if weights_prefix:
|
|
110
85
|
if not weights_prefix.endswith("."):
|
|
111
86
|
weights_prefix = weights_prefix + "."
|
|
112
87
|
state_dict: dict[str, torch.Tensor] = {}
|
|
113
|
-
mismatched_keys: list[str] = []
|
|
114
88
|
|
|
115
89
|
for key, value in raw_state_dict.items():
|
|
116
90
|
if key.startswith(weights_prefix):
|
|
117
91
|
new_key = key[len(weights_prefix) :]
|
|
118
92
|
state_dict[new_key] = value
|
|
119
93
|
else:
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
if mismatched_keys:
|
|
123
|
-
incompatible_keys["mismatched_keys"] = mismatched_keys
|
|
94
|
+
state_dict[key] = value
|
|
124
95
|
else:
|
|
125
96
|
state_dict = raw_state_dict
|
|
126
97
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
model_to_load: nn.Module = model
|
|
133
|
-
|
|
134
|
-
if base_prefix:
|
|
135
|
-
prefix_with_dot = base_prefix + "."
|
|
136
|
-
loaded_keys = list(state_dict.keys())
|
|
137
|
-
full_model_state = model.state_dict()
|
|
138
|
-
expected_keys = list(full_model_state.keys())
|
|
139
|
-
|
|
140
|
-
has_prefix_module = any(k.startswith(prefix_with_dot) for k in loaded_keys)
|
|
141
|
-
expects_prefix_module = any(
|
|
142
|
-
k.startswith(prefix_with_dot) for k in expected_keys
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
# Кейc 1: загружаем базовую модель в модель с головой.
|
|
146
|
-
# Пример: StaticEmbeddingsForSequenceClassification (имеет .model)
|
|
147
|
-
# state_dict с ключами "embeddings.weight", "token_pos_weights", ...
|
|
148
|
-
if (
|
|
149
|
-
hasattr(model, base_prefix)
|
|
150
|
-
and not has_prefix_module
|
|
151
|
-
and expects_prefix_module
|
|
152
|
-
):
|
|
153
|
-
# Веса без префикса -> грузим только в model.<base_prefix>
|
|
154
|
-
model_to_load = getattr(model, base_prefix)
|
|
155
|
-
|
|
156
|
-
# Кейc 2: загружаем чекпоинт модели с головой в базовую модель.
|
|
157
|
-
# Пример: BertModel, а в state_dict ключи "bert.encoder.layer.0..."
|
|
158
|
-
elif (
|
|
159
|
-
not hasattr(model, base_prefix)
|
|
160
|
-
and has_prefix_module
|
|
161
|
-
and not expects_prefix_module
|
|
162
|
-
):
|
|
163
|
-
new_state_dict: dict[str, torch.Tensor] = {}
|
|
164
|
-
for key, value in state_dict.items():
|
|
165
|
-
if key.startswith(prefix_with_dot):
|
|
166
|
-
new_key = key[len(prefix_with_dot) :]
|
|
167
|
-
else:
|
|
168
|
-
new_key = key
|
|
169
|
-
new_state_dict[new_key] = value
|
|
170
|
-
state_dict = new_state_dict
|
|
171
|
-
|
|
172
|
-
load_result = model_to_load.load_state_dict(
|
|
173
|
-
state_dict, strict=False, assign=True
|
|
98
|
+
model = cls.from_pretrained(
|
|
99
|
+
pretrained_model_name_or_path=None,
|
|
100
|
+
config=config,
|
|
101
|
+
state_dict=state_dict,
|
|
102
|
+
**kwargs_for_model,
|
|
174
103
|
)
|
|
175
|
-
missing_keys, unexpected_keys = (
|
|
176
|
-
load_result.missing_keys,
|
|
177
|
-
load_result.unexpected_keys,
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
# Если мы грузили только в base-подмодуль, расширим missing_keys
|
|
181
|
-
# до полного списка (base + голова), как в старых версиях HF.
|
|
182
|
-
if model_to_load is not model and base_prefix:
|
|
183
|
-
base_keys = set(model_to_load.state_dict().keys())
|
|
184
|
-
# Приводим ключи полной модели к "безпрефиксному" виду
|
|
185
|
-
head_like_keys = set()
|
|
186
|
-
prefix_with_dot = base_prefix + "."
|
|
187
|
-
for k in model.state_dict().keys():
|
|
188
|
-
if k.startswith(prefix_with_dot):
|
|
189
|
-
# отрезаем "model."
|
|
190
|
-
head_like_keys.add(k[len(prefix_with_dot) :])
|
|
191
|
-
else:
|
|
192
|
-
head_like_keys.add(k)
|
|
193
|
-
extra_missing = sorted(head_like_keys - base_keys)
|
|
194
|
-
missing_keys = list(missing_keys) + extra_missing
|
|
195
|
-
|
|
196
|
-
incompatible_keys["missing_keys"] = missing_keys
|
|
197
|
-
incompatible_keys["unexpected_keys"] = unexpected_keys
|
|
198
|
-
|
|
199
|
-
if should_log_incompatible_keys:
|
|
200
|
-
log_incompatible_keys(incompatible_keys=incompatible_keys, logger=logger)
|
|
201
104
|
|
|
202
105
|
return model
|
|
@@ -3,20 +3,20 @@ kostyl/ml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
3
3
|
kostyl/ml/clearml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
4
|
kostyl/ml/clearml/dataset_utils.py,sha256=eij_sr2KDhm8GxEbVbK8aBjPsuVvLl9-PIGGaKVgXLA,1729
|
|
5
5
|
kostyl/ml/clearml/logging_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWaJTyq8,1210
|
|
6
|
-
kostyl/ml/clearml/pulling_utils.py,sha256=
|
|
6
|
+
kostyl/ml/clearml/pulling_utils.py,sha256=cNa_-_5LHjNVYi9btXBrfl5sPvI6BAAlIFidtpKu310,4078
|
|
7
7
|
kostyl/ml/configs/__init__.py,sha256=IetcivbqYGutowLqxdKp7QR4tkXKBr4m8t4Zkk9jHZU,911
|
|
8
8
|
kostyl/ml/configs/base_model.py,sha256=Eofn14J9RsjpVx_J4rp6C19pDDCANU4hr3JtX-d0FpQ,4820
|
|
9
9
|
kostyl/ml/configs/hyperparams.py,sha256=OqN7mEj3zc5MTqBPCZL3Lcd2VCTDLo_K0yvhRWGfhCs,2924
|
|
10
10
|
kostyl/ml/configs/training_settings.py,sha256=Sq2tiRuwkbmi9zKDG2JghZLXo5DDt_eQqN_KYJSdcTY,2509
|
|
11
11
|
kostyl/ml/dist_utils.py,sha256=Onf0KHVLA8oeUgZTcTdmR9qiM22f2uYLoNwgLbMGJWk,3495
|
|
12
12
|
kostyl/ml/lightning/__init__.py,sha256=-F3JAyq8KU1d-nACWryGu8d1CbvWbQ1rXFdeRwfE2X8,175
|
|
13
|
-
kostyl/ml/lightning/callbacks/__init__.py,sha256=
|
|
14
|
-
kostyl/ml/lightning/callbacks/checkpoint.py,sha256=
|
|
13
|
+
kostyl/ml/lightning/callbacks/__init__.py,sha256=enexQt3octktsTiEYHltSF_24CM-NeFEVFimXiavGiY,296
|
|
14
|
+
kostyl/ml/lightning/callbacks/checkpoint.py,sha256=rJ05S7BnPopvQjV5b9CI29-S7-ySllyhG7PIzz34VyY,16793
|
|
15
15
|
kostyl/ml/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
|
|
16
|
-
kostyl/ml/lightning/callbacks/
|
|
16
|
+
kostyl/ml/lightning/callbacks/registry_uploader.py,sha256=ksoh02dzIde4E_GaZykfiOgfSjZti-IJt_i61enem3s,6779
|
|
17
17
|
kostyl/ml/lightning/extenstions/__init__.py,sha256=OY6QGv1agYgqqKf1xJBrxgp_i8FunVfPzYezfaRrGXU,182
|
|
18
18
|
kostyl/ml/lightning/extenstions/custom_module.py,sha256=nB5jW7cqRD1tyh-q5LD2EtiFQwFkLXpnS9Yu6c5xMRg,5987
|
|
19
|
-
kostyl/ml/lightning/extenstions/pretrained_model.py,sha256=
|
|
19
|
+
kostyl/ml/lightning/extenstions/pretrained_model.py,sha256=QJGr2UvYJcU2Gy2w8z_cEvTodjv7hGdd2PPPfdOI-Mw,4017
|
|
20
20
|
kostyl/ml/lightning/loggers/__init__.py,sha256=e51dszaoJbuzwBkbdugmuDsPldoSO4yaRgmZUg1Bdy0,71
|
|
21
21
|
kostyl/ml/lightning/loggers/tb_logger.py,sha256=j02HK5ue8yzXXV8FWKmmXyHkFlIxgHx-ahHWk_rFCZs,893
|
|
22
22
|
kostyl/ml/lightning/steps_estimation.py,sha256=fTZ0IrUEZV3H6VYlx4GYn56oco56mMiB7FO9F0Z7qc4,1511
|
|
@@ -30,6 +30,6 @@ kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
|
|
|
30
30
|
kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
|
|
31
31
|
kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
|
|
32
32
|
kostyl/utils/logging.py,sha256=Vye0u4-yeOSUc-f03gpQbxSktTbFiilTWLEVr00ZHvc,5796
|
|
33
|
-
kostyl_toolkit-0.1.
|
|
34
|
-
kostyl_toolkit-0.1.
|
|
35
|
-
kostyl_toolkit-0.1.
|
|
33
|
+
kostyl_toolkit-0.1.24.dist-info/WHEEL,sha256=z-mOpxbJHqy3cq6SvUThBZdaLGFZzdZPtgWLcP2NKjQ,79
|
|
34
|
+
kostyl_toolkit-0.1.24.dist-info/METADATA,sha256=uq8MPJ9vJgWsp9Z2c7C9tcbaH29QM9ux7_SyahPSlHE,4269
|
|
35
|
+
kostyl_toolkit-0.1.24.dist-info/RECORD,,
|
|
File without changes
|