kostyl-toolkit 0.1.20__tar.gz → 0.1.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/PKG-INFO +1 -1
  2. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/configs/hyperparams.py +1 -3
  3. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/configs/training_settings.py +1 -3
  4. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/lightning/callbacks/registry_uploading.py +22 -4
  5. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/lightning/extenstions/pretrained_model.py +93 -16
  6. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/utils/logging.py +1 -1
  7. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/pyproject.toml +1 -1
  8. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/README.md +0 -0
  9. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/__init__.py +0 -0
  10. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/__init__.py +0 -0
  11. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/clearml/__init__.py +0 -0
  12. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/clearml/dataset_utils.py +0 -0
  13. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/clearml/logging_utils.py +0 -0
  14. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/clearml/pulling_utils.py +0 -0
  15. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/configs/__init__.py +0 -0
  16. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/configs/base_model.py +0 -0
  17. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/dist_utils.py +0 -0
  18. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/lightning/__init__.py +0 -0
  19. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/lightning/callbacks/__init__.py +0 -0
  20. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/lightning/callbacks/checkpoint.py +0 -0
  21. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/lightning/callbacks/early_stopping.py +0 -0
  22. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/lightning/extenstions/__init__.py +0 -0
  23. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/lightning/extenstions/custom_module.py +0 -0
  24. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/lightning/loggers/__init__.py +0 -0
  25. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/lightning/loggers/tb_logger.py +0 -0
  26. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/lightning/steps_estimation.py +0 -0
  27. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/metrics_formatting.py +0 -0
  28. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/params_groups.py +0 -0
  29. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/schedulers/__init__.py +0 -0
  30. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/schedulers/base.py +0 -0
  31. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/schedulers/composite.py +0 -0
  32. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/ml/schedulers/cosine.py +0 -0
  33. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/utils/__init__.py +0 -0
  34. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/utils/dict_manipulations.py +0 -0
  35. {kostyl_toolkit-0.1.20 → kostyl_toolkit-0.1.22}/kostyl/utils/fs.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.20
3
+ Version: 0.1.22
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -4,8 +4,6 @@ from pydantic import model_validator
4
4
 
5
5
  from kostyl.utils.logging import setup_logger
6
6
 
7
- from .base_model import KostylBaseModel
8
-
9
7
 
10
8
  logger = setup_logger(fmt="only_message")
11
9
 
@@ -75,7 +73,7 @@ class WeightDecay(BaseModel):
75
73
  return self
76
74
 
77
75
 
78
- class HyperparamsConfig(KostylBaseModel):
76
+ class HyperparamsConfig(BaseModel):
79
77
  """Model training hyperparameters configuration."""
80
78
 
81
79
  grad_clip_val: float | None = Field(default=None, gt=0, validate_default=False)
@@ -5,8 +5,6 @@ from pydantic import Field
5
5
 
6
6
  from kostyl.utils.logging import setup_logger
7
7
 
8
- from .base_model import KostylBaseModel
9
-
10
8
 
11
9
  logger = setup_logger(fmt="only_message")
12
10
 
@@ -95,7 +93,7 @@ class DataConfig(BaseModel):
95
93
  data_columns: list[str]
96
94
 
97
95
 
98
- class TrainingSettings(KostylBaseModel):
96
+ class TrainingSettings(BaseModel):
99
97
  """Training parameters configuration."""
100
98
 
101
99
  trainer: LightningTrainerParameters
@@ -27,6 +27,8 @@ class ClearMLRegistryUploaderCallback(Callback):
27
27
  output_model_tags: list[str] | None = None,
28
28
  verbose: bool = True,
29
29
  enable_tag_versioning: bool = True,
30
+ label_enumeration: dict[str, int] | None = None,
31
+ config_dict: dict[str, str] | None = None,
30
32
  uploading_frequency: Literal[
31
33
  "after-every-eval", "on-train-end"
32
34
  ] = "on-train-end",
@@ -40,6 +42,8 @@ class ClearMLRegistryUploaderCallback(Callback):
40
42
  output_model_name: Name for the ClearML output model.
41
43
  output_model_tags: Tags for the output model.
42
44
  verbose: Whether to log messages.
45
+ label_enumeration: Optional mapping of label names to integer IDs.
46
+ config_dict: Optional configuration dictionary to associate with the model.
43
47
  enable_tag_versioning: Whether to enable versioning in tags. If True,
44
48
  the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
45
49
  uploading_frequency: When to upload:
@@ -55,6 +59,8 @@ class ClearMLRegistryUploaderCallback(Callback):
55
59
  self.ckpt_callback = ckpt_callback
