kostyl-toolkit 0.1.23__py3-none-any.whl → 0.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -95,7 +95,7 @@ def get_model_from_clearml[
95
95
  raise ValueError(
96
96
  f"Model class {model.__name__} is not compatible with Lightning checkpoints."
97
97
  )
98
- model_instance = model.from_lighting_checkpoint(local_path, **kwargs)
98
+ model_instance = model.from_lightning_checkpoint(local_path, **kwargs)
99
99
  else:
100
100
  raise ValueError(
101
101
  f"Unsupported model format for path: {local_path}. "
@@ -6,15 +6,7 @@ import torch
6
6
  from transformers import PretrainedConfig
7
7
  from transformers import PreTrainedModel
8
8
 
9
-
10
- try:
11
- from peft import PeftConfig
12
- except ImportError:
13
- PeftConfig = None # ty: ignore
14
-
15
- from kostyl.utils.logging import log_incompatible_keys
16
9
  from kostyl.utils.logging import setup_logger
17
- from torch import nn
18
10
 
19
11
 
20
12
  logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
@@ -24,12 +16,11 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
24
16
  """A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
25
17
 
26
18
  @classmethod
27
- def from_lighting_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
19
+ def from_lightning_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
28
20
  cls: type[TModelInstance],
29
21
  checkpoint_path: str | Path,
30
22
  config_key: str = "config",
31
23
  weights_prefix: str = "model.",
32
- should_log_incompatible_keys: bool = True,
33
24
  **kwargs: Any,
34
25
  ) -> TModelInstance:
35
26
  """
@@ -50,8 +41,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
50
41
  Defaults to "config".
51
42
  weights_prefix (str, optional): Prefix to strip from state dict keys. Defaults to "model.".
52
43
  If not empty and doesn't end with ".", a "." is appended.
53
- should_log_incompatible_keys (bool, optional): Whether to log incompatible keys. Defaults to True.
54
- **kwargs: Additional keyword arguments to pass to the model loading method.
44
+ kwargs: Additional keyword arguments to pass to the model's `from_pretrained` method.
55
45
 
56
46
  Returns:
57
47
  TModelInstance: The loaded model instance.
@@ -89,114 +79,27 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
89
79
  if not hasattr(config, key):
90
80
  kwargs_for_model[key] = value
91
81
 
92
- with torch.device("meta"):
93
- model = cls(config, **kwargs_for_model)
94
-
95
- # PEFT-адаптеры (оставляю твою логику как есть)
96
- if "peft_config" in checkpoint_dict:
97
- if PeftConfig is None:
98
- raise ImportError(
99
- "peft is not installed. Please install it to load PEFT models."
100
- )
101
- for name, adapter_dict in checkpoint_dict["peft_config"].items():
102
- peft_cfg = PeftConfig.from_peft_type(**adapter_dict)
103
- model.add_adapter(peft_cfg, adapter_name=name)
104
-
105
- incompatible_keys: dict[str, list[str]] = {}
106
-
107
82
  raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
108
83
 
109
84
  if weights_prefix:
110
85
  if not weights_prefix.endswith("."):
111
86
  weights_prefix = weights_prefix + "."
112
87
  state_dict: dict[str, torch.Tensor] = {}
113
- mismatched_keys: list[str] = []
114
88
 
115
89
  for key, value in raw_state_dict.items():
116
90
  if key.startswith(weights_prefix):
117
91
  new_key = key[len(weights_prefix) :]
118
92
  state_dict[new_key] = value
119
93
  else:
120
- mismatched_keys.append(key)
121
-
122
- if mismatched_keys:
123
- incompatible_keys["mismatched_keys"] = mismatched_keys
94
+ state_dict[key] = value
124
95
  else:
125
96
  state_dict = raw_state_dict
126
97
 
127
- # 5. Логика base_model_prefix как в HF:
128
- # поддержка загрузки базовой модели <-> модели с головой
129
- #
130
- # cls.base_model_prefix обычно "model" / "bert" / "encoder" и т.п.
131
- base_prefix: str = getattr(cls, "base_model_prefix", "") or ""
132
- model_to_load: nn.Module = model
133
-
134
- if base_prefix:
135
- prefix_with_dot = base_prefix + "."
136
- loaded_keys = list(state_dict.keys())
137
- full_model_state = model.state_dict()
138
- expected_keys = list(full_model_state.keys())
139
-
140
- has_prefix_module = any(k.startswith(prefix_with_dot) for k in loaded_keys)
141
- expects_prefix_module = any(
142
- k.startswith(prefix_with_dot) for k in expected_keys
143
- )
144
-
145
- # Кейc 1: загружаем базовую модель в модель с головой.
146
- # Пример: StaticEmbeddingsForSequenceClassification (имеет .model)
147
- # state_dict с ключами "embeddings.weight", "token_pos_weights", ...
148
- if (
149
- hasattr(model, base_prefix)
150
- and not has_prefix_module
151
- and expects_prefix_module
152
- ):
153
- # Веса без префикса -> грузим только в model.<base_prefix>
154
- model_to_load = getattr(model, base_prefix)
155
-
156
- # Кейc 2: загружаем чекпоинт модели с головой в базовую модель.
157
- # Пример: BertModel, а в state_dict ключи "bert.encoder.layer.0..."
158
- elif (
159
- not hasattr(model, base_prefix)
160
- and has_prefix_module
161
- and not expects_prefix_module
162
- ):
163
- new_state_dict: dict[str, torch.Tensor] = {}
164
- for key, value in state_dict.items():
165
- if key.startswith(prefix_with_dot):
166
- new_key = key[len(prefix_with_dot) :]
167
- else:
168
- new_key = key
169
- new_state_dict[new_key] = value
170
- state_dict = new_state_dict
171
-
172
- load_result = model_to_load.load_state_dict(
173
- state_dict, strict=False, assign=True
98
+ model = cls.from_pretrained(
99
+ pretrained_model_name_or_path=None,
100
+ config=config,
101
+ state_dict=state_dict,
102
+ **kwargs_for_model,
174
103
  )
175
- missing_keys, unexpected_keys = (
176
- load_result.missing_keys,
177
- load_result.unexpected_keys,
178
- )
179
-
180
- # Если мы грузили только в base-подмодуль, расширим missing_keys
181
- # до полного списка (base + голова), как в старых версиях HF.
182
- if model_to_load is not model and base_prefix:
183
- base_keys = set(model_to_load.state_dict().keys())
184
- # Приводим ключи полной модели к "безпрефиксному" виду
185
- head_like_keys = set()
186
- prefix_with_dot = base_prefix + "."
187
- for k in model.state_dict().keys():
188
- if k.startswith(prefix_with_dot):
189
- # отрезаем "model."
190
- head_like_keys.add(k[len(prefix_with_dot) :])
191
- else:
192
- head_like_keys.add(k)
193
- extra_missing = sorted(head_like_keys - base_keys)
194
- missing_keys = list(missing_keys) + extra_missing
195
-
196
- incompatible_keys["missing_keys"] = missing_keys
197
- incompatible_keys["unexpected_keys"] = unexpected_keys
198
-
199
- if should_log_incompatible_keys:
200
- log_incompatible_keys(incompatible_keys=incompatible_keys, logger=logger)
201
104
 
202
105
  return model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.23
3
+ Version: 0.1.24
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -3,7 +3,7 @@ kostyl/ml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  kostyl/ml/clearml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  kostyl/ml/clearml/dataset_utils.py,sha256=eij_sr2KDhm8GxEbVbK8aBjPsuVvLl9-PIGGaKVgXLA,1729
5
5
  kostyl/ml/clearml/logging_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWaJTyq8,1210
6
- kostyl/ml/clearml/pulling_utils.py,sha256=07bb7ZYlZy-qoZLn7uWZCtz02eX2idgk3JA-PPooS9E,4077
6
+ kostyl/ml/clearml/pulling_utils.py,sha256=cNa_-_5LHjNVYi9btXBrfl5sPvI6BAAlIFidtpKu310,4078
7
7
  kostyl/ml/configs/__init__.py,sha256=IetcivbqYGutowLqxdKp7QR4tkXKBr4m8t4Zkk9jHZU,911
8
8
  kostyl/ml/configs/base_model.py,sha256=Eofn14J9RsjpVx_J4rp6C19pDDCANU4hr3JtX-d0FpQ,4820
9
9
  kostyl/ml/configs/hyperparams.py,sha256=OqN7mEj3zc5MTqBPCZL3Lcd2VCTDLo_K0yvhRWGfhCs,2924
@@ -16,7 +16,7 @@ kostyl/ml/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jOR
16
16
  kostyl/ml/lightning/callbacks/registry_uploader.py,sha256=ksoh02dzIde4E_GaZykfiOgfSjZti-IJt_i61enem3s,6779
17
17
  kostyl/ml/lightning/extenstions/__init__.py,sha256=OY6QGv1agYgqqKf1xJBrxgp_i8FunVfPzYezfaRrGXU,182
18
18
  kostyl/ml/lightning/extenstions/custom_module.py,sha256=nB5jW7cqRD1tyh-q5LD2EtiFQwFkLXpnS9Yu6c5xMRg,5987
19
- kostyl/ml/lightning/extenstions/pretrained_model.py,sha256=x8D2nMDDW8J913qFRSEGKXfQO8ipPJM5SLo4Y5kc3YA,8638
19
+ kostyl/ml/lightning/extenstions/pretrained_model.py,sha256=QJGr2UvYJcU2Gy2w8z_cEvTodjv7hGdd2PPPfdOI-Mw,4017
20
20
  kostyl/ml/lightning/loggers/__init__.py,sha256=e51dszaoJbuzwBkbdugmuDsPldoSO4yaRgmZUg1Bdy0,71
21
21
  kostyl/ml/lightning/loggers/tb_logger.py,sha256=j02HK5ue8yzXXV8FWKmmXyHkFlIxgHx-ahHWk_rFCZs,893
22
22
  kostyl/ml/lightning/steps_estimation.py,sha256=fTZ0IrUEZV3H6VYlx4GYn56oco56mMiB7FO9F0Z7qc4,1511
@@ -30,6 +30,6 @@ kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
30
30
  kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
31
31
  kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
32
32
  kostyl/utils/logging.py,sha256=Vye0u4-yeOSUc-f03gpQbxSktTbFiilTWLEVr00ZHvc,5796
33
- kostyl_toolkit-0.1.23.dist-info/WHEEL,sha256=z-mOpxbJHqy3cq6SvUThBZdaLGFZzdZPtgWLcP2NKjQ,79
34
- kostyl_toolkit-0.1.23.dist-info/METADATA,sha256=8af_sRkZy9w8chOp4NLvercyB57df6FXAvpLsWKPqro,4269
35
- kostyl_toolkit-0.1.23.dist-info/RECORD,,
33
+ kostyl_toolkit-0.1.24.dist-info/WHEEL,sha256=z-mOpxbJHqy3cq6SvUThBZdaLGFZzdZPtgWLcP2NKjQ,79
34
+ kostyl_toolkit-0.1.24.dist-info/METADATA,sha256=uq8MPJ9vJgWsp9Z2c7C9tcbaH29QM9ux7_SyahPSlHE,4269
35
+ kostyl_toolkit-0.1.24.dist-info/RECORD,,