kostyl-toolkit 0.1.22__tar.gz → 0.1.23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/PKG-INFO +1 -1
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/callbacks/__init__.py +1 -1
- kostyl_toolkit-0.1.23/kostyl/ml/lightning/callbacks/checkpoint.py +327 -0
- kostyl_toolkit-0.1.22/kostyl/ml/lightning/callbacks/registry_uploading.py → kostyl_toolkit-0.1.23/kostyl/ml/lightning/callbacks/registry_uploader.py +55 -18
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/pyproject.toml +1 -1
- kostyl_toolkit-0.1.22/kostyl/ml/lightning/callbacks/checkpoint.py +0 -56
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/README.md +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/__init__.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/__init__.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/clearml/__init__.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/clearml/dataset_utils.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/clearml/logging_utils.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/clearml/pulling_utils.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/configs/__init__.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/configs/base_model.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/configs/hyperparams.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/configs/training_settings.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/dist_utils.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/__init__.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/callbacks/early_stopping.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/extenstions/__init__.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/extenstions/custom_module.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/extenstions/pretrained_model.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/loggers/__init__.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/loggers/tb_logger.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/steps_estimation.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/metrics_formatting.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/params_groups.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/schedulers/__init__.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/schedulers/base.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/schedulers/composite.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/schedulers/cosine.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/utils/__init__.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/utils/dict_manipulations.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/utils/fs.py +0 -0
- {kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/utils/logging.py +0 -0
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
from datetime import timedelta
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from shutil import rmtree
|
|
4
|
+
from typing import Literal
|
|
5
|
+
from typing import cast
|
|
6
|
+
|
|
7
|
+
from lightning.fabric.utilities.types import _PATH
|
|
8
|
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
9
|
+
|
|
10
|
+
from kostyl.ml.configs import CheckpointConfig
|
|
11
|
+
from kostyl.ml.dist_utils import is_main_process
|
|
12
|
+
from kostyl.utils import setup_logger
|
|
13
|
+
|
|
14
|
+
from .registry_uploader import RegistryUploaderCallback
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
logger = setup_logger("callbacks/checkpoint.py")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CustomModelCheckpoint(ModelCheckpoint):
|
|
21
|
+
r"""
|
|
22
|
+
Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
|
|
23
|
+
:class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
|
|
24
|
+
checkpoint.
|
|
25
|
+
|
|
26
|
+
After training finishes, use :attr:`best_model_path` to retrieve the path to the
|
|
27
|
+
best checkpoint file and :attr:`best_model_score` to get its score.
|
|
28
|
+
|
|
29
|
+
.. note::
|
|
30
|
+
When using manual optimization with ``every_n_train_steps``, you should save the model state
|
|
31
|
+
in your ``training_step`` before the optimizer step if you want the checkpoint to reflect
|
|
32
|
+
the pre-optimization state. Example:
|
|
33
|
+
|
|
34
|
+
.. code-block:: python
|
|
35
|
+
|
|
36
|
+
def training_step(self, batch, batch_idx):
|
|
37
|
+
# ... forward pass, loss calculation, backward pass ...
|
|
38
|
+
|
|
39
|
+
# Save model state before optimization
|
|
40
|
+
if not hasattr(self, 'saved_models'):
|
|
41
|
+
self.saved_models = {}
|
|
42
|
+
self.saved_models[batch_idx] = {
|
|
43
|
+
k: v.detach().clone()
|
|
44
|
+
for k, v in self.layer.state_dict().items()
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
# Then perform optimization
|
|
48
|
+
optimizer.zero_grad()
|
|
49
|
+
self.manual_backward(loss)
|
|
50
|
+
optimizer.step()
|
|
51
|
+
|
|
52
|
+
# Optional: Clean up old states to save memory
|
|
53
|
+
if batch_idx > 10: # Keep last 10 states
|
|
54
|
+
del self.saved_models[batch_idx - 10]
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
dirpath: Directory to save the model file.
|
|
58
|
+
Example: ``dirpath='my/path/'``.
|
|
59
|
+
|
|
60
|
+
.. warning::
|
|
61
|
+
In a distributed environment like DDP, it's recommended to provide a `dirpath` to avoid race conditions.
|
|
62
|
+
When using manual optimization with ``every_n_train_steps``, make sure to save the model state
|
|
63
|
+
in your training loop as shown in the example above.
|
|
64
|
+
|
|
65
|
+
Can be remote file paths such as `s3://mybucket/path/` or 'hdfs://path/'
|
|
66
|
+
(default: ``None``). If dirpath is ``None``, we only keep the ``k`` best checkpoints
|
|
67
|
+
in memory, and do not save anything to disk.
|
|
68
|
+
|
|
69
|
+
filename: Checkpoint filename. Can contain named formatting options to be auto-filled.
|
|
70
|
+
If no name is provided, it will be ``None`` and the checkpoint will be saved to
|
|
71
|
+
``{epoch}``.and if the Trainer uses a logger, the path will also contain logger name and version.
|
|
72
|
+
|
|
73
|
+
filename: checkpoint filename. Can contain named formatting options to be auto-filled.
|
|
74
|
+
|
|
75
|
+
Example::
|
|
76
|
+
|
|
77
|
+
# save any arbitrary metrics like `val_loss`, etc. in name
|
|
78
|
+
# saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
|
|
79
|
+
>>> checkpoint_callback = ModelCheckpoint(
|
|
80
|
+
... dirpath='my/path',
|
|
81
|
+
... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
|
|
82
|
+
... )
|
|
83
|
+
|
|
84
|
+
By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``, where "epoch" and "step" match
|
|
85
|
+
the number of finished epoch and optimizer steps respectively.
|
|
86
|
+
monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch.
|
|
87
|
+
verbose: verbosity mode. Default: ``False``.
|
|
88
|
+
save_last: When ``True``, saves a `last.ckpt` copy whenever a checkpoint file gets saved. Can be set to
|
|
89
|
+
``'link'`` on a local filesystem to create a symbolic link. This allows accessing the latest checkpoint
|
|
90
|
+
in a deterministic manner. Default: ``None``.
|
|
91
|
+
save_top_k: if ``save_top_k == k``,
|
|
92
|
+
the best k models according to the quantity monitored will be saved.
|
|
93
|
+
If ``save_top_k == 0``, no models are saved.
|
|
94
|
+
If ``save_top_k == -1``, all models are saved.
|
|
95
|
+
Please note that the monitors are checked every ``every_n_epochs`` epochs.
|
|
96
|
+
If ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, and the filename remains
|
|
97
|
+
unchanged, the name of the saved file will be appended with a version count starting with ``v1`` to avoid
|
|
98
|
+
collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k
|
|
99
|
+
ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid
|
|
100
|
+
collisions.
|
|
101
|
+
save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``False``.
|
|
102
|
+
mode: one of {min, max}.
|
|
103
|
+
If ``save_top_k != 0``, the decision to overwrite the current save file is made
|
|
104
|
+
based on either the maximization or the minimization of the monitored quantity.
|
|
105
|
+
For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc.
|
|
106
|
+
auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name.
|
|
107
|
+
For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve
|
|
108
|
+
to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/``
|
|
109
|
+
as this will result in extra folders.
|
|
110
|
+
For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False``
|
|
111
|
+
save_weights_only: if ``True``, then only the model's weights will be
|
|
112
|
+
saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
|
|
113
|
+
every_n_train_steps: How many training steps to wait before saving a checkpoint. This does not take into account
|
|
114
|
+
the steps of the current epoch. If ``every_n_train_steps == None or every_n_train_steps == 0``,
|
|
115
|
+
no checkpoints
|
|
116
|
+
will be saved during training. Mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
|
|
117
|
+
|
|
118
|
+
.. note::
|
|
119
|
+
When using with manual optimization, the checkpoint will be saved after the optimizer step by default.
|
|
120
|
+
To save the model state before the optimizer step, you need to save the model state in your
|
|
121
|
+
``training_step`` before calling ``optimizer.step()``. See the class docstring for an example.
|
|
122
|
+
train_time_interval: Checkpoints are monitored at the specified time interval.
|
|
123
|
+
For all practical purposes, this cannot be smaller than the amount
|
|
124
|
+
of time it takes to process a single training batch. This is not
|
|
125
|
+
guaranteed to execute at the exact time specified, but should be close.
|
|
126
|
+
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
|
|
127
|
+
every_n_epochs: Number of epochs between checkpoints.
|
|
128
|
+
This value must be ``None`` or non-negative.
|
|
129
|
+
To disable saving top-k checkpoints, set ``every_n_epochs = 0``.
|
|
130
|
+
This argument does not impact the saving of ``save_last=True`` checkpoints.
|
|
131
|
+
If all of ``every_n_epochs``, ``every_n_train_steps`` and
|
|
132
|
+
``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch
|
|
133
|
+
(equivalent to ``every_n_epochs = 1``).
|
|
134
|
+
If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``,
|
|
135
|
+
saving at the end of each epoch is disabled
|
|
136
|
+
(equivalent to ``every_n_epochs = 0``).
|
|
137
|
+
This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``.
|
|
138
|
+
Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and
|
|
139
|
+
``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
|
|
140
|
+
will only save checkpoints at epochs 0 < E <= N
|
|
141
|
+
where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
|
|
142
|
+
save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch.
|
|
143
|
+
If ``True``, checkpoints are saved at the end of every training epoch.
|
|
144
|
+
If ``False``, checkpoints are saved at the end of validation.
|
|
145
|
+
If ``None`` (default), checkpointing behavior is determined based on training configuration.
|
|
146
|
+
If ``val_check_interval`` is a str, dict, or `timedelta` (time-based), checkpointing is performed after
|
|
147
|
+
validation.
|
|
148
|
+
If ``check_val_every_n_epoch != 1``, checkpointing will not be performed at the end of
|
|
149
|
+
every training epoch. If there are no validation batches of data, checkpointing will occur at the
|
|
150
|
+
end of the training epoch. If there is a non-default number of validation runs per training epoch
|
|
151
|
+
(``val_check_interval != 1``), checkpointing is performed after validation.
|
|
152
|
+
enable_version_counter: Whether to append a version to the existing file name.
|
|
153
|
+
If ``False``, then the checkpoint files will be overwritten.
|
|
154
|
+
|
|
155
|
+
Note:
|
|
156
|
+
For extra customization, ModelCheckpoint includes the following attributes:
|
|
157
|
+
|
|
158
|
+
- ``CHECKPOINT_JOIN_CHAR = "-"``
|
|
159
|
+
- ``CHECKPOINT_EQUALS_CHAR = "="``
|
|
160
|
+
- ``CHECKPOINT_NAME_LAST = "last"``
|
|
161
|
+
- ``FILE_EXTENSION = ".ckpt"``
|
|
162
|
+
- ``STARTING_VERSION = 1``
|
|
163
|
+
|
|
164
|
+
For example, you can change the default last checkpoint name by doing
|
|
165
|
+
``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``
|
|
166
|
+
|
|
167
|
+
If you want to checkpoint every N hours, every M train batches, and/or every K val epochs,
|
|
168
|
+
then you should create multiple ``ModelCheckpoint`` callbacks.
|
|
169
|
+
|
|
170
|
+
If the checkpoint's ``dirpath`` changed from what it was before while resuming the training,
|
|
171
|
+
only ``best_model_path`` will be reloaded and a warning will be issued.
|
|
172
|
+
|
|
173
|
+
If you provide a ``filename`` on a mounted device where changing permissions is not allowed (causing ``chmod``
|
|
174
|
+
to raise a ``PermissionError``), install `fsspec>=2025.5.0`. Then the error is caught, the file's permissions
|
|
175
|
+
remain unchanged, and the checkpoint is still saved. Otherwise, no checkpoint will be saved and training stops.
|
|
176
|
+
|
|
177
|
+
Raises:
|
|
178
|
+
MisconfigurationException:
|
|
179
|
+
If ``save_top_k`` is smaller than ``-1``,
|
|
180
|
+
if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or
|
|
181
|
+
if ``mode`` is none of ``"min"`` or ``"max"``.
|
|
182
|
+
ValueError:
|
|
183
|
+
If ``trainer.save_checkpoint`` is ``None``.
|
|
184
|
+
|
|
185
|
+
Example::
|
|
186
|
+
|
|
187
|
+
>>> from lightning.pytorch import Trainer
|
|
188
|
+
>>> from lightning.pytorch.callbacks import ModelCheckpoint
|
|
189
|
+
|
|
190
|
+
# saves checkpoints to 'my/path/' at every epoch
|
|
191
|
+
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
|
|
192
|
+
>>> trainer = Trainer(callbacks=[checkpoint_callback])
|
|
193
|
+
|
|
194
|
+
# save epoch and val_loss in name
|
|
195
|
+
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
|
|
196
|
+
>>> checkpoint_callback = ModelCheckpoint(
|
|
197
|
+
... monitor='val_loss',
|
|
198
|
+
... dirpath='my/path/',
|
|
199
|
+
... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
|
|
200
|
+
... )
|
|
201
|
+
|
|
202
|
+
# save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
|
|
203
|
+
# or Neptune, due to the presence of characters like '=' or '/')
|
|
204
|
+
# saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
|
|
205
|
+
>>> checkpoint_callback = ModelCheckpoint(
|
|
206
|
+
... monitor='val/loss',
|
|
207
|
+
... dirpath='my/path/',
|
|
208
|
+
... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
|
|
209
|
+
... auto_insert_metric_name=False
|
|
210
|
+
... )
|
|
211
|
+
|
|
212
|
+
# retrieve the best checkpoint after training
|
|
213
|
+
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
|
|
214
|
+
>>> trainer = Trainer(callbacks=[checkpoint_callback])
|
|
215
|
+
>>> model = ... # doctest: +SKIP
|
|
216
|
+
>>> trainer.fit(model) # doctest: +SKIP
|
|
217
|
+
>>> print(checkpoint_callback.best_model_path) # doctest: +SKIP
|
|
218
|
+
|
|
219
|
+
.. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the
|
|
220
|
+
following arguments:
|
|
221
|
+
|
|
222
|
+
*monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval*
|
|
223
|
+
|
|
224
|
+
Read more: :ref:`Persisting Callback State <extensions/callbacks_state:save callback state>`
|
|
225
|
+
|
|
226
|
+
""" # noqa: D205
|
|
227
|
+
|
|
228
|
+
def __init__( # noqa: D107
|
|
229
|
+
self,
|
|
230
|
+
dirpath: _PATH | None = None,
|
|
231
|
+
filename: str | None = None,
|
|
232
|
+
monitor: str | None = None,
|
|
233
|
+
verbose: bool = False,
|
|
234
|
+
save_last: bool | Literal["link"] | None = None,
|
|
235
|
+
save_top_k: int = 1,
|
|
236
|
+
save_on_exception: bool = False,
|
|
237
|
+
save_weights_only: bool = False,
|
|
238
|
+
mode: str = "min",
|
|
239
|
+
auto_insert_metric_name: bool = True,
|
|
240
|
+
every_n_train_steps: int | None = None,
|
|
241
|
+
train_time_interval: timedelta | None = None,
|
|
242
|
+
every_n_epochs: int | None = None,
|
|
243
|
+
save_on_train_epoch_end: bool | None = None,
|
|
244
|
+
enable_version_counter: bool = True,
|
|
245
|
+
registry_uploader_callback: RegistryUploaderCallback | None = None,
|
|
246
|
+
) -> None:
|
|
247
|
+
super().__init__(
|
|
248
|
+
dirpath=dirpath,
|
|
249
|
+
filename=filename,
|
|
250
|
+
monitor=monitor,
|
|
251
|
+
verbose=verbose,
|
|
252
|
+
save_last=save_last,
|
|
253
|
+
save_top_k=save_top_k,
|
|
254
|
+
save_on_exception=save_on_exception,
|
|
255
|
+
save_weights_only=save_weights_only,
|
|
256
|
+
mode=mode,
|
|
257
|
+
auto_insert_metric_name=auto_insert_metric_name,
|
|
258
|
+
every_n_train_steps=every_n_train_steps,
|
|
259
|
+
train_time_interval=train_time_interval,
|
|
260
|
+
every_n_epochs=every_n_epochs,
|
|
261
|
+
save_on_train_epoch_end=save_on_train_epoch_end,
|
|
262
|
+
enable_version_counter=enable_version_counter,
|
|
263
|
+
)
|
|
264
|
+
self.registry_uploader_callback = registry_uploader_callback
|
|
265
|
+
self._custom_best_model_path = cast(str, self.best_model_path)
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
@property
|
|
269
|
+
def best_model_path(self) -> str:
|
|
270
|
+
"""Best model path."""
|
|
271
|
+
return self._custom_best_model_path
|
|
272
|
+
|
|
273
|
+
@best_model_path.setter
|
|
274
|
+
def best_model_path(self, value: str) -> None:
|
|
275
|
+
self._custom_best_model_path = value
|
|
276
|
+
if self.registry_uploader_callback is not None:
|
|
277
|
+
self.registry_uploader_callback.best_model_path = value
|
|
278
|
+
return
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def setup_checkpoint_callback(
|
|
282
|
+
dirpath: Path,
|
|
283
|
+
ckpt_cfg: CheckpointConfig,
|
|
284
|
+
save_weights_only: bool = True,
|
|
285
|
+
registry_uploader_callback: RegistryUploaderCallback | None = None,
|
|
286
|
+
) -> CustomModelCheckpoint:
|
|
287
|
+
"""
|
|
288
|
+
Sets up a ModelCheckpoint callback for PyTorch Lightning.
|
|
289
|
+
|
|
290
|
+
This function prepares a checkpoint directory and configures a ModelCheckpoint
|
|
291
|
+
callback based on the provided configuration. If the directory already exists,
|
|
292
|
+
it is removed (only by the main process) to ensure a clean start. Otherwise,
|
|
293
|
+
the directory is created.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
dirpath (Path): The path to the directory where checkpoints will be saved.
|
|
297
|
+
ckpt_cfg (CheckpointConfig): Configuration object containing checkpoint
|
|
298
|
+
settings such as filename, save_top_k, monitor, and mode.
|
|
299
|
+
save_weights_only (bool, optional): Whether to save only the model weights
|
|
300
|
+
or the full model. Defaults to True.
|
|
301
|
+
registry_uploader_callback (RegistryUploaderCallback | None, optional):
|
|
302
|
+
An optional callback for uploading checkpoints to a registry. Defaults to None.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
ModelCheckpoint: The configured ModelCheckpoint callback instance.
|
|
306
|
+
|
|
307
|
+
"""
|
|
308
|
+
if dirpath.exists():
|
|
309
|
+
if is_main_process():
|
|
310
|
+
logger.warning(f"Checkpoint directory {dirpath} already exists.")
|
|
311
|
+
rmtree(dirpath)
|
|
312
|
+
logger.warning(f"Removed existing checkpoint directory {dirpath}.")
|
|
313
|
+
else:
|
|
314
|
+
logger.info(f"Creating checkpoint directory {dirpath}.")
|
|
315
|
+
dirpath.mkdir(parents=True, exist_ok=True)
|
|
316
|
+
|
|
317
|
+
checkpoint_callback = CustomModelCheckpoint(
|
|
318
|
+
dirpath=dirpath,
|
|
319
|
+
filename=ckpt_cfg.filename,
|
|
320
|
+
save_top_k=ckpt_cfg.save_top_k,
|
|
321
|
+
monitor=ckpt_cfg.monitor,
|
|
322
|
+
mode=ckpt_cfg.mode,
|
|
323
|
+
verbose=True,
|
|
324
|
+
save_weights_only=save_weights_only,
|
|
325
|
+
registry_uploader_callback=registry_uploader_callback,
|
|
326
|
+
)
|
|
327
|
+
return checkpoint_callback
|
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from functools import partial
|
|
1
5
|
from typing import Literal
|
|
2
6
|
from typing import override
|
|
3
7
|
|
|
@@ -5,7 +9,6 @@ from clearml import OutputModel
|
|
|
5
9
|
from clearml import Task
|
|
6
10
|
from lightning import Trainer
|
|
7
11
|
from lightning.pytorch.callbacks import Callback
|
|
8
|
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
9
12
|
|
|
10
13
|
from kostyl.ml.clearml.logging_utils import find_version_in_tags
|
|
11
14
|
from kostyl.ml.clearml.logging_utils import increment_version
|
|
@@ -16,13 +19,31 @@ from kostyl.utils.logging import setup_logger
|
|
|
16
19
|
logger = setup_logger()
|
|
17
20
|
|
|
18
21
|
|
|
19
|
-
class
|
|
22
|
+
class RegistryUploaderCallback(Callback, ABC):
|
|
23
|
+
"""Abstract Lightning callback responsible for tracking and uploading the best-performing model checkpoint."""
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def best_model_path(self) -> str:
|
|
28
|
+
"""Return the file system path pointing to the best model artifact produced during training."""
|
|
29
|
+
raise NotImplementedError
|
|
30
|
+
|
|
31
|
+
@best_model_path.setter
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def best_model_path(self, value: str) -> None:
|
|
34
|
+
raise NotImplementedError
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
20
42
|
"""PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
|
|
21
43
|
|
|
22
44
|
def __init__(
|
|
23
45
|
self,
|
|
24
46
|
task: Task,
|
|
25
|
-
ckpt_callback: ModelCheckpoint,
|
|
26
47
|
output_model_name: str,
|
|
27
48
|
output_model_tags: list[str] | None = None,
|
|
28
49
|
verbose: bool = True,
|
|
@@ -56,7 +77,6 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
56
77
|
output_model_tags = []
|
|
57
78
|
|
|
58
79
|
self.task = task
|
|
59
|
-
self.ckpt_callback = ckpt_callback
|
|
60
80
|
self.output_model_name = output_model_name
|
|
61
81
|
self.output_model_tags = output_model_tags
|
|
62
82
|
self.config_dict = config_dict
|
|
@@ -66,7 +86,22 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
66
86
|
self.enable_tag_versioning = enable_tag_versioning
|
|
67
87
|
|
|
68
88
|
self._output_model: OutputModel | None = None
|
|
69
|
-
self.
|
|
89
|
+
self._last_uploaded_model_path: str = ""
|
|
90
|
+
self._best_model_path: str = ""
|
|
91
|
+
self._upload_callback: Callable | None = None
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
@override
|
|
96
|
+
def best_model_path(self) -> str:
|
|
97
|
+
return self._best_model_path
|
|
98
|
+
|
|
99
|
+
@best_model_path.setter
|
|
100
|
+
@override
|
|
101
|
+
def best_model_path(self, value: str) -> None:
|
|
102
|
+
self._best_model_path = value
|
|
103
|
+
if self._upload_callback is not None:
|
|
104
|
+
self._upload_callback()
|
|
70
105
|
return
|
|
71
106
|
|
|
72
107
|
def _create_output_model(self, pl_module: "KostylLightningModule") -> OutputModel:
|
|
@@ -98,27 +133,29 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
98
133
|
label_enumeration=self.label_enumeration,
|
|
99
134
|
)
|
|
100
135
|
|
|
136
|
+
@override
|
|
101
137
|
def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
if self.
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
138
|
+
if not self._best_model_path or (
|
|
139
|
+
self._best_model_path == self._last_uploaded_model_path
|
|
140
|
+
):
|
|
141
|
+
if not self._best_model_path:
|
|
142
|
+
if self.verbose:
|
|
143
|
+
logger.info("No best model found yet to upload")
|
|
144
|
+
elif self._best_model_path == self._last_uploaded_model_path:
|
|
145
|
+
if self.verbose:
|
|
146
|
+
logger.info("Best model unchanged since last upload")
|
|
147
|
+
self._upload_callback = partial(self._upload_best_checkpoint, pl_module)
|
|
112
148
|
return
|
|
149
|
+
self._upload_callback = None
|
|
113
150
|
|
|
114
151
|
if self._output_model is None:
|
|
115
152
|
self._output_model = self._create_output_model(pl_module)
|
|
116
153
|
|
|
117
154
|
if self.verbose:
|
|
118
|
-
logger.info(f"Uploading best model from {
|
|
155
|
+
logger.info(f"Uploading best model from {self._best_model_path}")
|
|
119
156
|
|
|
120
157
|
self._output_model.update_weights(
|
|
121
|
-
|
|
158
|
+
self._best_model_path,
|
|
122
159
|
auto_delete_file=False,
|
|
123
160
|
async_enable=False,
|
|
124
161
|
)
|
|
@@ -130,7 +167,7 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
130
167
|
config = self.config_dict
|
|
131
168
|
self._output_model.update_design(config_dict=config)
|
|
132
169
|
|
|
133
|
-
self.
|
|
170
|
+
self._last_uploaded_model_path = self._best_model_path
|
|
134
171
|
return
|
|
135
172
|
|
|
136
173
|
@override
|
|
@@ -1,56 +0,0 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
|
-
from shutil import rmtree
|
|
3
|
-
|
|
4
|
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
5
|
-
|
|
6
|
-
from kostyl.ml.configs import CheckpointConfig
|
|
7
|
-
from kostyl.ml.dist_utils import is_main_process
|
|
8
|
-
from kostyl.utils import setup_logger
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
logger = setup_logger("callbacks/checkpoint.py")
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def setup_checkpoint_callback(
|
|
15
|
-
dirpath: Path,
|
|
16
|
-
ckpt_cfg: CheckpointConfig,
|
|
17
|
-
save_weights_only: bool = True,
|
|
18
|
-
) -> ModelCheckpoint:
|
|
19
|
-
"""
|
|
20
|
-
Sets up a ModelCheckpoint callback for PyTorch Lightning.
|
|
21
|
-
|
|
22
|
-
This function prepares a checkpoint directory and configures a ModelCheckpoint
|
|
23
|
-
callback based on the provided configuration. If the directory already exists,
|
|
24
|
-
it is removed (only by the main process) to ensure a clean start. Otherwise,
|
|
25
|
-
the directory is created.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
dirpath (Path): The path to the directory where checkpoints will be saved.
|
|
29
|
-
ckpt_cfg (CheckpointConfig): Configuration object containing checkpoint
|
|
30
|
-
settings such as filename, save_top_k, monitor, and mode.
|
|
31
|
-
save_weights_only (bool, optional): Whether to save only the model weights
|
|
32
|
-
or the full model. Defaults to True.
|
|
33
|
-
|
|
34
|
-
Returns:
|
|
35
|
-
ModelCheckpoint: The configured ModelCheckpoint callback instance.
|
|
36
|
-
|
|
37
|
-
"""
|
|
38
|
-
if dirpath.exists():
|
|
39
|
-
if is_main_process():
|
|
40
|
-
logger.warning(f"Checkpoint directory {dirpath} already exists.")
|
|
41
|
-
rmtree(dirpath)
|
|
42
|
-
logger.warning(f"Removed existing checkpoint directory {dirpath}.")
|
|
43
|
-
else:
|
|
44
|
-
logger.info(f"Creating checkpoint directory {dirpath}.")
|
|
45
|
-
dirpath.mkdir(parents=True, exist_ok=True)
|
|
46
|
-
|
|
47
|
-
checkpoint_callback = ModelCheckpoint(
|
|
48
|
-
dirpath=dirpath,
|
|
49
|
-
filename=ckpt_cfg.filename,
|
|
50
|
-
save_top_k=ckpt_cfg.save_top_k,
|
|
51
|
-
monitor=ckpt_cfg.monitor,
|
|
52
|
-
mode=ckpt_cfg.mode,
|
|
53
|
-
verbose=True,
|
|
54
|
-
save_weights_only=save_weights_only,
|
|
55
|
-
)
|
|
56
|
-
return checkpoint_callback
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/callbacks/early_stopping.py
RENAMED
|
File without changes
|
|
File without changes
|
{kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/extenstions/custom_module.py
RENAMED
|
File without changes
|
{kostyl_toolkit-0.1.22 → kostyl_toolkit-0.1.23}/kostyl/ml/lightning/extenstions/pretrained_model.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|