56
60
  self.output_model_name = output_model_name
57
61
  self.output_model_tags = output_model_tags
62
+ self.config_dict = config_dict
63
+ self.label_enumeration = label_enumeration
58
64
  self.verbose = verbose
59
65
  self.uploading_frequency = uploading_frequency
60
66
  self.enable_tag_versioning = enable_tag_versioning
@@ -75,16 +81,21 @@ class ClearMLRegistryUploaderCallback(Callback):
75
81
 
76
82
  if "LightningCheckpoint" not in self.output_model_tags:
77
83
  self.output_model_tags.append("LightningCheckpoint")
78
- config = pl_module.model_config
79
- if config is not None:
80
- config = config.to_dict()
84
+
85
+ if self.config_dict is None:
86
+ config = pl_module.model_config
87
+ if config is not None:
88
+ config = config.to_dict()
89
+ else:
90
+ config = self.config_dict
81
91
 
82
92
  return OutputModel(
83
93
  task=self.task,
84
94
  name=self.output_model_name,
85
95
  framework="PyTorch",
86
96
  tags=self.output_model_tags,
87
- config_dict=config,
97
+ config_dict=None,
98
+ label_enumeration=self.label_enumeration,
88
99
  )
89
100
 
90
101
  def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
@@ -111,6 +122,13 @@ class ClearMLRegistryUploaderCallback(Callback):
111
122
  auto_delete_file=False,
112
123
  async_enable=False,
113
124
  )
125
+ if self.config_dict is None:
126
+ config = pl_module.model_config
127
+ if config is not None:
128
+ config = config.to_dict()
129
+ else:
130
+ config = self.config_dict
131
+ self._output_model.update_design(config_dict=config)
114
132
 
115
133
  self._last_best_model_path = current_best
116
134
  return
@@ -14,6 +14,7 @@ except ImportError:
14
14
 
15
15
  from kostyl.utils.logging import log_incompatible_keys
16
16
  from kostyl.utils.logging import setup_logger
17
+ from torch import nn
17
18
 
18
19
 
19
20
  logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
@@ -67,7 +68,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
67
68
  raise ValueError(f"{checkpoint_path} is a directory")
68
69
  if not checkpoint_path.exists():
69
70
  raise FileNotFoundError(f"{checkpoint_path} does not exist")
70
- if not checkpoint_path.suffix == ".ckpt":
71
+ if checkpoint_path.suffix != ".ckpt":
71
72
  raise ValueError(f"{checkpoint_path} is not a .ckpt file")
72
73
 
73
74
  checkpoint_dict = torch.load(
@@ -77,19 +78,21 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
77
78
  mmap=True,
78
79
  )
79
80
 
80
- config_cls = cast(PretrainedConfig, type(cls.config_class))
81
+ # 1. Восстанавливаем конфиг
82
+ config_cls = cast(type[PretrainedConfig], cls.config_class)
81
83
  config_dict = checkpoint_dict[config_key]
82
84
  config_dict.update(kwargs)
83
85
  config = config_cls.from_dict(config_dict)
84
86
 
85
- kwargs_for_model = {}
86
- for key in kwargs:
87
+ kwargs_for_model: dict[str, Any] = {}
88
+ for key, value in kwargs.items():
87
89
  if not hasattr(config, key):
88
- kwargs_for_model[key] = kwargs[key]
90
+ kwargs_for_model[key] = value
89
91
 
90
92
  with torch.device("meta"):
91
93
  model = cls(config, **kwargs_for_model)
92
94
 
95
+ # PEFT-адаптеры (оставляю твою логику как есть)
93
96
  if "peft_config" in checkpoint_dict:
94
97
  if PeftConfig is None:
