kostyl-toolkit 0.1.21__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kostyl/ml/lightning/callbacks/registry_uploading.py +22 -4
- kostyl/ml/lightning/extenstions/pretrained_model.py +93 -16
- {kostyl_toolkit-0.1.21.dist-info → kostyl_toolkit-0.1.22.dist-info}/METADATA +1 -1
- {kostyl_toolkit-0.1.21.dist-info → kostyl_toolkit-0.1.22.dist-info}/RECORD +5 -5
- {kostyl_toolkit-0.1.21.dist-info → kostyl_toolkit-0.1.22.dist-info}/WHEEL +0 -0
|
@@ -27,6 +27,8 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
27
27
|
output_model_tags: list[str] | None = None,
|
|
28
28
|
verbose: bool = True,
|
|
29
29
|
enable_tag_versioning: bool = True,
|
|
30
|
+
label_enumeration: dict[str, int] | None = None,
|
|
31
|
+
config_dict: dict[str, str] | None = None,
|
|
30
32
|
uploading_frequency: Literal[
|
|
31
33
|
"after-every-eval", "on-train-end"
|
|
32
34
|
] = "on-train-end",
|
|
@@ -40,6 +42,8 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
40
42
|
output_model_name: Name for the ClearML output model.
|
|
41
43
|
output_model_tags: Tags for the output model.
|
|
42
44
|
verbose: Whether to log messages.
|
|
45
|
+
label_enumeration: Optional mapping of label names to integer IDs.
|
|
46
|
+
config_dict: Optional configuration dictionary to associate with the model.
|
|
43
47
|
enable_tag_versioning: Whether to enable versioning in tags. If True,
|
|
44
48
|
the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
|
|
45
49
|
uploading_frequency: When to upload:
|
|
@@ -55,6 +59,8 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
55
59
|
self.ckpt_callback = ckpt_callback
|
|
56
60
|
self.output_model_name = output_model_name
|
|
57
61
|
self.output_model_tags = output_model_tags
|
|
62
|
+
self.config_dict = config_dict
|
|
63
|
+
self.label_enumeration = label_enumeration
|
|
58
64
|
self.verbose = verbose
|
|
59
65
|
self.uploading_frequency = uploading_frequency
|
|
60
66
|
self.enable_tag_versioning = enable_tag_versioning
|
|
@@ -75,16 +81,21 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
75
81
|
|
|
76
82
|
if "LightningCheckpoint" not in self.output_model_tags:
|
|
77
83
|
self.output_model_tags.append("LightningCheckpoint")
|
|
78
|
-
|
|
79
|
-
if
|
|
80
|
-
config =
|
|
84
|
+
|
|
85
|
+
if self.config_dict is None:
|
|
86
|
+
config = pl_module.model_config
|
|
87
|
+
if config is not None:
|
|
88
|
+
config = config.to_dict()
|
|
89
|
+
else:
|
|
90
|
+
config = self.config_dict
|
|
81
91
|
|
|
82
92
|
return OutputModel(
|
|
83
93
|
task=self.task,
|
|
84
94
|
name=self.output_model_name,
|
|
85
95
|
framework="PyTorch",
|
|
86
96
|
tags=self.output_model_tags,
|
|
87
|
-
config_dict=
|
|
97
|
+
config_dict=None,
|
|
98
|
+
label_enumeration=self.label_enumeration,
|
|
88
99
|
)
|
|
89
100
|
|
|
90
101
|
def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
|
|
@@ -111,6 +122,13 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
111
122
|
auto_delete_file=False,
|
|
112
123
|
async_enable=False,
|
|
113
124
|
)
|
|
125
|
+
if self.config_dict is None:
|
|
126
|
+
config = pl_module.model_config
|
|
127
|
+
if config is not None:
|
|
128
|
+
config = config.to_dict()
|
|
129
|
+
else:
|
|
130
|
+
config = self.config_dict
|
|
131
|
+
self._output_model.update_design(config_dict=config)
|
|
114
132
|
|
|
115
133
|
self._last_best_model_path = current_best
|
|
116
134
|
return
|
|
@@ -14,6 +14,7 @@ except ImportError:
|
|
|
14
14
|
|
|
15
15
|
from kostyl.utils.logging import log_incompatible_keys
|
|
16
16
|
from kostyl.utils.logging import setup_logger
|
|
17
|
+
from torch import nn
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
|
|
@@ -67,7 +68,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
67
68
|
raise ValueError(f"{checkpoint_path} is a directory")
|
|
68
69
|
if not checkpoint_path.exists():
|
|
69
70
|
raise FileNotFoundError(f"{checkpoint_path} does not exist")
|
|
70
|
-
if
|
|
71
|
+
if checkpoint_path.suffix != ".ckpt":
|
|
71
72
|
raise ValueError(f"{checkpoint_path} is not a .ckpt file")
|
|
72
73
|
|
|
73
74
|
checkpoint_dict = torch.load(
|
|
@@ -77,19 +78,21 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
77
78
|
mmap=True,
|
|
78
79
|
)
|
|
79
80
|
|
|
80
|
-
|
|
81
|
+
# 1. Восстанавливаем конфиг
|
|
82
|
+
config_cls = cast(type[PretrainedConfig], cls.config_class)
|
|
81
83
|
config_dict = checkpoint_dict[config_key]
|
|
82
84
|
config_dict.update(kwargs)
|
|
83
85
|
config = config_cls.from_dict(config_dict)
|
|
84
86
|
|
|
85
|
-
kwargs_for_model = {}
|
|
86
|
-
for key in kwargs:
|
|
87
|
+
kwargs_for_model: dict[str, Any] = {}
|
|
88
|
+
for key, value in kwargs.items():
|
|
87
89
|
if not hasattr(config, key):
|
|
88
|
-
kwargs_for_model[key] =
|
|
90
|
+
kwargs_for_model[key] = value
|
|
89
91
|
|
|
90
92
|
with torch.device("meta"):
|
|
91
93
|
model = cls(config, **kwargs_for_model)
|
|
92
94
|
|
|
95
|
+
# PEFT-адаптеры (оставляю твою логику как есть)
|
|
93
96
|
if "peft_config" in checkpoint_dict:
|
|
94
97
|
if PeftConfig is None:
|
|
95
98
|
raise ImportError(
|
|
@@ -100,26 +103,100 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
100
103
|
model.add_adapter(peft_cfg, adapter_name=name)
|
|
101
104
|
|
|
102
105
|
incompatible_keys: dict[str, list[str]] = {}
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
106
|
+
|
|
107
|
+
raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
|
|
108
|
+
|
|
109
|
+
if weights_prefix:
|
|
110
|
+
if not weights_prefix.endswith("."):
|
|
111
|
+
weights_prefix = weights_prefix + "."
|
|
112
|
+
state_dict: dict[str, torch.Tensor] = {}
|
|
113
|
+
mismatched_keys: list[str] = []
|
|
114
|
+
|
|
115
|
+
for key, value in raw_state_dict.items():
|
|
109
116
|
if key.startswith(weights_prefix):
|
|
110
117
|
new_key = key[len(weights_prefix) :]
|
|
111
|
-
|
|
118
|
+
state_dict[new_key] = value
|
|
112
119
|
else:
|
|
113
120
|
mismatched_keys.append(key)
|
|
121
|
+
|
|
122
|
+
if mismatched_keys:
|
|
114
123
|
incompatible_keys["mismatched_keys"] = mismatched_keys
|
|
115
124
|
else:
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
125
|
+
state_dict = raw_state_dict
|
|
126
|
+
|
|
127
|
+
# 5. Логика base_model_prefix как в HF:
|
|
128
|
+
# поддержка загрузки базовой модели <-> модели с головой
|
|
129
|
+
#
|
|
130
|
+
# cls.base_model_prefix обычно "model" / "bert" / "encoder" и т.п.
|
|
131
|
+
base_prefix: str = getattr(cls, "base_model_prefix", "") or ""
|
|
132
|
+
model_to_load: nn.Module = model
|
|
133
|
+
|
|
134
|
+
if base_prefix:
|
|
135
|
+
prefix_with_dot = base_prefix + "."
|
|
136
|
+
loaded_keys = list(state_dict.keys())
|
|
137
|
+
full_model_state = model.state_dict()
|
|
138
|
+
expected_keys = list(full_model_state.keys())
|
|
139
|
+
|
|
140
|
+
has_prefix_module = any(k.startswith(prefix_with_dot) for k in loaded_keys)
|
|
141
|
+
expects_prefix_module = any(
|
|
142
|
+
k.startswith(prefix_with_dot) for k in expected_keys
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Кейc 1: загружаем базовую модель в модель с головой.
|
|
146
|
+
# Пример: StaticEmbeddingsForSequenceClassification (имеет .model)
|
|
147
|
+
# state_dict с ключами "embeddings.weight", "token_pos_weights", ...
|
|
148
|
+
if (
|
|
149
|
+
hasattr(model, base_prefix)
|
|
150
|
+
and not has_prefix_module
|
|
151
|
+
and expects_prefix_module
|
|
152
|
+
):
|
|
153
|
+
# Веса без префикса -> грузим только в model.<base_prefix>
|
|
154
|
+
model_to_load = getattr(model, base_prefix)
|
|
155
|
+
|
|
156
|
+
# Кейc 2: загружаем чекпоинт модели с головой в базовую модель.
|
|
157
|
+
# Пример: BertModel, а в state_dict ключи "bert.encoder.layer.0..."
|
|
158
|
+
elif (
|
|
159
|
+
not hasattr(model, base_prefix)
|
|
160
|
+
and has_prefix_module
|
|
161
|
+
and not expects_prefix_module
|
|
162
|
+
):
|
|
163
|
+
new_state_dict: dict[str, torch.Tensor] = {}
|
|
164
|
+
for key, value in state_dict.items():
|
|
165
|
+
if key.startswith(prefix_with_dot):
|
|
166
|
+
new_key = key[len(prefix_with_dot) :]
|
|
167
|
+
else:
|
|
168
|
+
new_key = key
|
|
169
|
+
new_state_dict[new_key] = value
|
|
170
|
+
state_dict = new_state_dict
|
|
171
|
+
|
|
172
|
+
load_result = model_to_load.load_state_dict(
|
|
173
|
+
state_dict, strict=False, assign=True
|
|
174
|
+
)
|
|
175
|
+
missing_keys, unexpected_keys = (
|
|
176
|
+
load_result.missing_keys,
|
|
177
|
+
load_result.unexpected_keys,
|
|
120
178
|
)
|
|
179
|
+
|
|
180
|
+
# Если мы грузили только в base-подмодуль, расширим missing_keys
|
|
181
|
+
# до полного списка (base + голова), как в старых версиях HF.
|
|
182
|
+
if model_to_load is not model and base_prefix:
|
|
183
|
+
base_keys = set(model_to_load.state_dict().keys())
|
|
184
|
+
# Приводим ключи полной модели к "безпрефиксному" виду
|
|
185
|
+
head_like_keys = set()
|
|
186
|
+
prefix_with_dot = base_prefix + "."
|
|
187
|
+
for k in model.state_dict().keys():
|
|
188
|
+
if k.startswith(prefix_with_dot):
|
|
189
|
+
# отрезаем "model."
|
|
190
|
+
head_like_keys.add(k[len(prefix_with_dot) :])
|
|
191
|
+
else:
|
|
192
|
+
head_like_keys.add(k)
|
|
193
|
+
extra_missing = sorted(head_like_keys - base_keys)
|
|
194
|
+
missing_keys = list(missing_keys) + extra_missing
|
|
195
|
+
|
|
121
196
|
incompatible_keys["missing_keys"] = missing_keys
|
|
122
197
|
incompatible_keys["unexpected_keys"] = unexpected_keys
|
|
198
|
+
|
|
123
199
|
if should_log_incompatible_keys:
|
|
124
200
|
log_incompatible_keys(incompatible_keys=incompatible_keys, logger=logger)
|
|
201
|
+
|
|
125
202
|
return model
|
|
@@ -13,10 +13,10 @@ kostyl/ml/lightning/__init__.py,sha256=-F3JAyq8KU1d-nACWryGu8d1CbvWbQ1rXFdeRwfE2
|
|
|
13
13
|
kostyl/ml/lightning/callbacks/__init__.py,sha256=Vd-rozY4T9Prr3IMqbliXxj6sC6y9XsovHQqRwzc2HI,297
|
|
14
14
|
kostyl/ml/lightning/callbacks/checkpoint.py,sha256=FooGeeUz6TtoXQglpcK16NWAmSX3fbu6wntRtK3a_Io,1936
|
|
15
15
|
kostyl/ml/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
|
|
16
|
-
kostyl/ml/lightning/callbacks/registry_uploading.py,sha256=
|
|
16
|
+
kostyl/ml/lightning/callbacks/registry_uploading.py,sha256=32vhMNNuThtEcvRdS5jh5s-wf7LwZNsCTwZA3emcObs,5449
|
|
17
17
|
kostyl/ml/lightning/extenstions/__init__.py,sha256=OY6QGv1agYgqqKf1xJBrxgp_i8FunVfPzYezfaRrGXU,182
|
|
18
18
|
kostyl/ml/lightning/extenstions/custom_module.py,sha256=nB5jW7cqRD1tyh-q5LD2EtiFQwFkLXpnS9Yu6c5xMRg,5987
|
|
19
|
-
kostyl/ml/lightning/extenstions/pretrained_model.py,sha256=
|
|
19
|
+
kostyl/ml/lightning/extenstions/pretrained_model.py,sha256=x8D2nMDDW8J913qFRSEGKXfQO8ipPJM5SLo4Y5kc3YA,8638
|
|
20
20
|
kostyl/ml/lightning/loggers/__init__.py,sha256=e51dszaoJbuzwBkbdugmuDsPldoSO4yaRgmZUg1Bdy0,71
|
|
21
21
|
kostyl/ml/lightning/loggers/tb_logger.py,sha256=j02HK5ue8yzXXV8FWKmmXyHkFlIxgHx-ahHWk_rFCZs,893
|
|
22
22
|
kostyl/ml/lightning/steps_estimation.py,sha256=fTZ0IrUEZV3H6VYlx4GYn56oco56mMiB7FO9F0Z7qc4,1511
|
|
@@ -30,6 +30,6 @@ kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
|
|
|
30
30
|
kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
|
|
31
31
|
kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
|
|
32
32
|
kostyl/utils/logging.py,sha256=Vye0u4-yeOSUc-f03gpQbxSktTbFiilTWLEVr00ZHvc,5796
|
|
33
|
-
kostyl_toolkit-0.1.
|
|
34
|
-
kostyl_toolkit-0.1.
|
|
35
|
-
kostyl_toolkit-0.1.
|
|
33
|
+
kostyl_toolkit-0.1.22.dist-info/WHEEL,sha256=z-mOpxbJHqy3cq6SvUThBZdaLGFZzdZPtgWLcP2NKjQ,79
|
|
34
|
+
kostyl_toolkit-0.1.22.dist-info/METADATA,sha256=GweBJ42Dhbl4Y5PNu-jnffXj1CaJ34DTPUcoFEndJ1M,4269
|
|
35
|
+
kostyl_toolkit-0.1.22.dist-info/RECORD,,
|
|
File without changes
|