kostyl-toolkit 0.1.23__tar.gz → 0.1.24__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/PKG-INFO +1 -1
  2. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/clearml/pulling_utils.py +1 -1
  3. kostyl_toolkit-0.1.24/kostyl/ml/lightning/extenstions/pretrained_model.py +105 -0
  4. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/pyproject.toml +1 -1
  5. kostyl_toolkit-0.1.23/kostyl/ml/lightning/extenstions/pretrained_model.py +0 -202
  6. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/README.md +0 -0
  7. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/__init__.py +0 -0
  8. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/__init__.py +0 -0
  9. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/clearml/__init__.py +0 -0
  10. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/clearml/dataset_utils.py +0 -0
  11. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/clearml/logging_utils.py +0 -0
  12. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/configs/__init__.py +0 -0
  13. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/configs/base_model.py +0 -0
  14. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/configs/hyperparams.py +0 -0
  15. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/configs/training_settings.py +0 -0
  16. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/dist_utils.py +0 -0
  17. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/__init__.py +0 -0
  18. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/callbacks/__init__.py +0 -0
  19. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/callbacks/checkpoint.py +0 -0
  20. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/callbacks/early_stopping.py +0 -0
  21. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/callbacks/registry_uploader.py +0 -0
  22. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/extenstions/__init__.py +0 -0
  23. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/extenstions/custom_module.py +0 -0
  24. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/loggers/__init__.py +0 -0
  25. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/loggers/tb_logger.py +0 -0
  26. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/lightning/steps_estimation.py +0 -0
  27. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/metrics_formatting.py +0 -0
  28. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/params_groups.py +0 -0
  29. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/schedulers/__init__.py +0 -0
  30. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/schedulers/base.py +0 -0
  31. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/schedulers/composite.py +0 -0
  32. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/ml/schedulers/cosine.py +0 -0
  33. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/utils/__init__.py +0 -0
  34. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/utils/dict_manipulations.py +0 -0
  35. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/utils/fs.py +0 -0
  36. {kostyl_toolkit-0.1.23 → kostyl_toolkit-0.1.24}/kostyl/utils/logging.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.23
