kostyl-toolkit 0.1.23__py3-none-any.whl → 0.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kostyl/ml/clearml/pulling_utils.py +1 -1
- kostyl/ml/lightning/extenstions/pretrained_model.py +8 -105
- {kostyl_toolkit-0.1.23.dist-info → kostyl_toolkit-0.1.24.dist-info}/METADATA +1 -1
- {kostyl_toolkit-0.1.23.dist-info → kostyl_toolkit-0.1.24.dist-info}/RECORD +5 -5
- {kostyl_toolkit-0.1.23.dist-info → kostyl_toolkit-0.1.24.dist-info}/WHEEL +0 -0
|
@@ -95,7 +95,7 @@ def get_model_from_clearml[
|
|
|
95
95
|
raise ValueError(
|
|
96
96
|
f"Model class {model.__name__} is not compatible with Lightning checkpoints."
|
|
97
97
|
)
|
|
98
|
-
model_instance = model.
|
|
98
|
+
model_instance = model.from_lightning_checkpoint(local_path, **kwargs)
|
|
99
99
|
else:
|
|
100
100
|
raise ValueError(
|
|
101
101
|
f"Unsupported model format for path: {local_path}. "
|
|
@@ -6,15 +6,7 @@ import torch
|
|
|
6
6
|
from transformers import PretrainedConfig
|
|
7
7
|
from transformers import PreTrainedModel
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
try:
|
|
11
|
-
from peft import PeftConfig
|
|
12
|
-
except ImportError:
|
|
13
|
-
PeftConfig = None # ty: ignore
|
|
14
|
-
|
|
15
|
-
from kostyl.utils.logging import log_incompatible_keys
|
|
16
9
|
from kostyl.utils.logging import setup_logger
|
|
17
|
-
from torch import nn
|
|
18
10
|
|
|
19
11
|
|
|
20
12
|
logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
|
|
@@ -24,12 +16,11 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
24
16
|
"""A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
|
|
25
17
|
|
|
26
18
|
@classmethod
|
|
27
|
-
def
|
|
19
|
+
def from_lightning_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
|
|
28
20
|
cls: type[TModelInstance],
|
|
29
21
|
checkpoint_path: str | Path,
|
|
30
22
|
config_key: str = "config",
|
|
31
23
|
weights_prefix: str = "model.",
|
|
32
|
-
should_log_incompatible_keys: bool = True,
|
|
33
24
|
**kwargs: Any,
|
|
34
25
|
) -> TModelInstance:
|
|
35
26
|
"""
|
|
@@ -50,8 +41,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
50
41
|
Defaults to "config".
|
|
51
42
|
weights_prefix (str, optional): Prefix to strip from state dict keys. Defaults to "model.".
|
|
52
43
|
If not empty and doesn't end with ".", a "." is appended.
|
|
53
|
-
|
|
54
|
-
**kwargs: Additional keyword arguments to pass to the model loading method.
|
|
44
|
+
kwargs: Additional keyword arguments to pass to the model's `from_pretrained` method.
|
|
55
45
|
|
|
56
46
|
Returns:
|
|
57
47
|
TModelInstance: The loaded model instance.
|
|
@@ -89,114 +79,27 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
89
79
|
if not hasattr(config, key):
|
|
90
80
|
kwargs_for_model[key] = value
|
|
91
81
|
|
|
92
|
-
with torch.device("meta"):
|
|
93
|
-
model = cls(config, **kwargs_for_model)
|
|
94
|
-
|
|
95
|
-
# PEFT-адаптеры (оставляю твою логику как есть)
|
|
96
|
-
if "peft_config" in checkpoint_dict:
|
|
97
|
-
if PeftConfig is None:
|
|
98
|
-
raise ImportError(
|
|
99
|
-
"peft is not installed. Please install it to load PEFT models."
|
|
100
|
-
)
|
|
101
|
-
for name, adapter_dict in checkpoint_dict["peft_config"].items():
|
|
102
|
-
peft_cfg = PeftConfig.from_peft_type(**adapter_dict)
|
|
103
|
-
model.add_adapter(peft_cfg, adapter_name=name)
|
|
104
|
-
|
|
105
|
-
incompatible_keys: dict[str, list[str]] = {}
|
|
106
|
-
|
|
107
82
|
raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
|
|
108
83
|
|
|
109
84
|
if weights_prefix:
|
|
110
85
|
if not weights_prefix.endswith("."):
|
|
111
86
|
weights_prefix = weights_prefix + "."
|
|
112
87
|
state_dict: dict[str, torch.Tensor] = {}
|
|
113
|
-
mismatched_keys: list[str] = []
|
|
114
88
|
|
|
115
89
|
for key, value in raw_state_dict.items():
|
|
116
90
|
if key.startswith(weights_prefix):
|
|
117
91
|
new_key = key[len(weights_prefix) :]
|
|
118
92
|
state_dict[new_key] = value
|
|
119
93
|
else:
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
if mismatched_keys:
|
|
123
|
-
incompatible_keys["mismatched_keys"] = mismatched_keys
|
|
94
|
+
state_dict[key] = value
|
|
124
95
|
else:
|
|
125
96
|
state_dict = raw_state_dict
|
|
126
97
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
model_to_load: nn.Module = model
|
|
133
|
-
|
|
134
|
-
if base_prefix:
|
|
135
|
-
prefix_with_dot = base_prefix + "."
|
|
136
|
-
loaded_keys = list(state_dict.keys())
|
|
137
|
-
full_model_state = model.state_dict()
|
|
138
|
-
expected_keys = list(full_model_state.keys())
|
|
139
|
-
|
|
140
|
-
has_prefix_module = any(k.startswith(prefix_with_dot) for k in loaded_keys)
|
|
141
|
-
expects_prefix_module = any(
|
|
142
|
-
k.startswith(prefix_with_dot) for k in expected_keys
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
# Кейc 1: загружаем базовую модель в модель с головой.
|
|
146
|
-
# Пример: StaticEmbeddingsForSequenceClassification (имеет .model)
|
|
147
|
-
# state_dict с ключами "embeddings.weight", "token_pos_weights", ...
|
|
148
|
-
if (
|
|
149
|
-
hasattr(model, base_prefix)
|
|
150
|
-
and not has_prefix_module
|
|
151
|
-
and expects_prefix_module
|
|
152
|
-
):
|
|
153
|
-
# Веса без префикса -> грузим только в model.<base_prefix>
|
|
154
|
-
model_to_load = getattr(model, base_prefix)
|
|
155
|
-
|
|
156
|
-
# Кейc 2: загружаем чекпоинт модели с головой в базовую модель.
|
|
157
|
-
# Пример: BertModel, а в state_dict ключи "bert.encoder.layer.0..."
|
|
158
|
-
elif (
|
|
159
|
-
not hasattr(model, base_prefix)
|
|
160
|
-
and has_prefix_module
|
|
161
|
-
and not expects_prefix_module
|
|
162
|
-
):
|
|
163
|
-
new_state_dict: dict[str, torch.Tensor] = {}
|
|
164
|
-
for key, value in state_dict.items():
|
|
165
|
-
if key.startswith(prefix_with_dot):
|
|
166
|
-
new_key = key[len(prefix_with_dot) :]
|
|
167
|
-
else:
|
|
168
|
-
new_key = key
|
|
169
|
-
new_state_dict[new_key] = value
|
|
170
|
-
state_dict = new_state_dict
|
|
171
|
-
|
|
172
|
-
load_result = model_to_load.load_state_dict(
|
|
173
|
-
state_dict, strict=False, assign=True
|
|
98
|
+
model = cls.from_pretrained(
|
|
99
|
+
pretrained_model_name_or_path=None,
|
|
100
|
+
config=config,
|
|
101
|
+
state_dict=state_dict,
|
|
102
|
+
**kwargs_for_model,
|
|
174
103
|
)
|
|
175
|
-
missing_keys, unexpected_keys = (
|
|
176
|
-
load_result.missing_keys,
|
|
177
|
-
load_result.unexpected_keys,
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
# Если мы грузили только в base-подмодуль, расширим missing_keys
|
|
181
|
-
# до полного списка (base + голова), как в старых версиях HF.
|
|
182
|
-
if model_to_load is not model and base_prefix:
|
|
183
|
-
base_keys = set(model_to_load.state_dict().keys())
|
|
184
|
-
# Приводим ключи полной модели к "безпрефиксному" виду
|
|
185
|
-
head_like_keys = set()
|
|
186
|
-
prefix_with_dot = base_prefix + "."
|
|
187
|
-
for k in model.state_dict().keys():
|
|
188
|
-
if k.startswith(prefix_with_dot):
|
|
189
|
-
# отрезаем "model."
|
|
190
|
-
head_like_keys.add(k[len(prefix_with_dot) :])
|
|
191
|
-
else:
|
|
192
|
-
head_like_keys.add(k)
|
|
193
|
-
extra_missing = sorted(head_like_keys - base_keys)
|
|
194
|
-
missing_keys = list(missing_keys) + extra_missing
|
|
195
|
-
|
|
196
|
-
incompatible_keys["missing_keys"] = missing_keys
|
|
197
|
-
incompatible_keys["unexpected_keys"] = unexpected_keys
|
|
198
|
-
|
|
199
|
-
if should_log_incompatible_keys:
|
|
200
|
-
log_incompatible_keys(incompatible_keys=incompatible_keys, logger=logger)
|
|
201
104
|
|
|
202
105
|
return model
|
|
@@ -3,7 +3,7 @@ kostyl/ml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
3
3
|
kostyl/ml/clearml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
4
|
kostyl/ml/clearml/dataset_utils.py,sha256=eij_sr2KDhm8GxEbVbK8aBjPsuVvLl9-PIGGaKVgXLA,1729
|
|
5
5
|
kostyl/ml/clearml/logging_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWaJTyq8,1210
|
|
6
|
-
kostyl/ml/clearml/pulling_utils.py,sha256=
|
|
6
|
+
kostyl/ml/clearml/pulling_utils.py,sha256=cNa_-_5LHjNVYi9btXBrfl5sPvI6BAAlIFidtpKu310,4078
|
|
7
7
|
kostyl/ml/configs/__init__.py,sha256=IetcivbqYGutowLqxdKp7QR4tkXKBr4m8t4Zkk9jHZU,911
|
|
8
8
|
kostyl/ml/configs/base_model.py,sha256=Eofn14J9RsjpVx_J4rp6C19pDDCANU4hr3JtX-d0FpQ,4820
|
|
9
9
|
kostyl/ml/configs/hyperparams.py,sha256=OqN7mEj3zc5MTqBPCZL3Lcd2VCTDLo_K0yvhRWGfhCs,2924
|
|
@@ -16,7 +16,7 @@ kostyl/ml/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jOR
|
|
|
16
16
|
kostyl/ml/lightning/callbacks/registry_uploader.py,sha256=ksoh02dzIde4E_GaZykfiOgfSjZti-IJt_i61enem3s,6779
|
|
17
17
|
kostyl/ml/lightning/extenstions/__init__.py,sha256=OY6QGv1agYgqqKf1xJBrxgp_i8FunVfPzYezfaRrGXU,182
|
|
18
18
|
kostyl/ml/lightning/extenstions/custom_module.py,sha256=nB5jW7cqRD1tyh-q5LD2EtiFQwFkLXpnS9Yu6c5xMRg,5987
|
|
19
|
-
kostyl/ml/lightning/extenstions/pretrained_model.py,sha256=
|
|
19
|
+
kostyl/ml/lightning/extenstions/pretrained_model.py,sha256=QJGr2UvYJcU2Gy2w8z_cEvTodjv7hGdd2PPPfdOI-Mw,4017
|
|
20
20
|
kostyl/ml/lightning/loggers/__init__.py,sha256=e51dszaoJbuzwBkbdugmuDsPldoSO4yaRgmZUg1Bdy0,71
|
|
21
21
|
kostyl/ml/lightning/loggers/tb_logger.py,sha256=j02HK5ue8yzXXV8FWKmmXyHkFlIxgHx-ahHWk_rFCZs,893
|
|
22
22
|
kostyl/ml/lightning/steps_estimation.py,sha256=fTZ0IrUEZV3H6VYlx4GYn56oco56mMiB7FO9F0Z7qc4,1511
|
|
@@ -30,6 +30,6 @@ kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
|
|
|
30
30
|
kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
|
|
31
31
|
kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
|
|
32
32
|
kostyl/utils/logging.py,sha256=Vye0u4-yeOSUc-f03gpQbxSktTbFiilTWLEVr00ZHvc,5796
|
|
33
|
-
kostyl_toolkit-0.1.
|
|
34
|
-
kostyl_toolkit-0.1.
|
|
35
|
-
kostyl_toolkit-0.1.
|
|
33
|
+
kostyl_toolkit-0.1.24.dist-info/WHEEL,sha256=z-mOpxbJHqy3cq6SvUThBZdaLGFZzdZPtgWLcP2NKjQ,79
|
|
34
|
+
kostyl_toolkit-0.1.24.dist-info/METADATA,sha256=uq8MPJ9vJgWsp9Z2c7C9tcbaH29QM9ux7_SyahPSlHE,4269
|
|
35
|
+
kostyl_toolkit-0.1.24.dist-info/RECORD,,
|
|
File without changes
|