95
98
  raise ImportError(
@@ -100,26 +103,100 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
100
103
  model.add_adapter(peft_cfg, adapter_name=name)
101
104
 
102
105
  incompatible_keys: dict[str, list[str]] = {}
103
- if weights_prefix != "":
104
- if weights_prefix[-1] != ".":
105
- weights_prefix += "."
106
- model_state_dict = {}
107
- mismatched_keys = []
108
- for key, value in checkpoint_dict["state_dict"].items():
106
+
107
+ raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
108
+
109
+ if weights_prefix:
110
+ if not weights_prefix.endswith("."):
111
+ weights_prefix = weights_prefix + "."
112
+ state_dict: dict[str, torch.Tensor] = {}
113
+ mismatched_keys: list[str] = []
114
+
115
+ for key, value in raw_state_dict.items():
109
116
  if key.startswith(weights_prefix):
110
117
  new_key = key[len(weights_prefix) :]
111
- model_state_dict[new_key] = value
118
+ state_dict[new_key] = value
112
119
  else:
113
120
  mismatched_keys.append(key)
121
+
122
+ if mismatched_keys:
114
123
  incompatible_keys["mismatched_keys"] = mismatched_keys
115
124
  else:
116
- model_state_dict = checkpoint_dict["state_dict"]
117
-
118
- missing_keys, unexpected_keys = model.load_state_dict(
119
- model_state_dict, strict=False, assign=True
125
+ state_dict = raw_state_dict
126
+
127
+ # 5. Логика base_model_prefix как в HF:
128
+ # поддержка загрузки базовой модели <-> модели с головой
129
+ #
130
+ # cls.base_model_prefix обычно "model" / "bert" / "encoder" и т.п.
131
+ base_prefix: str = getattr(cls, "base_model_prefix", "") or ""
132
+ model_to_load: nn.Module = model
133
+
134
+ if base_prefix:
135
+ prefix_with_dot = base_prefix + "."
136
+ loaded_keys = list(state_dict.keys())
137
+ full_model_state = model.state_dict()
138
+ expected_keys = list(full_model_state.keys())
139
+
140
+ has_prefix_module = any(k.startswith(prefix_with_dot) for k in loaded_keys)
141
+ expects_prefix_module = any(
142
+ k.startswith(prefix_with_dot) for k in expected_keys
143
+ )
144
+
145
+ # Кейc 1: загружаем базовую модель в модель с головой.
146
+ # Пример: StaticEmbeddingsForSequenceClassification (имеет .model)
147
+ # state_dict с ключами "embeddings.weight", "token_pos_weights", ...
148
+ if (
149
+ hasattr(model, base_prefix)
150
+ and not has_prefix_module
151
+ and expects_prefix_module
152
+ ):
153
+ # Веса без префикса -> грузим только в model.<base_prefix>
154
+ model_to_load = getattr(model, base_prefix)
155
+
156
+ # Кейc 2: загружаем чекпоинт модели с головой в базовую модель.
157
+ # Пример: BertModel, а в state_dict ключи "bert.encoder.layer.0..."
158
+ elif (
159
+ not hasattr(model, base_prefix)
160
+ and has_prefix_module
161
+ and not expects_prefix_module
162
+ ):
163
+ new_state_dict: dict[str, torch.Tensor] = {}
164
+ for key, value in state_dict.items():
165
+ if key.startswith(prefix_with_dot):
166
+ new_key = key[len(prefix_with_dot) :]
167
+ else:
168
+ new_key = key
169
+ new_state_dict[new_key] = value
170
+ state_dict = new_state_dict
171
+
172
+ load_result = model_to_load.load_state_dict(
173
+ state_dict, strict=False, assign=True
174
+ )
175
+ missing_keys, unexpected_keys = (
176
+ load_result.missing_keys,
177
+ load_result.unexpected_keys,
120
178
  )
179
+
180
+ # Если мы грузили только в base-подмодуль, расширим missing_keys
181
+ # до полного списка (base + голова), как в старых версиях HF.
182
+ if model_to_load is not model and base_prefix:
183
+ base_keys = set(model_to_load.state_dict().keys())
184
+ # Приводим ключи полной модели к "безпрефиксному" виду
185
+ head_like_keys = set()
186
+ prefix_with_dot = base_prefix + "."
187
+ for k in model.state_dict().keys():
188
+ if k.startswith(prefix_with_dot):
189
+ # отрезаем "model."
190
+ head_like_keys.add(k[len(prefix_with_dot) :])
191
+ else:
192
+ head_like_keys.add(k)
193
+ extra_missing = sorted(head_like_keys - base_keys)
194
+ missing_keys = list(missing_keys) + extra_missing
195
+
121
196
  incompatible_keys["missing_keys"] = missing_keys
122
197
  incompatible_keys["unexpected_keys"] = unexpected_keys
198
+
123
199
  if should_log_incompatible_keys:
124
200
  log_incompatible_keys(incompatible_keys=incompatible_keys, logger=logger)
201
+
125
202
  return model
@@ -118,7 +118,7 @@ def setup_logger(
118
118
  else:
119
119
  base = name
120
120
 
121
- if add_rank is None:
121
+ if (add_rank is None) or add_rank:
122
122
  try:
123
123
  add_rank = dist.is_available() and dist.is_initialized()
124
124
  except Exception:
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "kostyl-toolkit"
3
- version = "0.1.20"
3
+ version = "0.1.22"
4
4
  description = "Kickass Orchestration System for Training, Yielding & Logging "
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"