kostyl-toolkit 0.1.22__py3-none-any.whl → 0.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -95,7 +95,7 @@ def get_model_from_clearml[
95
95
  raise ValueError(
96
96
  f"Model class {model.__name__} is not compatible with Lightning checkpoints."
97
97
  )
98
- model_instance = model.from_lighting_checkpoint(local_path, **kwargs)
98
+ model_instance = model.from_lightning_checkpoint(local_path, **kwargs)
99
99
  else:
100
100
  raise ValueError(
101
101
  f"Unsupported model format for path: {local_path}. "
@@ -1,6 +1,6 @@
1
1
  from .checkpoint import setup_checkpoint_callback
2
2
  from .early_stopping import setup_early_stopping_callback
3
- from .registry_uploading import ClearMLRegistryUploaderCallback
3
+ from .registry_uploader import ClearMLRegistryUploaderCallback
4
4
 
5
5
 
6
6
  __all__ = [
@@ -1,21 +1,289 @@
1
+ from datetime import timedelta
1
2
  from pathlib import Path
2
3
  from shutil import rmtree
4
+ from typing import Literal
5
+ from typing import cast
3
6
 
7
+ from lightning.fabric.utilities.types import _PATH
4
8
  from lightning.pytorch.callbacks import ModelCheckpoint
5
9
 
6
10
  from kostyl.ml.configs import CheckpointConfig
7
11
  from kostyl.ml.dist_utils import is_main_process
8
12
  from kostyl.utils import setup_logger
9
13
 
14
+ from .registry_uploader import RegistryUploaderCallback
15
+
10
16
 
11
17
  logger = setup_logger("callbacks/checkpoint.py")
12
18
 
13
19
 
20
+ class CustomModelCheckpoint(ModelCheckpoint):
21
+ r"""
22
+ Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
23
+ :class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
24
+ checkpoint.
25
+
26
+ After training finishes, use :attr:`best_model_path` to retrieve the path to the
27
+ best checkpoint file and :attr:`best_model_score` to get its score.
28
+
29
+ .. note::
30
+ When using manual optimization with ``every_n_train_steps``, you should save the model state
31
+ in your ``training_step`` before the optimizer step if you want the checkpoint to reflect
32
+ the pre-optimization state. Example:
33
+
34
+ .. code-block:: python
35
+
36
+ def training_step(self, batch, batch_idx):
37
+ # ... forward pass, loss calculation, backward pass ...
38
+
39
+ # Save model state before optimization
40
+ if not hasattr(self, 'saved_models'):
41
+ self.saved_models = {}
42
+ self.saved_models[batch_idx] = {
43
+ k: v.detach().clone()
44
+ for k, v in self.layer.state_dict().items()
45
+ }
46
+
47
+ # Then perform optimization
48
+ optimizer.zero_grad()
49
+ self.manual_backward(loss)
50
+ optimizer.step()
51
+
52
+ # Optional: Clean up old states to save memory
53
+ if batch_idx > 10: # Keep last 10 states
54
+ del self.saved_models[batch_idx - 10]
55
+
56
+ Args:
57
+ dirpath: Directory to save the model file.
58
+ Example: ``dirpath='my/path/'``.
59
+
60
+ .. warning::
61
+ In a distributed environment like DDP, it's recommended to provide a `dirpath` to avoid race conditions.
62
+ When using manual optimization with ``every_n_train_steps``, make sure to save the model state
63
+ in your training loop as shown in the example above.
64
+
65
+ Can be remote file paths such as `s3://mybucket/path/` or 'hdfs://path/'
66
+ (default: ``None``). If dirpath is ``None``, we only keep the ``k`` best checkpoints
67
+ in memory, and do not save anything to disk.
68
+
69
+ filename: Checkpoint filename. Can contain named formatting options to be auto-filled.
70
+ If no name is provided, it will be ``None`` and the checkpoint will be saved to
71
+ ``{epoch}``.and if the Trainer uses a logger, the path will also contain logger name and version.
72
+
73
+ filename: checkpoint filename. Can contain named formatting options to be auto-filled.
74
+
75
+ Example::
76
+
77
+ # save any arbitrary metrics like `val_loss`, etc. in name
78
+ # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
79
+ >>> checkpoint_callback = ModelCheckpoint(
80
+ ... dirpath='my/path',
81
+ ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
82
+ ... )
83
+
84
+ By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``, where "epoch" and "step" match
85
+ the number of finished epoch and optimizer steps respectively.
86
+ monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch.
87
+ verbose: verbosity mode. Default: ``False``.
88
+ save_last: When ``True``, saves a `last.ckpt` copy whenever a checkpoint file gets saved. Can be set to
89
+ ``'link'`` on a local filesystem to create a symbolic link. This allows accessing the latest checkpoint
90
+ in a deterministic manner. Default: ``None``.
91
+ save_top_k: if ``save_top_k == k``,
92
+ the best k models according to the quantity monitored will be saved.
93
+ If ``save_top_k == 0``, no models are saved.
94
+ If ``save_top_k == -1``, all models are saved.
95
+ Please note that the monitors are checked every ``every_n_epochs`` epochs.
96
+ If ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, and the filename remains
97
+ unchanged, the name of the saved file will be appended with a version count starting with ``v1`` to avoid
98
+ collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k
99
+ ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid
100
+ collisions.
101
+ save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``False``.
102
+ mode: one of {min, max}.
103
+ If ``save_top_k != 0``, the decision to overwrite the current save file is made
104
+ based on either the maximization or the minimization of the monitored quantity.
105
+ For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc.
106
+ auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name.
107
+ For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve
108
+ to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/``
109
+ as this will result in extra folders.
110
+ For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False``
111
+ save_weights_only: if ``True``, then only the model's weights will be
112
+ saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
113
+ every_n_train_steps: How many training steps to wait before saving a checkpoint. This does not take into account
114
+ the steps of the current epoch. If ``every_n_train_steps == None or every_n_train_steps == 0``,
115
+ no checkpoints
116
+ will be saved during training. Mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
117
+
118
+ .. note::
119
+ When using with manual optimization, the checkpoint will be saved after the optimizer step by default.
120
+ To save the model state before the optimizer step, you need to save the model state in your
121
+ ``training_step`` before calling ``optimizer.step()``. See the class docstring for an example.
122
+ train_time_interval: Checkpoints are monitored at the specified time interval.
123
+ For all practical purposes, this cannot be smaller than the amount
124
+ of time it takes to process a single training batch. This is not
125
+ guaranteed to execute at the exact time specified, but should be close.
126
+ This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
127
+ every_n_epochs: Number of epochs between checkpoints.
128
+ This value must be ``None`` or non-negative.
129
+ To disable saving top-k checkpoints, set ``every_n_epochs = 0``.
130
+ This argument does not impact the saving of ``save_last=True`` checkpoints.
131
+ If all of ``every_n_epochs``, ``every_n_train_steps`` and
132
+ ``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch
133
+ (equivalent to ``every_n_epochs = 1``).
134
+ If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``,
135
+ saving at the end of each epoch is disabled
136
+ (equivalent to ``every_n_epochs = 0``).
137
+ This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``.
138
+ Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and
139
+ ``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
140
+ will only save checkpoints at epochs 0 < E <= N
141
+ where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
142
+ save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch.
143
+ If ``True``, checkpoints are saved at the end of every training epoch.
144
+ If ``False``, checkpoints are saved at the end of validation.
145
+ If ``None`` (default), checkpointing behavior is determined based on training configuration.
146
+ If ``val_check_interval`` is a str, dict, or `timedelta` (time-based), checkpointing is performed after
147
+ validation.
148
+ If ``check_val_every_n_epoch != 1``, checkpointing will not be performed at the end of
149
+ every training epoch. If there are no validation batches of data, checkpointing will occur at the
150
+ end of the training epoch. If there is a non-default number of validation runs per training epoch
151
+ (``val_check_interval != 1``), checkpointing is performed after validation.
152
+ enable_version_counter: Whether to append a version to the existing file name.
153
+ If ``False``, then the checkpoint files will be overwritten.
154
+
155
+ Note:
156
+ For extra customization, ModelCheckpoint includes the following attributes:
157
+
158
+ - ``CHECKPOINT_JOIN_CHAR = "-"``
159
+ - ``CHECKPOINT_EQUALS_CHAR = "="``
160
+ - ``CHECKPOINT_NAME_LAST = "last"``
161
+ - ``FILE_EXTENSION = ".ckpt"``
162
+ - ``STARTING_VERSION = 1``
163
+
164
+ For example, you can change the default last checkpoint name by doing
165
+ ``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``
166
+
167
+ If you want to checkpoint every N hours, every M train batches, and/or every K val epochs,
168
+ then you should create multiple ``ModelCheckpoint`` callbacks.
169
+
170
+ If the checkpoint's ``dirpath`` changed from what it was before while resuming the training,
171
+ only ``best_model_path`` will be reloaded and a warning will be issued.
172
+
173
+ If you provide a ``filename`` on a mounted device where changing permissions is not allowed (causing ``chmod``
174
+ to raise a ``PermissionError``), install `fsspec>=2025.5.0`. Then the error is caught, the file's permissions
175
+ remain unchanged, and the checkpoint is still saved. Otherwise, no checkpoint will be saved and training stops.
176
+
177
+ Raises:
178
+ MisconfigurationException:
179
+ If ``save_top_k`` is smaller than ``-1``,
180
+ if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or
181
+ if ``mode`` is none of ``"min"`` or ``"max"``.
182
+ ValueError:
183
+ If ``trainer.save_checkpoint`` is ``None``.
184
+
185
+ Example::
186
+
187
+ >>> from lightning.pytorch import Trainer
188
+ >>> from lightning.pytorch.callbacks import ModelCheckpoint
189
+
190
+ # saves checkpoints to 'my/path/' at every epoch
191
+ >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
192
+ >>> trainer = Trainer(callbacks=[checkpoint_callback])
193
+
194
+ # save epoch and val_loss in name
195
+ # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
196
+ >>> checkpoint_callback = ModelCheckpoint(
197
+ ... monitor='val_loss',
198
+ ... dirpath='my/path/',
199
+ ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
200
+ ... )
201
+
202
+ # save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
203
+ # or Neptune, due to the presence of characters like '=' or '/')
204
+ # saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
205
+ >>> checkpoint_callback = ModelCheckpoint(
206
+ ... monitor='val/loss',
207
+ ... dirpath='my/path/',
208
+ ... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
209
+ ... auto_insert_metric_name=False
210
+ ... )
211
+
212
+ # retrieve the best checkpoint after training
213
+ >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
214
+ >>> trainer = Trainer(callbacks=[checkpoint_callback])
215
+ >>> model = ... # doctest: +SKIP
216
+ >>> trainer.fit(model) # doctest: +SKIP
217
+ >>> print(checkpoint_callback.best_model_path) # doctest: +SKIP
218
+
219
+ .. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the
220
+ following arguments:
221
+
222
+ *monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval*
223
+
224
+ Read more: :ref:`Persisting Callback State <extensions/callbacks_state:save callback state>`
225
+
226
+ """ # noqa: D205
227
+
228
+ def __init__( # noqa: D107
229
+ self,
230
+ dirpath: _PATH | None = None,
231
+ filename: str | None = None,
232
+ monitor: str | None = None,
233
+ verbose: bool = False,
234
+ save_last: bool | Literal["link"] | None = None,
235
+ save_top_k: int = 1,
236
+ save_on_exception: bool = False,
237
+ save_weights_only: bool = False,
238
+ mode: str = "min",
239
+ auto_insert_metric_name: bool = True,
240
+ every_n_train_steps: int | None = None,
241
+ train_time_interval: timedelta | None = None,
242
+ every_n_epochs: int | None = None,
243
+ save_on_train_epoch_end: bool | None = None,
244
+ enable_version_counter: bool = True,
245
+ registry_uploader_callback: RegistryUploaderCallback | None = None,
246
+ ) -> None:
247
+ super().__init__(
248
+ dirpath=dirpath,
249
+ filename=filename,
250
+ monitor=monitor,
251
+ verbose=verbose,
252
+ save_last=save_last,
253
+ save_top_k=save_top_k,
254
+ save_on_exception=save_on_exception,
255
+ save_weights_only=save_weights_only,
256
+ mode=mode,
257
+ auto_insert_metric_name=auto_insert_metric_name,
258
+ every_n_train_steps=every_n_train_steps,
259
+ train_time_interval=train_time_interval,
260
+ every_n_epochs=every_n_epochs,
261
+ save_on_train_epoch_end=save_on_train_epoch_end,
262
+ enable_version_counter=enable_version_counter,
263
+ )
264
+ self.registry_uploader_callback = registry_uploader_callback
265
+ self._custom_best_model_path = cast(str, self.best_model_path)
266
+ return
267
+
268
+ @property
269
+ def best_model_path(self) -> str:
270
+ """Best model path."""
271
+ return self._custom_best_model_path
272
+
273
+ @best_model_path.setter
274
+ def best_model_path(self, value: str) -> None:
275
+ self._custom_best_model_path = value
276
+ if self.registry_uploader_callback is not None:
277
+ self.registry_uploader_callback.best_model_path = value
278
+ return
279
+
280
+
14
281
  def setup_checkpoint_callback(
15
282
  dirpath: Path,
16
283
  ckpt_cfg: CheckpointConfig,
17
284
  save_weights_only: bool = True,
18
- ) -> ModelCheckpoint:
285
+ registry_uploader_callback: RegistryUploaderCallback | None = None,
286
+ ) -> CustomModelCheckpoint:
19
287
  """
20
288
  Sets up a ModelCheckpoint callback for PyTorch Lightning.
21
289
 
@@ -30,6 +298,8 @@ def setup_checkpoint_callback(
30
298
  settings such as filename, save_top_k, monitor, and mode.
31
299
  save_weights_only (bool, optional): Whether to save only the model weights
32
300
  or the full model. Defaults to True.
301
+ registry_uploader_callback (RegistryUploaderCallback | None, optional):
302
+ An optional callback for uploading checkpoints to a registry. Defaults to None.
33
303
 
34
304
  Returns:
35
305
  ModelCheckpoint: The configured ModelCheckpoint callback instance.
@@ -44,7 +314,7 @@ def setup_checkpoint_callback(
44
314
  logger.info(f"Creating checkpoint directory {dirpath}.")
45
315
  dirpath.mkdir(parents=True, exist_ok=True)
46
316
 
47
- checkpoint_callback = ModelCheckpoint(
317
+ checkpoint_callback = CustomModelCheckpoint(
48
318
  dirpath=dirpath,
49
319
  filename=ckpt_cfg.filename,
50
320
  save_top_k=ckpt_cfg.save_top_k,
@@ -52,5 +322,6 @@ def setup_checkpoint_callback(
52
322
  mode=ckpt_cfg.mode,
53
323
  verbose=True,
54
324
  save_weights_only=save_weights_only,
325
+ registry_uploader_callback=registry_uploader_callback,
55
326
  )
56
327
  return checkpoint_callback
@@ -1,3 +1,7 @@
1
+ from abc import ABC
2
+ from abc import abstractmethod
3
+ from collections.abc import Callable
4
+ from functools import partial
1
5
  from typing import Literal
2
6
  from typing import override
3
7
 
@@ -5,7 +9,6 @@ from clearml import OutputModel
5
9
  from clearml import Task
6
10
  from lightning import Trainer
7
11
  from lightning.pytorch.callbacks import Callback
8
- from lightning.pytorch.callbacks import ModelCheckpoint
9
12
 
10
13
  from kostyl.ml.clearml.logging_utils import find_version_in_tags
11
14
  from kostyl.ml.clearml.logging_utils import increment_version
@@ -16,13 +19,31 @@ from kostyl.utils.logging import setup_logger
16
19
  logger = setup_logger()
17
20
 
18
21
 
19
- class ClearMLRegistryUploaderCallback(Callback):
22
+ class RegistryUploaderCallback(Callback, ABC):
23
+ """Abstract Lightning callback responsible for tracking and uploading the best-performing model checkpoint."""
24
+
25
+ @property
26
+ @abstractmethod
27
+ def best_model_path(self) -> str:
28
+ """Return the file system path pointing to the best model artifact produced during training."""
29
+ raise NotImplementedError
30
+
31
+ @best_model_path.setter
32
+ @abstractmethod
33
+ def best_model_path(self, value: str) -> None:
34
+ raise NotImplementedError
35
+
36
+ @abstractmethod
37
+ def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
38
+ raise NotImplementedError
39
+
40
+
41
+ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
20
42
  """PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
21
43
 
22
44
  def __init__(
23
45
  self,
24
46
  task: Task,
25
- ckpt_callback: ModelCheckpoint,
26
47
  output_model_name: str,
27
48
  output_model_tags: list[str] | None = None,
28
49
  verbose: bool = True,
@@ -56,7 +77,6 @@ class ClearMLRegistryUploaderCallback(Callback):
56
77
  output_model_tags = []
57
78
 
58
79
  self.task = task
59
- self.ckpt_callback = ckpt_callback
60
80
  self.output_model_name = output_model_name
61
81
  self.output_model_tags = output_model_tags
62
82
  self.config_dict = config_dict
@@ -66,7 +86,22 @@ class ClearMLRegistryUploaderCallback(Callback):
66
86
  self.enable_tag_versioning = enable_tag_versioning
67
87
 
68
88
  self._output_model: OutputModel | None = None
69
- self._last_best_model_path: str = ""
89
+ self._last_uploaded_model_path: str = ""
90
+ self._best_model_path: str = ""
91
+ self._upload_callback: Callable | None = None
92
+ return
93
+
94
+ @property
95
+ @override
96
+ def best_model_path(self) -> str:
97
+ return self._best_model_path
98
+
99
+ @best_model_path.setter
100
+ @override
101
+ def best_model_path(self, value: str) -> None:
102
+ self._best_model_path = value
103
+ if self._upload_callback is not None:
104
+ self._upload_callback()
70
105
  return
71
106
 
72
107
  def _create_output_model(self, pl_module: "KostylLightningModule") -> OutputModel:
@@ -98,27 +133,29 @@ class ClearMLRegistryUploaderCallback(Callback):
98
133
  label_enumeration=self.label_enumeration,
99
134
  )
100
135
 
136
+ @override
101
137
  def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
102
- current_best = self.ckpt_callback.best_model_path
103
-
104
- if not current_best:
105
- if self.verbose:
106
- logger.info("No best model found yet to upload")
107
- return
108
-
109
- if current_best == self._last_best_model_path:
110
- if self.verbose:
111
- logger.info("Best model unchanged since last upload")
138
+ if not self._best_model_path or (
139
+ self._best_model_path == self._last_uploaded_model_path
140
+ ):
141
+ if not self._best_model_path:
142
+ if self.verbose:
143
+ logger.info("No best model found yet to upload")
144
+ elif self._best_model_path == self._last_uploaded_model_path:
145
+ if self.verbose:
146
+ logger.info("Best model unchanged since last upload")
147
+ self._upload_callback = partial(self._upload_best_checkpoint, pl_module)
112
148
  return
149
+ self._upload_callback = None
113
150
 
114
151
  if self._output_model is None:
115
152
  self._output_model = self._create_output_model(pl_module)
116
153
 
117
154
  if self.verbose:
118
- logger.info(f"Uploading best model from {current_best}")
155
+ logger.info(f"Uploading best model from {self._best_model_path}")
119
156
 
120
157
  self._output_model.update_weights(
121
- current_best,
158
+ self._best_model_path,
122
159
  auto_delete_file=False,
123
160
  async_enable=False,
124
161
  )
@@ -130,7 +167,7 @@ class ClearMLRegistryUploaderCallback(Callback):
130
167
  config = self.config_dict
131
168
  self._output_model.update_design(config_dict=config)
132
169
 
133
- self._last_best_model_path = current_best
170
+ self._last_uploaded_model_path = self._best_model_path
134
171
  return
135
172
 
136
173
  @override
@@ -6,15 +6,7 @@ import torch
6
6
  from transformers import PretrainedConfig
7
7
  from transformers import PreTrainedModel
8
8
 
9
-
10
- try:
11
- from peft import PeftConfig
12
- except ImportError:
13
- PeftConfig = None # ty: ignore
14
-
15
- from kostyl.utils.logging import log_incompatible_keys
16
9
  from kostyl.utils.logging import setup_logger
17
- from torch import nn
18
10
 
19
11
 
20
12
  logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
@@ -24,12 +16,11 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
24
16
  """A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
25
17
 
26
18
  @classmethod
27
- def from_lighting_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
19
+ def from_lightning_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
28
20
  cls: type[TModelInstance],
29
21
  checkpoint_path: str | Path,
30
22
  config_key: str = "config",
31
23
  weights_prefix: str = "model.",
32
- should_log_incompatible_keys: bool = True,
33
24
  **kwargs: Any,
34
25
  ) -> TModelInstance:
35
26
  """
@@ -50,8 +41,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
50
41
  Defaults to "config".
51
42
  weights_prefix (str, optional): Prefix to strip from state dict keys. Defaults to "model.".
52
43
  If not empty and doesn't end with ".", a "." is appended.
53
- should_log_incompatible_keys (bool, optional): Whether to log incompatible keys. Defaults to True.
54
- **kwargs: Additional keyword arguments to pass to the model loading method.
44
+ kwargs: Additional keyword arguments to pass to the model's `from_pretrained` method.
55
45
 
56
46
  Returns:
57
47
  TModelInstance: The loaded model instance.
@@ -89,114 +79,27 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
89
79
  if not hasattr(config, key):
90
80
  kwargs_for_model[key] = value
91
81
 
92
- with torch.device("meta"):
93
- model = cls(config, **kwargs_for_model)
94
-
95
- # PEFT-адаптеры (оставляю твою логику как есть)
96
- if "peft_config" in checkpoint_dict:
97
- if PeftConfig is None:
98
- raise ImportError(
99
- "peft is not installed. Please install it to load PEFT models."
100
- )
101
- for name, adapter_dict in checkpoint_dict["peft_config"].items():
102
- peft_cfg = PeftConfig.from_peft_type(**adapter_dict)
103
- model.add_adapter(peft_cfg, adapter_name=name)
104
-
105
- incompatible_keys: dict[str, list[str]] = {}
106
-
107
82
  raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
108
83
 
109
84
  if weights_prefix:
110
85
  if not weights_prefix.endswith("."):
111
86
  weights_prefix = weights_prefix + "."
112
87
  state_dict: dict[str, torch.Tensor] = {}
113
- mismatched_keys: list[str] = []
114
88
 
115
89
  for key, value in raw_state_dict.items():
116
90
  if key.startswith(weights_prefix):
117
91
  new_key = key[len(weights_prefix) :]
118
92
  state_dict[new_key] = value
119
93
  else:
120
- mismatched_keys.append(key)
121
-
122
- if mismatched_keys:
123
- incompatible_keys["mismatched_keys"] = mismatched_keys
94
+ state_dict[key] = value
124
95
  else:
125
96
  state_dict = raw_state_dict
126
97
 
127
- # 5. Логика base_model_prefix как в HF:
128
- # поддержка загрузки базовой модели <-> модели с головой
129
- #
130
- # cls.base_model_prefix обычно "model" / "bert" / "encoder" и т.п.
131
- base_prefix: str = getattr(cls, "base_model_prefix", "") or ""
132
- model_to_load: nn.Module = model
133
-
134
- if base_prefix:
135
- prefix_with_dot = base_prefix + "."
136
- loaded_keys = list(state_dict.keys())
137
- full_model_state = model.state_dict()
138
- expected_keys = list(full_model_state.keys())
139
-
140
- has_prefix_module = any(k.startswith(prefix_with_dot) for k in loaded_keys)
141
- expects_prefix_module = any(
142
- k.startswith(prefix_with_dot) for k in expected_keys
143
- )
144
-
145
- # Кейc 1: загружаем базовую модель в модель с головой.
146
- # Пример: StaticEmbeddingsForSequenceClassification (имеет .model)
147
- # state_dict с ключами "embeddings.weight", "token_pos_weights", ...
148
- if (
149
- hasattr(model, base_prefix)
150
- and not has_prefix_module
151
- and expects_prefix_module
152
- ):
153
- # Веса без префикса -> грузим только в model.<base_prefix>
154
- model_to_load = getattr(model, base_prefix)
155
-
156
- # Кейc 2: загружаем чекпоинт модели с головой в базовую модель.
157
- # Пример: BertModel, а в state_dict ключи "bert.encoder.layer.0..."
158
- elif (
159
- not hasattr(model, base_prefix)
160
- and has_prefix_module
161
- and not expects_prefix_module
162
- ):
163
- new_state_dict: dict[str, torch.Tensor] = {}
164
- for key, value in state_dict.items():
165
- if key.startswith(prefix_with_dot):
166
- new_key = key[len(prefix_with_dot) :]
167
- else:
168
- new_key = key
169
- new_state_dict[new_key] = value
170
- state_dict = new_state_dict
171
-
172
- load_result = model_to_load.load_state_dict(
173
- state_dict, strict=False, assign=True
98
+ model = cls.from_pretrained(
99
+ pretrained_model_name_or_path=None,
100
+ config=config,
101
+ state_dict=state_dict,
102
+ **kwargs_for_model,
174
103
  )
175
- missing_keys, unexpected_keys = (
176
- load_result.missing_keys,
177
- load_result.unexpected_keys,
178
- )
179
-
180
- # Если мы грузили только в base-подмодуль, расширим missing_keys
181
- # до полного списка (base + голова), как в старых версиях HF.
182
- if model_to_load is not model and base_prefix:
183
- base_keys = set(model_to_load.state_dict().keys())
184
- # Приводим ключи полной модели к "безпрефиксному" виду
185
- head_like_keys = set()
186
- prefix_with_dot = base_prefix + "."
187
- for k in model.state_dict().keys():
188
- if k.startswith(prefix_with_dot):
189
- # отрезаем "model."
190
- head_like_keys.add(k[len(prefix_with_dot) :])
191
- else:
192
- head_like_keys.add(k)
193
- extra_missing = sorted(head_like_keys - base_keys)
194
- missing_keys = list(missing_keys) + extra_missing
195
-
196
- incompatible_keys["missing_keys"] = missing_keys
197
- incompatible_keys["unexpected_keys"] = unexpected_keys
198
-
199
- if should_log_incompatible_keys:
200
- log_incompatible_keys(incompatible_keys=incompatible_keys, logger=logger)
201
104
 
202
105
  return model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.22
3
+ Version: 0.1.24
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -3,20 +3,20 @@ kostyl/ml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  kostyl/ml/clearml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  kostyl/ml/clearml/dataset_utils.py,sha256=eij_sr2KDhm8GxEbVbK8aBjPsuVvLl9-PIGGaKVgXLA,1729
5
5
  kostyl/ml/clearml/logging_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWaJTyq8,1210
6
- kostyl/ml/clearml/pulling_utils.py,sha256=07bb7ZYlZy-qoZLn7uWZCtz02eX2idgk3JA-PPooS9E,4077
6
+ kostyl/ml/clearml/pulling_utils.py,sha256=cNa_-_5LHjNVYi9btXBrfl5sPvI6BAAlIFidtpKu310,4078
7
7
  kostyl/ml/configs/__init__.py,sha256=IetcivbqYGutowLqxdKp7QR4tkXKBr4m8t4Zkk9jHZU,911
8
8
  kostyl/ml/configs/base_model.py,sha256=Eofn14J9RsjpVx_J4rp6C19pDDCANU4hr3JtX-d0FpQ,4820
9
9
  kostyl/ml/configs/hyperparams.py,sha256=OqN7mEj3zc5MTqBPCZL3Lcd2VCTDLo_K0yvhRWGfhCs,2924
10
10
  kostyl/ml/configs/training_settings.py,sha256=Sq2tiRuwkbmi9zKDG2JghZLXo5DDt_eQqN_KYJSdcTY,2509
11
11
  kostyl/ml/dist_utils.py,sha256=Onf0KHVLA8oeUgZTcTdmR9qiM22f2uYLoNwgLbMGJWk,3495
12
12
  kostyl/ml/lightning/__init__.py,sha256=-F3JAyq8KU1d-nACWryGu8d1CbvWbQ1rXFdeRwfE2X8,175
13
- kostyl/ml/lightning/callbacks/__init__.py,sha256=Vd-rozY4T9Prr3IMqbliXxj6sC6y9XsovHQqRwzc2HI,297
14
- kostyl/ml/lightning/callbacks/checkpoint.py,sha256=FooGeeUz6TtoXQglpcK16NWAmSX3fbu6wntRtK3a_Io,1936
13
+ kostyl/ml/lightning/callbacks/__init__.py,sha256=enexQt3octktsTiEYHltSF_24CM-NeFEVFimXiavGiY,296
14
+ kostyl/ml/lightning/callbacks/checkpoint.py,sha256=rJ05S7BnPopvQjV5b9CI29-S7-ySllyhG7PIzz34VyY,16793
15
15
  kostyl/ml/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
16
- kostyl/ml/lightning/callbacks/registry_uploading.py,sha256=32vhMNNuThtEcvRdS5jh5s-wf7LwZNsCTwZA3emcObs,5449
16
+ kostyl/ml/lightning/callbacks/registry_uploader.py,sha256=ksoh02dzIde4E_GaZykfiOgfSjZti-IJt_i61enem3s,6779
17
17
  kostyl/ml/lightning/extenstions/__init__.py,sha256=OY6QGv1agYgqqKf1xJBrxgp_i8FunVfPzYezfaRrGXU,182
18
18
  kostyl/ml/lightning/extenstions/custom_module.py,sha256=nB5jW7cqRD1tyh-q5LD2EtiFQwFkLXpnS9Yu6c5xMRg,5987
19
- kostyl/ml/lightning/extenstions/pretrained_model.py,sha256=x8D2nMDDW8J913qFRSEGKXfQO8ipPJM5SLo4Y5kc3YA,8638
19
+ kostyl/ml/lightning/extenstions/pretrained_model.py,sha256=QJGr2UvYJcU2Gy2w8z_cEvTodjv7hGdd2PPPfdOI-Mw,4017
20
20
  kostyl/ml/lightning/loggers/__init__.py,sha256=e51dszaoJbuzwBkbdugmuDsPldoSO4yaRgmZUg1Bdy0,71
21
21
  kostyl/ml/lightning/loggers/tb_logger.py,sha256=j02HK5ue8yzXXV8FWKmmXyHkFlIxgHx-ahHWk_rFCZs,893
22
22
  kostyl/ml/lightning/steps_estimation.py,sha256=fTZ0IrUEZV3H6VYlx4GYn56oco56mMiB7FO9F0Z7qc4,1511
@@ -30,6 +30,6 @@ kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
30
30
  kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
31
31
  kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
32
32
  kostyl/utils/logging.py,sha256=Vye0u4-yeOSUc-f03gpQbxSktTbFiilTWLEVr00ZHvc,5796
33
- kostyl_toolkit-0.1.22.dist-info/WHEEL,sha256=z-mOpxbJHqy3cq6SvUThBZdaLGFZzdZPtgWLcP2NKjQ,79
34
- kostyl_toolkit-0.1.22.dist-info/METADATA,sha256=GweBJ42Dhbl4Y5PNu-jnffXj1CaJ34DTPUcoFEndJ1M,4269
35
- kostyl_toolkit-0.1.22.dist-info/RECORD,,
33
+ kostyl_toolkit-0.1.24.dist-info/WHEEL,sha256=z-mOpxbJHqy3cq6SvUThBZdaLGFZzdZPtgWLcP2NKjQ,79
34
+ kostyl_toolkit-0.1.24.dist-info/METADATA,sha256=uq8MPJ9vJgWsp9Z2c7C9tcbaH29QM9ux7_SyahPSlHE,4269
35
+ kostyl_toolkit-0.1.24.dist-info/RECORD,,