3
+ Version: 0.1.24
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -95,7 +95,7 @@ def get_model_from_clearml[
95
95
  raise ValueError(
96
96
  f"Model class {model.__name__} is not compatible with Lightning checkpoints."
97
97
  )
98
- model_instance = model.from_lighting_checkpoint(local_path, **kwargs)
98
+ model_instance = model.from_lightning_checkpoint(local_path, **kwargs)
99
99
  else:
100
100
  raise ValueError(
101
101
  f"Unsupported model format for path: {local_path}. "
@@ -0,0 +1,105 @@
1
+ from pathlib import Path
2
+ from typing import Any
3
+ from typing import cast
4
+
5
+ import torch
6
+ from transformers import PretrainedConfig
7
+ from transformers import PreTrainedModel
8
+
9
+ from kostyl.utils.logging import setup_logger
10
+
11
+
12
+ logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
13
+
14
+
15
+ class LightningCheckpointLoaderMixin(PreTrainedModel):
16
+ """A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
17
+
18
+ @classmethod
19
+ def from_lightning_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
20
+ cls: type[TModelInstance],
21
+ checkpoint_path: str | Path,
22
+ config_key: str = "config",
23
+ weights_prefix: str = "model.",
24
+ **kwargs: Any,
25
+ ) -> TModelInstance:
26
+ """
27
+ Load a model from a Lightning checkpoint file.
28
+
29
+ This class method loads a pretrained model from a PyTorch Lightning checkpoint file (.ckpt).
30
+ It extracts the model configuration from the checkpoint, instantiates the model, and loads
31
+ the state dictionary, handling any incompatible keys.
32
+
33
+ Note:
34
+ The method uses `torch.load` with `weights_only=False` and `mmap=True` for loading.
35
+ Incompatible keys (missing, unexpected, mismatched) are collected and optionally logged.
36
+
37
+ Args:
38
+ cls (type["LightningPretrainedModelMixin"]): The class of the model to instantiate.
39
+ checkpoint_path (str | Path): Path to the checkpoint file. Must be a .ckpt file.
40
+ config_key (str, optional): Key in the checkpoint dictionary where the config is stored.
41
+ Defaults to "config".
42
+ weights_prefix (str, optional): Prefix to strip from state dict keys. Defaults to "model.".
43
+ If not empty and doesn't end with ".", a "." is appended.
44
+ kwargs: Additional keyword arguments to pass to the model's `from_pretrained` method.
45
+
46
+ Returns:
47
+ TModelInstance: The loaded model instance.
48
+
49
+ Raises:
50
+ ValueError: If checkpoint_path is a directory, not a .ckpt file, or invalid.
51
+ FileNotFoundError: If the checkpoint file does not exist.
52
+
53
+ """
54
+ if isinstance(checkpoint_path, str):
55
+ checkpoint_path = Path(checkpoint_path)
56
+
57
+ if checkpoint_path.is_dir():
58
+ raise ValueError(f"{checkpoint_path} is a directory")
59
+ if not checkpoint_path.exists():
60
+ raise FileNotFoundError(f"{checkpoint_path} does not exist")
61
+ if checkpoint_path.suffix != ".ckpt":
62
+ raise ValueError(f"{checkpoint_path} is not a .ckpt file")
63
+
64
+ checkpoint_dict = torch.load(
65
+ checkpoint_path,
66
+ map_location="cpu",
67
+ weights_only=False,
68
+ mmap=True,
69
+ )
70
+
71
+ # 1. Восстанавливаем конфиг
72
+ config_cls = cast(type[PretrainedConfig], cls.config_class)
73
+ config_dict = checkpoint_dict[config_key]
74
+ config_dict.update(kwargs)
75
+ config = config_cls.from_dict(config_dict)
76
+
77
+ kwargs_for_model: dict[str, Any] = {}
78
+ for key, value in kwargs.items():
79
+ if not hasattr(config, key):
80
+ kwargs_for_model[key] = value
81
+
82
+ raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
83
+
84
+ if weights_prefix:
85
+ if not weights_prefix.endswith("."):
86
+ weights_prefix = weights_prefix + "."
87
+ state_dict: dict[str, torch.Tensor] = {}
88
+
89
+ for key, value in raw_state_dict.items():
90
+ if key.startswith(weights_prefix):
91
+ new_key = key[len(weights_prefix) :]
92
+ state_dict[new_key] = value
93
+ else:
94
+ state_dict[key] = value
95
+ else:
96
+ state_dict = raw_state_dict
97
+
98
+ model = cls.from_pretrained(
99
+ pretrained_model_name_or_path=None,
100
+ config=config,
101
+ state_dict=state_dict,
102
+ **kwargs_for_model,
103
+ )
104
+
105
+ return model
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "kostyl-toolkit"
3
- version = "0.1.23"
3
+ version = "0.1.24"
4
4
  description = "Kickass Orchestration System for Training, Yielding & Logging "
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -1,202 +0,0 @@
1
- from pathlib import Path
2
- from typing import Any
3
- from typing import cast
4
-
5
- import torch
6
- from transformers import PretrainedConfig
7
- from transformers import PreTrainedModel
8
-
9
-
10
- try:
11
- from peft import PeftConfig
12
- except ImportError:
13
- PeftConfig = None # ty: ignore
14
-
15
- from kostyl.utils.logging import log_incompatible_keys
16
- from kostyl.utils.logging import setup_logger
17
- from torch import nn
18
-
19
-
20
- logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
21
-
22
-
23
- class LightningCheckpointLoaderMixin(PreTrainedModel):
24
- """A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
25
-
26
- @classmethod
27
- def from_lighting_checkpoint[TModelInstance: LightningCheckpointLoaderMixin]( # noqa: C901
28
- cls: type[TModelInstance],
29
- checkpoint_path: str | Path,
30
- config_key: str = "config",
31
- weights_prefix: str = "model.",
32
- should_log_incompatible_keys: bool = True,
33
- **kwargs: Any,
34
- ) -> TModelInstance:
35
- """
36
- Load a model from a Lightning checkpoint file.
37
-
38
- This class method loads a pretrained model from a PyTorch Lightning checkpoint file (.ckpt).
39
- It extracts the model configuration from the checkpoint, instantiates the model, and loads
40
- the state dictionary, handling any incompatible keys.
41
-
42
- Note:
43
- The method uses `torch.load` with `weights_only=False` and `mmap=True` for loading.
44
- Incompatible keys (missing, unexpected, mismatched) are collected and optionally logged.
45
-
46
- Args:
47
- cls (type["LightningPretrainedModelMixin"]): The class of the model to instantiate.
48
- checkpoint_path (str | Path): Path to the checkpoint file. Must be a .ckpt file.
49
- config_key (str, optional): Key in the checkpoint dictionary where the config is stored.
50
- Defaults to "config".
51
- weights_prefix (str, optional): Prefix to strip from state dict keys. Defaults to "model.".
52
- If not empty and doesn't end with ".", a "." is appended.
53
- should_log_incompatible_keys (bool, optional): Whether to log incompatible keys. Defaults to True.
54
- **kwargs: Additional keyword arguments to pass to the model loading method.
55
-
56
- Returns:
57
- TModelInstance: The loaded model instance.
58
-
59
- Raises:
60
- ValueError: If checkpoint_path is a directory, not a .ckpt file, or invalid.
61
- FileNotFoundError: If the checkpoint file does not exist.
62
-
63
- """
64
- if isinstance(checkpoint_path, str):
65
- checkpoint_path = Path(checkpoint_path)
66
-
67
- if checkpoint_path.is_dir():
68
- raise ValueError(f"{checkpoint_path} is a directory")
69
- if not checkpoint_path.exists():
70
- raise FileNotFoundError(f"{checkpoint_path} does not exist")
71
- if checkpoint_path.suffix != ".ckpt":
72
- raise ValueError(f"{checkpoint_path} is not a .ckpt file")
73
-
74
- checkpoint_dict = torch.load(
75
- checkpoint_path,
76
- map_location="cpu",
77
- weights_only=False,
78
- mmap=True,
79
- )
80
-
81
- # 1. Восстанавливаем конфиг
82
- config_cls = cast(type[PretrainedConfig], cls.config_class)
83
- config_dict = checkpoint_dict[config_key]
84
- config_dict.update(kwargs)
85
- config = config_cls.from_dict(config_dict)
86
-
87
- kwargs_for_model: dict[str, Any] = {}
88
- for key, value in kwargs.items():
89
- if not hasattr(config, key):
90
- kwargs_for_model[key] = value
91
-
92
- with torch.device("meta"):
93
- model = cls(config, **kwargs_for_model)
94
-
95
- # PEFT-адаптеры (оставляю твою логику как есть)
96
- if "peft_config" in checkpoint_dict:
97
- if PeftConfig is None:
98
- raise ImportError(
99
- "peft is not installed. Please install it to load PEFT models."
100
- )
101
- for name, adapter_dict in checkpoint_dict["peft_config"].items():
102
- peft_cfg = PeftConfig.from_peft_type(**adapter_dict)
103
- model.add_adapter(peft_cfg, adapter_name=name)
104
-
105
- incompatible_keys: dict[str, list[str]] = {}
106
-
107
- raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
108
-
109
- if weights_prefix:
110
- if not weights_prefix.endswith("."):
111
- weights_prefix = weights_prefix + "."
112
- state_dict: dict[str, torch.Tensor] = {}
113
- mismatched_keys: list[str] = []
114
-
115
- for key, value in raw_state_dict.items():
116
- if key.startswith(weights_prefix):
117
- new_key = key[len(weights_prefix) :]
118
- state_dict[new_key] = value
119
- else:
120
- mismatched_keys.append(key)
121
-
122
- if mismatched_keys:
123
- incompatible_keys["mismatched_keys"] = mismatched_keys
124
- else:
125
- state_dict = raw_state_dict
126
-
127
- # 5. Логика base_model_prefix как в HF:
128
- # поддержка загрузки базовой модели <-> модели с головой
129
- #
130
- # cls.base_model_prefix обычно "model" / "bert" / "encoder" и т.п.
131
- base_prefix: str = getattr(cls, "base_model_prefix", "") or ""
132
- model_to_load: nn.Module = model
133
-
134
- if base_prefix:
135
- prefix_with_dot = base_prefix + "."
136
- loaded_keys = list(state_dict.keys())
137
- full_model_state = model.state_dict()
138
- expected_keys = list(full_model_state.keys())
139
-
140
- has_prefix_module = any(k.startswith(prefix_with_dot) for k in loaded_keys)
141
- expects_prefix_module = any(
142
- k.startswith(prefix_with_dot) for k in expected_keys
143
- )
144
-
145
- # Кейc 1: загружаем базовую модель в модель с головой.
146
- # Пример: StaticEmbeddingsForSequenceClassification (имеет .model)
147
- # state_dict с ключами "embeddings.weight", "token_pos_weights", ...
148
- if (
149
- hasattr(model, base_prefix)
150
- and not has_prefix_module
151
- and expects_prefix_module
152
- ):
153
- # Веса без префикса -> грузим только в model.<base_prefix>
154
- model_to_load = getattr(model, base_prefix)
155
-
156
- # Кейc 2: загружаем чекпоинт модели с головой в базовую модель.
157
- # Пример: BertModel, а в state_dict ключи "bert.encoder.layer.0..."
158
- elif (
159
- not hasattr(model, base_prefix)
160
- and has_prefix_module
161
- and not expects_prefix_module
162
- ):
163
- new_state_dict: dict[str, torch.Tensor] = {}
164
- for key, value in state_dict.items():
165
- if key.startswith(prefix_with_dot):
166
- new_key = key[len(prefix_with_dot) :]
167
- else:
168
- new_key = key
169
- new_state_dict[new_key] = value
170
- state_dict = new_state_dict
171
-
172
- load_result = model_to_load.load_state_dict(
173
- state_dict, strict=False, assign=True
174
- )
175
- missing_keys, unexpected_keys = (
176
- load_result.missing_keys,
177
- load_result.unexpected_keys,
178
- )
179
-
180
- # Если мы грузили только в base-подмодуль, расширим missing_keys
181
- # до полного списка (base + голова), как в старых версиях HF.
182
- if model_to_load is not model and base_prefix:
183
- base_keys = set(model_to_load.state_dict().keys())
184
- # Приводим ключи полной модели к "безпрефиксному" виду
185
- head_like_keys = set()
186
- prefix_with_dot = base_prefix + "."
187
- for k in model.state_dict().keys():
188
- if k.startswith(prefix_with_dot):
189
- # отрезаем "model."
190
- head_like_keys.add(k[len(prefix_with_dot) :])
191
- else:
192
- head_like_keys.add(k)
193
- extra_missing = sorted(head_like_keys - base_keys)
194
- missing_keys = list(missing_keys) + extra_missing
195
-
196
- incompatible_keys["missing_keys"] = missing_keys
197
- incompatible_keys["unexpected_keys"] = unexpected_keys
198
-
199
- if should_log_incompatible_keys:
200
- log_incompatible_keys(incompatible_keys=incompatible_keys, logger=logger)
201
-
202
- return model