kostyl-toolkit 0.1.23__tar.gz → 0.1.24__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/PKG-INFO +1 -1
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/clearml/pulling_utils.py +1 -1
- kostyl_toolkit-0.1.24/kostyl/ml/lightning/extenstions/pretrained_model.py +105 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/pyproject.toml +1 -1
- kostyl_toolkit-0.1.23/kostyl/ml/lightning/extenstions/pretrained_model.py +0 -202
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/README.md +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/__init__.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/__init__.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/clearml/__init__.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/clearml/dataset_utils.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/clearml/logging_utils.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/configs/__init__.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/configs/base_model.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/configs/hyperparams.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/configs/training_settings.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/dist_utils.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/__init__.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/callbacks/__init__.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/callbacks/checkpoint.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/callbacks/early_stopping.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/callbacks/registry_uploader.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/extenstions/__init__.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/extenstions/custom_module.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/loggers/__init__.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/loggers/tb_logger.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/steps_estimation.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/metrics_formatting.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/params_groups.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/schedulers/__init__.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/schedulers/base.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/schedulers/composite.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/schedulers/cosine.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/utils/__init__.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/utils/dict_manipulations.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/utils/fs.py +0 -0
- {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/utils/logging.py +0 -0
|
@@ -95,7 +95,7 @@ def get_model_from_clearml[
|
|
|
95
95
|
raise ValueError(
|
|
96
96
|
f"Model class {model.__name__} is not compatible with Lightning checkpoints."
|
|
97
97
|
)
|
|
98
|
-
model_instance = model.
|
|
98
|
+
model_instance = model.from_lightning_checkpoint(local_path, **kwargs)
|
|
99
99
|
else:
|
|
100
100
|
raise ValueError(
|
|
101
101
|
f"Unsupported model format for path: {local_path}. "
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Any
|
|
3
|
+
from typing import cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from transformers import PretrainedConfig
|
|
7
|
+
from transformers import PreTrainedModel
|
|
8
|
+
|
|
9
|
+
from kostyl.utils.logging import setup_logger
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
16
|
+
"""A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
def from_lightning_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
|
|
20
|
+
cls: type[TModelInstance],
|
|
21
|
+
checkpoint_path: str | Path,
|
|
22
|
+
config_key: str = "config",
|
|
23
|
+
weights_prefix: str = "model.",
|
|
24
|
+
**kwargs: Any,
|
|
25
|
+
) -> TModelInstance:
|
|
26
|
+
"""
|
|
27
|
+
Load a model from a Lightning checkpoint file.
|
|
28
|
+
|
|
29
|
+
This class method loads a pretrained model from a PyTorch Lightning checkpoint file (.ckpt).
|
|
30
|
+
It extracts the model configuration from the checkpoint, instantiates the model, and loads
|
|
31
|
+
the state dictionary, handling any incompatible keys.
|
|
32
|
+
|
|
33
|
+
Note:
|
|
34
|
+
The method uses `torch.load` with `weights_only=False` and `mmap=True` for loading.
|
|
35
|
+
Incompatible keys (missing, unexpected, mismatched) are collected and optionally logged.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
cls (type["LightningPretrainedModelMixin"]): The class of the model to instantiate.
|
|
39
|
+
checkpoint_path (str | Path): Path to the checkpoint file. Must be a .ckpt file.
|
|
40
|
+
config_key (str, optional): Key in the checkpoint dictionary where the config is stored.
|
|
41
|
+
Defaults to "config".
|
|
42
|
+
weights_prefix (str, optional): Prefix to strip from state dict keys. Defaults to "model.".
|
|
43
|
+
If not empty and doesn't end with ".", a "." is appended.
|
|
44
|
+
kwargs: Additional keyword arguments to pass to the model's `from_pretrained` method.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
TModelInstance: The loaded model instance.
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
ValueError: If checkpoint_path is a directory, not a .ckpt file, or invalid.
|
|
51
|
+
FileNotFoundError: If the checkpoint file does not exist.
|
|
52
|
+
|
|
53
|
+
"""
|
|
54
|
+
if isinstance(checkpoint_path, str):
|
|
55
|
+
checkpoint_path = Path(checkpoint_path)
|
|
56
|
+
|
|
57
|
+
if checkpoint_path.is_dir():
|
|
58
|
+
raise ValueError(f"{checkpoint_path} is a directory")
|
|
59
|
+
if not checkpoint_path.exists():
|
|
60
|
+
raise FileNotFoundError(f"{checkpoint_path} does not exist")
|
|
61
|
+
if checkpoint_path.suffix != ".ckpt":
|
|
62
|
+
raise ValueError(f"{checkpoint_path} is not a .ckpt file")
|
|
63
|
+
|
|
64
|
+
checkpoint_dict = torch.load(
|
|
65
|
+
checkpoint_path,
|
|
66
|
+
map_location="cpu",
|
|
67
|
+
weights_only=False,
|
|
68
|
+
mmap=True,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# 1. Восстанавливаем конфиг
|
|
72
|
+
config_cls = cast(type[PretrainedConfig], cls.config_class)
|
|
73
|
+
config_dict = checkpoint_dict[config_key]
|
|
74
|
+
config_dict.update(kwargs)
|
|
75
|
+
config = config_cls.from_dict(config_dict)
|
|
76
|
+
|
|
77
|
+
kwargs_for_model: dict[str, Any] = {}
|
|
78
|
+
for key, value in kwargs.items():
|
|
79
|
+
if not hasattr(config, key):
|
|
80
|
+
kwargs_for_model[key] = value
|
|
81
|
+
|
|
82
|
+
raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
|
|
83
|
+
|
|
84
|
+
if weights_prefix:
|
|
85
|
+
if not weights_prefix.endswith("."):
|
|
86
|
+
weights_prefix = weights_prefix + "."
|
|
87
|
+
state_dict: dict[str, torch.Tensor] = {}
|
|
88
|
+
|
|
89
|
+
for key, value in raw_state_dict.items():
|
|
90
|
+
if key.startswith(weights_prefix):
|
|
91
|
+
new_key = key[len(weights_prefix) :]
|
|
92
|
+
state_dict[new_key] = value
|
|
93
|
+
else:
|
|
94
|
+
state_dict[key] = value
|
|
95
|
+
else:
|
|
96
|
+
state_dict = raw_state_dict
|
|
97
|
+
|
|
98
|
+
model = cls.from_pretrained(
|
|
99
|
+
pretrained_model_name_or_path=None,
|
|
100
|
+
config=config,
|
|
101
|
+
state_dict=state_dict,
|
|
102
|
+
**kwargs_for_model,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return model
|
|
@@ -1,202 +0,0 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
|
-
from typing import Any
|
|
3
|
-
from typing import cast
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
from transformers import PretrainedConfig
|
|
7
|
-
from transformers import PreTrainedModel
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
try:
|
|
11
|
-
from peft import PeftConfig
|
|
12
|
-
except ImportError:
|
|
13
|
-
PeftConfig = None # ty: ignore
|
|
14
|
-
|
|
15
|
-
from kostyl.utils.logging import log_incompatible_keys
|
|
16
|
-
from kostyl.utils.logging import setup_logger
|
|
17
|
-
from torch import nn
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
24
|
-
"""A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
|
|
25
|
-
|
|
26
|
-
@classmethod
|
|
27
|
-
def from_lighting_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
|
|
28
|
-
cls: type[TModelInstance],
|
|
29
|
-
checkpoint_path: str | Path,
|
|
30
|
-
config_key: str = "config",
|
|
31
|
-
weights_prefix: str = "model.",
|
|
32
|
-
should_log_incompatible_keys: bool = True,
|
|
33
|
-
**kwargs: Any,
|
|
34
|
-
) -> TModelInstance:
|
|
35
|
-
"""
|
|
36
|
-
Load a model from a Lightning checkpoint file.
|
|
37
|
-
|
|
38
|
-
This class method loads a pretrained model from a PyTorch Lightning checkpoint file (.ckpt).
|
|
39
|
-
It extracts the model configuration from the checkpoint, instantiates the model, and loads
|
|
40
|
-
the state dictionary, handling any incompatible keys.
|
|
41
|
-
|
|
42
|
-
Note:
|
|
43
|
-
The method uses `torch.load` with `weights_only=False` and `mmap=True` for loading.
|
|
44
|
-
Incompatible keys (missing, unexpected, mismatched) are collected and optionally logged.
|
|
45
|
-
|
|
46
|
-
Args:
|
|
47
|
-
cls (type["LightningPretrainedModelMixin"]): The class of the model to instantiate.
|
|
48
|
-
checkpoint_path (str | Path): Path to the checkpoint file. Must be a .ckpt file.
|
|
49
|
-
config_key (str, optional): Key in the checkpoint dictionary where the config is stored.
|
|
50
|
-
Defaults to "config".
|
|
51
|
-
weights_prefix (str, optional): Prefix to strip from state dict keys. Defaults to "model.".
|
|
52
|
-
If not empty and doesn't end with ".", a "." is appended.
|
|
53
|
-
should_log_incompatible_keys (bool, optional): Whether to log incompatible keys. Defaults to True.
|
|
54
|
-
**kwargs: Additional keyword arguments to pass to the model loading method.
|
|
55
|
-
|
|
56
|
-
Returns:
|
|
57
|
-
TModelInstance: The loaded model instance.
|
|
58
|
-
|
|
59
|
-
Raises:
|
|
60
|
-
ValueError: If checkpoint_path is a directory, not a .ckpt file, or invalid.
|
|
61
|
-
FileNotFoundError: If the checkpoint file does not exist.
|
|
62
|
-
|
|
63
|
-
"""
|
|
64
|
-
if isinstance(checkpoint_path, str):
|
|
65
|
-
checkpoint_path = Path(checkpoint_path)
|
|
66
|
-
|
|
67
|
-
if checkpoint_path.is_dir():
|
|
68
|
-
raise ValueError(f"{checkpoint_path} is a directory")
|
|
69
|
-
if not checkpoint_path.exists():
|
|
70
|
-
raise FileNotFoundError(f"{checkpoint_path} does not exist")
|
|
71
|
-
if checkpoint_path.suffix != ".ckpt":
|
|
72
|
-
raise ValueError(f"{checkpoint_path} is not a .ckpt file")
|
|
73
|
-
|
|
74
|
-
checkpoint_dict = torch.load(
|
|
75
|
-
checkpoint_path,
|
|
76
|
-
map_location="cpu",
|
|
77
|
-
weights_only=False,
|
|
78
|
-
mmap=True,
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
# 1. Восстанавливаем конфиг
|
|
82
|
-
config_cls = cast(type[PretrainedConfig], cls.config_class)
|
|
83
|
-
config_dict = checkpoint_dict[config_key]
|
|
84
|
-
config_dict.update(kwargs)
|
|
85
|
-
config = config_cls.from_dict(config_dict)
|
|
86
|
-
|
|
87
|
-
kwargs_for_model: dict[str, Any] = {}
|
|
88
|
-
for key, value in kwargs.items():
|
|
89
|
-
if not hasattr(config, key):
|
|
90
|
-
kwargs_for_model[key] = value
|
|
91
|
-
|
|
92
|
-
with torch.device("meta"):
|
|
93
|
-
model = cls(config, **kwargs_for_model)
|
|
94
|
-
|
|
95
|
-
# PEFT-адаптеры (оставляю твою логику как есть)
|
|
96
|
-
if "peft_config" in checkpoint_dict:
|
|
97
|
-
if PeftConfig is None:
|
|
98
|
-
raise ImportError(
|
|
99
|
-
"peft is not installed. Please install it to load PEFT models."
|
|
100
|
-
)
|
|
101
|
-
for name, adapter_dict in checkpoint_dict["peft_config"].items():
|
|
102
|
-
peft_cfg = PeftConfig.from_peft_type(**adapter_dict)
|
|
103
|
-
model.add_adapter(peft_cfg, adapter_name=name)
|
|
104
|
-
|
|
105
|
-
incompatible_keys: dict[str, list[str]] = {}
|
|
106
|
-
|
|
107
|
-
raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
|
|
108
|
-
|
|
109
|
-
if weights_prefix:
|
|
110
|
-
if not weights_prefix.endswith("."):
|
|
111
|
-
weights_prefix = weights_prefix + "."
|
|
112
|
-
state_dict: dict[str, torch.Tensor] = {}
|
|
113
|
-
mismatched_keys: list[str] = []
|
|
114
|
-
|
|
115
|
-
for key, value in raw_state_dict.items():
|
|
116
|
-
if key.startswith(weights_prefix):
|
|
117
|
-
new_key = key[len(weights_prefix) :]
|
|
118
|
-
state_dict[new_key] = value
|
|
119
|
-
else:
|
|
120
|
-
mismatched_keys.append(key)
|
|
121
|
-
|
|
122
|
-
if mismatched_keys:
|
|
123
|
-
incompatible_keys["mismatched_keys"] = mismatched_keys
|
|
124
|
-
else:
|
|
125
|
-
state_dict = raw_state_dict
|
|
126
|
-
|
|
127
|
-
# 5. Логика base_model_prefix как в HF:
|
|
128
|
-
# поддержка загрузки базовой модели <-> модели с головой
|
|
129
|
-
#
|
|
130
|
-
# cls.base_model_prefix обычно "model" / "bert" / "encoder" и т.п.
|
|
131
|
-
base_prefix: str = getattr(cls, "base_model_prefix", "") or ""
|
|
132
|
-
model_to_load: nn.Module = model
|
|
133
|
-
|
|
134
|
-
if base_prefix:
|
|
135
|
-
prefix_with_dot = base_prefix + "."
|
|
136
|
-
loaded_keys = list(state_dict.keys())
|
|
137
|
-
full_model_state = model.state_dict()
|
|
138
|
-
expected_keys = list(full_model_state.keys())
|
|
139
|
-
|
|
140
|
-
has_prefix_module = any(k.startswith(prefix_with_dot) for k in loaded_keys)
|
|
141
|
-
expects_prefix_module = any(
|
|
142
|
-
k.startswith(prefix_with_dot) for k in expected_keys
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
# Кейc 1: загружаем базовую модель в модель с головой.
|
|
146
|
-
# Пример: StaticEmbeddingsForSequenceClassification (имеет .model)
|
|
147
|
-
# state_dict с ключами "embeddings.weight", "token_pos_weights", ...
|
|
148
|
-
if (
|
|
149
|
-
hasattr(model, base_prefix)
|
|
150
|
-
and not has_prefix_module
|
|
151
|
-
and expects_prefix_module
|
|
152
|
-
):
|
|
153
|
-
# Веса без префикса -> грузим только в model.<base_prefix>
|
|
154
|
-
model_to_load = getattr(model, base_prefix)
|
|
155
|
-
|
|
156
|
-
# Кейc 2: загружаем чекпоинт модели с головой в базовую модель.
|
|
157
|
-
# Пример: BertModel, а в state_dict ключи "bert.encoder.layer.0..."
|
|
158
|
-
elif (
|
|
159
|
-
not hasattr(model, base_prefix)
|
|
160
|
-
and has_prefix_module
|
|
161
|
-
and not expects_prefix_module
|
|
162
|
-
):
|
|
163
|
-
new_state_dict: dict[str, torch.Tensor] = {}
|
|
164
|
-
for key, value in state_dict.items():
|
|
165
|
-
if key.startswith(prefix_with_dot):
|
|
166
|
-
new_key = key[len(prefix_with_dot) :]
|
|
167
|
-
else:
|
|
168
|
-
new_key = key
|
|
169
|
-
new_state_dict[new_key] = value
|
|
170
|
-
state_dict = new_state_dict
|
|
171
|
-
|
|
172
|
-
load_result = model_to_load.load_state_dict(
|
|
173
|
-
state_dict, strict=False, assign=True
|
|
174
|
-
)
|
|
175
|
-
missing_keys, unexpected_keys = (
|
|
176
|
-
load_result.missing_keys,
|
|
177
|
-
load_result.unexpected_keys,
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
# Если мы грузили только в base-подмодуль, расширим missing_keys
|
|
181
|
-
# до полного списка (base + голова), как в старых версиях HF.
|
|
182
|
-
if model_to_load is not model and base_prefix:
|
|
183
|
-
base_keys = set(model_to_load.state_dict().keys())
|
|
184
|
-
# Приводим ключи полной модели к "безпрефиксному" виду
|
|
185
|
-
head_like_keys = set()
|
|
186
|
-
prefix_with_dot = base_prefix + "."
|
|
187
|
-
for k in model.state_dict().keys():
|
|
188
|
-
if k.startswith(prefix_with_dot):
|
|
189
|
-
# отрезаем "model."
|
|
190
|
-
head_like_keys.add(k[len(prefix_with_dot) :])
|
|
191
|
-
else:
|
|
192
|
-
head_like_keys.add(k)
|
|
193
|
-
extra_missing = sorted(head_like_keys - base_keys)
|
|
194
|
-
missing_keys = list(missing_keys) + extra_missing
|
|
195
|
-
|
|
196
|
-
incompatible_keys["missing_keys"] = missing_keys
|
|
197
|
-
incompatible_keys["unexpected_keys"] = unexpected_keys
|
|
198
|
-
|
|
199
|
-
if should_log_incompatible_keys:
|
|
200
|
-
log_incompatible_keys(incompatible_keys=incompatible_keys, logger=logger)
|
|
201
|
-
|
|
202
|
-
return model
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/callbacks/early_stopping.py
RENAMED
|
File without changes
|
{kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/callbacks/registry_uploader.py
RENAMED
|
File without changes
|
|
File without changes
|
{kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/extenstions/custom_module.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|