kostyl-toolkit 0.1.34__py3-none-any.whl → 0.1.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kostyl/ml/lightning/extensions/pretrained_model.py +27 -5
- {kostyl_toolkit-0.1.34.dist-info → kostyl_toolkit-0.1.35.dist-info}/METADATA +1 -1
- {kostyl_toolkit-0.1.34.dist-info → kostyl_toolkit-0.1.35.dist-info}/RECORD +4 -4
- {kostyl_toolkit-0.1.34.dist-info → kostyl_toolkit-0.1.35.dist-info}/WHEEL +0 -0
|
@@ -20,7 +20,8 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
20
20
|
cls: type[TModelInstance],
|
|
21
21
|
checkpoint_path: str | Path,
|
|
22
22
|
config_key: str = "config",
|
|
23
|
-
weights_prefix: str = "model.",
|
|
23
|
+
weights_prefix: str | None = "model.",
|
|
24
|
+
strict_prefix: bool = False,
|
|
24
25
|
**kwargs: Any,
|
|
25
26
|
) -> TModelInstance:
|
|
26
27
|
"""
|
|
@@ -39,8 +40,10 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
39
40
|
checkpoint_path (str | Path): Path to the checkpoint file. Must be a .ckpt file.
|
|
40
41
|
config_key (str, optional): Key in the checkpoint dictionary where the config is stored.
|
|
41
42
|
Defaults to "config".
|
|
42
|
-
weights_prefix (str, optional): Prefix to strip from state dict keys. Defaults to "model.".
|
|
43
|
-
If not empty and doesn't end with ".", a "." is appended.
|
|
43
|
+
weights_prefix (str | None, optional): Prefix to strip from state dict keys. Defaults to "model.".
|
|
44
|
+
If not empty and doesn't end with ".", a "." is appended. If empty or None, no prefix stripping will be skipped.
|
|
45
|
+
strict_prefix (bool, optional): If True, drop tensors those keys that do not start with the
|
|
46
|
+
specified prefix. Defaults to False.
|
|
44
47
|
kwargs: Additional keyword arguments to pass to the model's `from_pretrained` method.
|
|
45
48
|
|
|
46
49
|
Returns:
|
|
@@ -53,6 +56,13 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
53
56
|
"""
|
|
54
57
|
if isinstance(checkpoint_path, str):
|
|
55
58
|
checkpoint_path = Path(checkpoint_path)
|
|
59
|
+
if weights_prefix is None:
|
|
60
|
+
weights_prefix = ""
|
|
61
|
+
weights_prefix = cast(str, weights_prefix)
|
|
62
|
+
if weights_prefix == "" and strict_prefix:
|
|
63
|
+
logger.warning(
|
|
64
|
+
"strict_prefix=True has no effect when weights_prefix is empty or None."
|
|
65
|
+
)
|
|
56
66
|
|
|
57
67
|
if checkpoint_path.is_dir():
|
|
58
68
|
raise ValueError(f"{checkpoint_path} is a directory")
|
|
@@ -85,13 +95,25 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
85
95
|
if not weights_prefix.endswith("."):
|
|
86
96
|
weights_prefix = weights_prefix + "."
|
|
87
97
|
state_dict: dict[str, torch.Tensor] = {}
|
|
88
|
-
|
|
98
|
+
matched_keys_counter = 0
|
|
89
99
|
for key, value in raw_state_dict.items():
|
|
90
100
|
if key.startswith(weights_prefix):
|
|
91
101
|
new_key = key[len(weights_prefix) :]
|
|
92
102
|
state_dict[new_key] = value
|
|
93
|
-
|
|
103
|
+
matched_keys_counter += 1
|
|
104
|
+
elif not strict_prefix:
|
|
94
105
|
state_dict[key] = value
|
|
106
|
+
|
|
107
|
+
if matched_keys_counter == 0:
|
|
108
|
+
if strict_prefix:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"No keys in the checkpoint start with the specified prefix '{weights_prefix}'. "
|
|
111
|
+
"Try to load with `strict_prefix=False` or verify the prefix."
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
logger.warning(
|
|
115
|
+
f"No keys in the checkpoint start with the specified prefix '{weights_prefix}'. "
|
|
116
|
+
)
|
|
95
117
|
else:
|
|
96
118
|
state_dict = raw_state_dict
|
|
97
119
|
|
|
@@ -16,7 +16,7 @@ kostyl/ml/lightning/callbacks/checkpoint.py,sha256=sZ9OqudO-gXp7FqtWaOH46TXVpeCJ
|
|
|
16
16
|
kostyl/ml/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
|
|
17
17
|
kostyl/ml/lightning/extensions/__init__.py,sha256=OY6QGv1agYgqqKf1xJBrxgp_i8FunVfPzYezfaRrGXU,182
|
|
18
18
|
kostyl/ml/lightning/extensions/custom_module.py,sha256=iQrnPz-WTmRfvLo94C5fQc2Qwa1IpHtUh1sCpVwTSFM,6602
|
|
19
|
-
kostyl/ml/lightning/extensions/pretrained_model.py,sha256=
|
|
19
|
+
kostyl/ml/lightning/extensions/pretrained_model.py,sha256=eRfQBzAjVernHl9A4PP5uTLvjjmcNKPdTu7ABFLq7HI,5196
|
|
20
20
|
kostyl/ml/lightning/loggers/__init__.py,sha256=e51dszaoJbuzwBkbdugmuDsPldoSO4yaRgmZUg1Bdy0,71
|
|
21
21
|
kostyl/ml/lightning/loggers/tb_logger.py,sha256=j02HK5ue8yzXXV8FWKmmXyHkFlIxgHx-ahHWk_rFCZs,893
|
|
22
22
|
kostyl/ml/lightning/training_utils.py,sha256=u7X9ysF9Gqy8CdwacdcDlNQNsbagYAhslbv-1WLJ45k,9052
|
|
@@ -32,6 +32,6 @@ kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
|
|
|
32
32
|
kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
|
|
33
33
|
kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
|
|
34
34
|
kostyl/utils/logging.py,sha256=Vye0u4-yeOSUc-f03gpQbxSktTbFiilTWLEVr00ZHvc,5796
|
|
35
|
-
kostyl_toolkit-0.1.
|
|
36
|
-
kostyl_toolkit-0.1.
|
|
37
|
-
kostyl_toolkit-0.1.
|
|
35
|
+
kostyl_toolkit-0.1.35.dist-info/WHEEL,sha256=ZyFSCYkV2BrxH6-HRVRg3R9Fo7MALzer9KiPYqNxSbo,79
|
|
36
|
+
kostyl_toolkit-0.1.35.dist-info/METADATA,sha256=KL4-Z421DpchI6KUZ6tVATy99urk1OP2OY4Uf5r9R3U,4269
|
|
37
|
+
kostyl_toolkit-0.1.35.dist-info/RECORD,,
|
|
File without changes
|