kostyl-toolkit 0.1.39__py3-none-any.whl → 0.1.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,10 @@
1
- from .hyperparams import OPTIMIZER
1
+ from .hyperparams import OPTIMIZER_CONFIG
2
+ from .hyperparams import AdamConfig
3
+ from .hyperparams import AdamWithPrecisionConfig
2
4
  from .hyperparams import HyperparamsConfig
3
5
  from .hyperparams import Lr
6
+ from .hyperparams import MuonConfig
7
+ from .hyperparams import ScheduledParamConfig
4
8
  from .hyperparams import WeightDecay
5
9
  from .mixins import ConfigLoadingMixin
6
10
  from .training_settings import CheckpointConfig
@@ -14,7 +18,9 @@ from .training_settings import TrainingSettings
14
18
 
15
19
 
16
20
  __all__ = [
17
- "OPTIMIZER",
21
+ "OPTIMIZER_CONFIG",
22
+ "AdamConfig",
23
+ "AdamWithPrecisionConfig",
18
24
  "CheckpointConfig",
19
25
  "ConfigLoadingMixin",
20
26
  "DDPStrategyConfig",
@@ -24,6 +30,8 @@ __all__ = [
24
30
  "HyperparamsConfig",
25
31
  "LightningTrainerParameters",
26
32
  "Lr",
33
+ "MuonConfig",
34
+ "ScheduledParamConfig",
27
35
  "SingleDeviceStrategyConfig",
28
36
  "TrainingSettings",
29
37
  "WeightDecay",
@@ -21,6 +21,7 @@ class MuonConfig(BaseModel):
21
21
  """Muon optimizer hyperparameters configuration."""
22
22
 
23
23
  type: Literal["Muon"]
24
+ momentum: float = 0.95
24
25
  nesterov: bool = True
25
26
  ns_coefficients: tuple[float, float, float] = (3.4445, -4.7750, 2.0315)
26
27
  ns_steps: int = 5
@@ -29,14 +30,15 @@ class MuonConfig(BaseModel):
29
30
  class AdamWithPrecisionConfig(BaseModel):
30
31
  """Adam optimizer with low-precision hyperparameters configuration."""
31
32
 
32
- type: Literal["Adam8bit", "Adam4bit", "AdamFp8"]
33
+ type: Literal[
34
+ "Adam8bit", "Adam4bit", "AdamFp8", "AdamW8bit", "AdamW4bit", "AdamWFp8"
35
+ ]
33
36
  betas: tuple[float, float] = (0.9, 0.999)
34
37
  block_size: int
35
38
  bf16_stochastic_round: bool = False
36
- is_adamw: bool = True
37
39
 
38
40
 
39
- OPTIMIZER = AdamConfig | AdamWithPrecisionConfig | MuonConfig
41
+ OPTIMIZER_CONFIG = AdamConfig | AdamWithPrecisionConfig | MuonConfig
40
42
  SCHEDULER = Literal[
41
43
  "linear",
42
44
  "cosine",
@@ -45,8 +47,8 @@ SCHEDULER = Literal[
45
47
  ]
46
48
 
47
49
 
48
- class Lr(BaseModel):
49
- """Learning rate hyperparameters configuration."""
50
+ class ScheduledParamConfig(BaseModel):
51
+ """Base configuration for a scheduled hyperparameter."""
50
52
 
51
53
  scheduler_type: SCHEDULER | None = None
52
54
 
@@ -60,14 +62,14 @@ class Lr(BaseModel):
60
62
  )
61
63
 
62
64
  @model_validator(mode="after")
63
- def _validate_freeze_ratio(self) -> "Lr":
65
+ def _validate_freeze_ratio(self) -> "ScheduledParamConfig":
64
66
  if self.scheduler_type is None and self.freeze_ratio is not None:
65
67
  logger.warning("use_scheduler is False, freeze_ratio will be ignored.")
66
68
  self.freeze_ratio = None
67
69
  return self
68
70
 
69
71
  @model_validator(mode="after")
70
- def _validate_warmup(self) -> "Lr":
72
+ def _validate_warmup(self) -> "ScheduledParamConfig":
71
73
  if ((self.warmup_value is not None) or (self.warmup_ratio is not None)) and self.scheduler_type is None: # fmt: skip
72
74
  logger.warning(
73
75
  "scheduler_type is None, warmup_value and warmup_ratio will be ignored."
@@ -81,7 +83,7 @@ class Lr(BaseModel):
81
83
  return self
82
84
 
83
85
  @model_validator(mode="after")
84
- def _validate_final_value(self) -> "Lr":
86
+ def _validate_final_value(self) -> "ScheduledParamConfig":
85
87
  if (self.scheduler_type in {"linear"}) and (self.final_value is not None):
86
88
  raise ValueError("If scheduler_type is 'linear', final_value must be None.")
87
89
  if (self.scheduler_type is None) and (self.final_value is not None):
@@ -90,7 +92,7 @@ class Lr(BaseModel):
90
92
  return self
91
93
 
92
94
  @model_validator(mode="after")
93
- def _validate_plateau_ratio(self) -> "Lr":
95
+ def _validate_plateau_ratio(self) -> "ScheduledParamConfig":
94
96
  if self.scheduler_type is not None:
95
97
  if self.scheduler_type.startswith("plateau") and self.plateau_ratio is None:
96
98
  raise ValueError(
@@ -107,7 +109,11 @@ class Lr(BaseModel):
107
109
  return self
108
110
 
109
111
 
110
- class WeightDecay(Lr):
112
+ class Lr(ScheduledParamConfig):
113
+ """Learning rate hyperparameters configuration."""
114
+
115
+
116
+ class WeightDecay(ScheduledParamConfig):
111
117
  """Weight decay hyperparameters configuration."""
112
118
 
113
119
 
@@ -115,6 +121,6 @@ class HyperparamsConfig(BaseModel):
115
121
  """Model training hyperparameters configuration."""
116
122
 
117
123
  grad_clip_val: float | None = Field(default=None, gt=0, validate_default=False)
118
- optimizer: OPTIMIZER
124
+ optimizer: OPTIMIZER_CONFIG
119
125
  lr: Lr
120
126
  weight_decay: WeightDecay
@@ -5,3 +5,25 @@ except ImportError as e:
5
5
  "ClearML integration requires the 'clearml' package. "
6
6
  "Please install it via 'pip install clearml'."
7
7
  ) from e
8
+ from .checkpoint_uploader import ClearMLCheckpointUploader
9
+ from .config_mixin import ConfigSyncingClearmlMixin
10
+ from .dataset_utils import collect_clearml_datasets
11
+ from .dataset_utils import download_clearml_datasets
12
+ from .dataset_utils import get_datasets_paths
13
+ from .loading_utils import load_model_from_clearml
14
+ from .loading_utils import load_tokenizer_from_clearml
15
+ from .version_utils import find_version_in_tags
16
+ from .version_utils import increment_version
17
+
18
+
19
+ __all__ = [
20
+ "ClearMLCheckpointUploader",
21
+ "ConfigSyncingClearmlMixin",
22
+ "collect_clearml_datasets",
23
+ "download_clearml_datasets",
24
+ "find_version_in_tags",
25
+ "get_datasets_paths",
26
+ "increment_version",
27
+ "load_model_from_clearml",
28
+ "load_tokenizer_from_clearml",
29
+ ]
@@ -1,5 +1,5 @@
1
1
  from collections.abc import Callable
2
- from functools import partial
2
+ from datetime import datetime
3
3
  from pathlib import Path
4
4
  from typing import override
5
5
 
@@ -24,7 +24,7 @@ class ClearMLCheckpointUploader(ModelCheckpointUploader):
24
24
  comment: str | None = None,
25
25
  framework: str | None = None,
26
26
  base_model_id: str | None = None,
27
- new_model_per_upload: bool = True,
27
+ upload_as_new_model: bool = True,
28
28
  verbose: bool = True,
29
29
  ) -> None:
30
30
  """
@@ -38,20 +38,22 @@ class ClearMLCheckpointUploader(ModelCheckpointUploader):
38
38
  comment: A comment / description for the model.
39
39
  framework: The framework of the model (e.g., "PyTorch", "TensorFlow").
40
40
  base_model_id: Optional ClearML model ID to use as a base for the new model
41
- new_model_per_upload: Whether to create a new ClearML model
42
- for every upload or update weights of the same model. When updating weights,
43
- the last uploaded checkpoint will be replaced (and deleted).
41
+ upload_as_new_model: Whether to create a new ClearML model
42
+ for every upload or update weights of the same model. When True,
43
+ each checkpoint is uploaded as a separate model with timestamp added to the name.
44
+ When False, weights of the same model are updated.
44
45
  verbose: Whether to log messages during upload.
45
46
 
46
47
  """
47
48
  super().__init__()
48
- if base_model_id is not None and new_model_per_upload:
49
+ if base_model_id is not None and upload_as_new_model:
49
50
  raise ValueError(
50
- "Cannot set base_model_id when new_model_per_upload is True."
51
+ "Cannot set base_model_id when upload_as_new_model is True."
51
52
  )
52
53
 
53
54
  self.verbose = verbose
54
- self.new_model_per_upload = new_model_per_upload
55
+ self.upload_as_new_model = upload_as_new_model
56
+ self.model_name = model_name
55
57
  self.best_model_path: str = ""
56
58
  self.config_dict = config_dict
57
59
  self._output_model: OutputModel | None = None
@@ -59,15 +61,13 @@ class ClearMLCheckpointUploader(ModelCheckpointUploader):
59
61
  self._upload_callback: Callable | None = None
60
62
 
61
63
  self._validate_tags(tags)
62
- self.model_fabric = partial(
63
- OutputModel,
64
- name=model_name,
65
- label_enumeration=label_enumeration,
66
- tags=tags,
67
- comment=comment,
68
- framework=framework,
69
- base_model_id=base_model_id,
70
- )
64
+ self.model_fabric_kwargs = {
65
+ "label_enumeration": label_enumeration,
66
+ "tags": tags,
67
+ "comment": comment,
68
+ "framework": framework,
69
+ "base_model_id": base_model_id,
70
+ }
71
71
  return
72
72
 
73
73
  @staticmethod
@@ -78,16 +78,22 @@ class ClearMLCheckpointUploader(ModelCheckpointUploader):
78
78
  tags.append("LightningCheckpoint")
79
79
  return None
80
80
 
81
- @property
82
- def output_model_(self) -> OutputModel:
83
- """Returns the OutputModel instance based on `new_model_per_upload` setting."""
84
- if self.new_model_per_upload:
85
- model = self.model_fabric()
86
- self._output_model = self.model_fabric()
87
- else:
88
- if self._output_model is None:
89
- self._output_model = self.model_fabric()
90
- model = self._output_model
81
+ def _create_new_model(self) -> OutputModel:
82
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
83
+ model_name_with_timestamp = f"{self.model_name}_{timestamp}"
84
+ model = OutputModel(
85
+ name=model_name_with_timestamp,
86
+ **self.model_fabric_kwargs,
87
+ )
88
+ return model
89
+
90
+ def _get_output_model(self) -> OutputModel:
91
+ if self._output_model is None:
92
+ self._output_model = OutputModel(
93
+ name=self.model_name,
94
+ **self.model_fabric_kwargs,
95
+ )
96
+ model = self._output_model
91
97
  return model
92
98
 
93
99
  @override
@@ -105,12 +111,17 @@ class ClearMLCheckpointUploader(ModelCheckpointUploader):
105
111
  if self.verbose:
106
112
  logger.info(f"Uploading model from {path}")
107
113
 
108
- self.output_model_.update_weights(
114
+ if self.upload_as_new_model:
115
+ output_model = self._create_new_model()
116
+ else:
117
+ output_model = self._get_output_model()
118
+
119
+ output_model.update_weights(
109
120
  path,
110
121
  auto_delete_file=False,
111
122
  async_enable=False,
112
123
  )
113
- self.output_model_.update_design(config_dict=self.config_dict)
124
+ output_model.update_design(config_dict=self.config_dict)
114
125
 
115
126
  self._last_uploaded_model_path = path
116
127
  return
@@ -10,7 +10,7 @@ from kostyl.utils.dict_manipulations import flattened_dict_to_nested
10
10
  from kostyl.utils.fs import load_config
11
11
 
12
12
 
13
- class BaseModelWithClearmlSyncing[TConfig: ConfigLoadingMixin]:
13
+ class ConfigSyncingClearmlMixin[TConfig: ConfigLoadingMixin]:
14
14
  """Mixin providing ClearML task configuration syncing functionality for Pydantic models."""
15
15
 
16
16
  @classmethod
@@ -31,7 +31,7 @@ except ImportError:
31
31
  LIGHTING_MIXIN_AVAILABLE = False
32
32
 
33
33
 
34
- def get_tokenizer_from_clearml(
34
+ def load_tokenizer_from_clearml(
35
35
  model_id: str,
36
36
  task: Task | None = None,
37
37
  ignore_remote_overrides: bool = True,
@@ -66,7 +66,7 @@ def get_tokenizer_from_clearml(
66
66
  return tokenizer, clearml_tokenizer
67
67
 
68
68
 
69
- def get_model_from_clearml[
69
+ def load_model_from_clearml[
70
70
  TModel: PreTrainedModel | LightningCheckpointLoaderMixin | AutoModel
71
71
  ](
72
72
  model_id: str,
@@ -15,7 +15,7 @@ from kostyl.ml.dist_utils import is_local_zero_rank
15
15
  from kostyl.utils import setup_logger
16
16
 
17
17
 
18
- logger = setup_logger("callbacks/checkpoint.py")
18
+ logger = setup_logger()
19
19
 
20
20
 
21
21
  class ModelCheckpointWithCheckpointUploader(ModelCheckpoint):
@@ -278,6 +278,10 @@ class ModelCheckpointWithCheckpointUploader(ModelCheckpoint):
278
278
  case "only-best":
279
279
  if filepath == self.best_model_path:
280
280
  self.registry_uploader.upload_checkpoint(filepath)
281
+ case _:
282
+ logger.warning_once(
283
+ "Unknown upload strategy for checkpoint uploader. Skipping upload."
284
+ )
281
285
  return
282
286
 
283
287
 
@@ -14,7 +14,7 @@ from transformers import PretrainedConfig
14
14
  from transformers import PreTrainedModel
15
15
 
16
16
  from kostyl.ml.integrations.lightning.metrics_formatting import apply_suffix
17
- from kostyl.ml.schedulers.base import BaseScheduler
17
+ from kostyl.ml.optim.schedulers import BaseScheduler
18
18
  from kostyl.utils import setup_logger
19
19
 
20
20
 
@@ -0,0 +1,8 @@
1
+ from .factory import create_optimizer
2
+ from .factory import create_scheduler
3
+
4
+
5
+ __all__ = [
6
+ "create_optimizer",
7
+ "create_scheduler",
8
+ ]
@@ -0,0 +1,257 @@
1
+ from typing import Any
2
+
3
+ from torch.optim import Optimizer
4
+
5
+ from kostyl.ml.configs import OPTIMIZER_CONFIG
6
+ from kostyl.ml.configs import AdamConfig
7
+ from kostyl.ml.configs import AdamWithPrecisionConfig
8
+ from kostyl.ml.configs import MuonConfig
9
+ from kostyl.ml.configs import ScheduledParamConfig
10
+ from kostyl.utils import setup_logger
11
+
12
+ from .schedulers import SCHEDULER_MAPPING
13
+ from .schedulers import CosineScheduler
14
+ from .schedulers import LinearScheduler
15
+ from .schedulers import PlateauWithAnnealingScheduler
16
+
17
+
18
+ logger = setup_logger(fmt="only_message")
19
+
20
+
21
+ def create_scheduler(
22
+ config: ScheduledParamConfig,
23
+ param_group_field: str,
24
+ num_iters: int,
25
+ optim: Optimizer,
26
+ multiplier_field: str | None = None,
27
+ skip_if_zero: bool = False,
28
+ apply_if_field: str | None = None,
29
+ ignore_if_field: str | None = None,
30
+ ) -> LinearScheduler | CosineScheduler | PlateauWithAnnealingScheduler:
31
+ """
32
+ Converts a ScheduledParamConfig to a scheduler instance.
33
+
34
+ Args:
35
+ config: Configuration object for the scheduler.
36
+ param_group_field: The field name in the optimizer's param groups to schedule.
37
+ num_iters: Total number of iterations.
38
+ optim: The optimizer instance.
39
+ multiplier_field: Optional per-group field name that contains a multiplier applied to the scheduled value. If None, no multiplier is applied.
40
+ skip_if_zero: Leave groups untouched when their target field equals zero.
41
+ Default is False.
42
+ apply_if_field: Require this key to be present in a param group before updating.
43
+ ignore_if_field: Skip groups that declare this key in their dictionaries.
44
+
45
+ Returns:
46
+ A scheduler instance based on the configuration.
47
+
48
+ """
49
+ if config.scheduler_type is None:
50
+ raise ValueError("scheduler_type must be specified in the config.")
51
+
52
+ if "plateau" in config.scheduler_type:
53
+ scheduler_type = "plateau"
54
+ else:
55
+ scheduler_type = config.scheduler_type
56
+ scheduler_cls = SCHEDULER_MAPPING[scheduler_type] # type: ignore
57
+
58
+ if issubclass(scheduler_cls, PlateauWithAnnealingScheduler):
59
+ if "cosine" in config.scheduler_type:
60
+ annealing_type = "cosine"
61
+ elif "linear" in config.scheduler_type:
62
+ annealing_type = "linear"
63
+ else:
64
+ raise ValueError(f"Unknown annealing_type: {config.scheduler_type}")
65
+ scheduler = scheduler_cls(
66
+ optimizer=optim,
67
+ param_group_field=param_group_field,
68
+ num_iters=num_iters,
69
+ plateau_value=config.base_value,
70
+ final_value=config.final_value, # type: ignore
71
+ warmup_ratio=config.warmup_ratio,
72
+ warmup_value=config.warmup_value,
73
+ freeze_ratio=config.freeze_ratio,
74
+ plateau_ratio=config.plateau_ratio, # type: ignore
75
+ annealing_type=annealing_type,
76
+ multiplier_field=multiplier_field,
77
+ skip_if_zero=skip_if_zero,
78
+ apply_if_field=apply_if_field,
79
+ ignore_if_field=ignore_if_field,
80
+ )
81
+ elif issubclass(scheduler_cls, LinearScheduler):
82
+ scheduler = scheduler_cls(
83
+ optimizer=optim,
84
+ param_group_field=param_group_field,
85
+ num_iters=num_iters,
86
+ initial_value=config.base_value,
87
+ final_value=config.final_value, # type: ignore
88
+ multiplier_field=multiplier_field,
89
+ skip_if_zero=skip_if_zero,
90
+ apply_if_field=apply_if_field,
91
+ ignore_if_field=ignore_if_field,
92
+ )
93
+ elif issubclass(scheduler_cls, CosineScheduler):
94
+ scheduler = scheduler_cls(
95
+ optimizer=optim,
96
+ param_group_field=param_group_field,
97
+ num_iters=num_iters,
98
+ base_value=config.base_value,
99
+ final_value=config.final_value, # type: ignore
100
+ warmup_ratio=config.warmup_ratio,
101
+ warmup_value=config.warmup_value,
102
+ freeze_ratio=config.freeze_ratio,
103
+ multiplier_field=multiplier_field,
104
+ skip_if_zero=skip_if_zero,
105
+ apply_if_field=apply_if_field,
106
+ ignore_if_field=ignore_if_field,
107
+ )
108
+ else:
109
+ raise ValueError(f"Unsupported scheduler type: {config.scheduler_type}")
110
+ return scheduler
111
+
112
+
113
+ def create_optimizer( # noqa: C901
114
+ parameters_groups: dict[str, Any],
115
+ optimizer_config: OPTIMIZER_CONFIG,
116
+ lr: float,
117
+ weight_decay: float,
118
+ ) -> Optimizer:
119
+ """
120
+ Creates an optimizer based on the configuration.
121
+
122
+ Args:
123
+ parameters_groups: Dictionary containing model parameters
124
+ (key "params" and per-group options, i.e. "lr", "weight_decay" and etc.).
125
+ optimizer_config: Configuration for the optimizer.
126
+ lr: Learning rate.
127
+ weight_decay: Weight decay.
128
+
129
+ Returns:
130
+ An instantiated optimizer.
131
+
132
+ """
133
+ if isinstance(optimizer_config, AdamConfig):
134
+ match optimizer_config.type:
135
+ case "Adam":
136
+ from torch.optim import Adam
137
+
138
+ optimizer = Adam(
139
+ params=parameters_groups["params"],
140
+ lr=lr,
141
+ weight_decay=weight_decay,
142
+ betas=optimizer_config.betas,
143
+ )
144
+
145
+ case "AdamW":
146
+ from torch.optim import AdamW
147
+
148
+ optimizer = AdamW(
149
+ params=parameters_groups["params"],
150
+ lr=lr,
151
+ weight_decay=weight_decay,
152
+ betas=optimizer_config.betas,
153
+ )
154
+ return optimizer
155
+ case _:
156
+ raise ValueError(f"Unsupported optimizer type: {optimizer_config.type}")
157
+ elif isinstance(optimizer_config, MuonConfig):
158
+ from torch.optim import Muon
159
+
160
+ optimizer = Muon(
161
+ params=parameters_groups["params"],
162
+ lr=lr,
163
+ weight_decay=weight_decay,
164
+ momentum=optimizer_config.momentum,
165
+ nesterov=optimizer_config.nesterov,
166
+ ns_coefficients=optimizer_config.ns_coefficients,
167
+ ns_steps=optimizer_config.ns_steps,
168
+ )
169
+ elif isinstance(optimizer_config, AdamWithPrecisionConfig):
170
+ try:
171
+ import torchao # noqa: F401
172
+ except ImportError as e:
173
+ raise ImportError(
174
+ "torchao is required for low-precision Adam optimizers. "
175
+ "Please install it via 'pip install torchao'."
176
+ ) from e
177
+ match optimizer_config.type:
178
+ case "Adam8bit":
179
+ from torchao.optim import Adam8bit
180
+
181
+ logger.warning(
182
+ "Ignoring weight_decay for Adam8bit optimizer as it is not supported."
183
+ )
184
+
185
+ optimizer = Adam8bit(
186
+ params=parameters_groups["params"],
187
+ lr=lr,
188
+ betas=optimizer_config.betas,
189
+ block_size=optimizer_config.block_size,
190
+ bf16_stochastic_round=optimizer_config.bf16_stochastic_round,
191
+ )
192
+ case "Adam4bit":
193
+ from torchao.optim import Adam4bit
194
+
195
+ logger.warning(
196
+ "Ignoring weight_decay for Adam4bit optimizer as it is not supported."
197
+ )
198
+
199
+ optimizer = Adam4bit(
200
+ params=parameters_groups["params"],
201
+ lr=lr,
202
+ betas=optimizer_config.betas,
203
+ block_size=optimizer_config.block_size,
204
+ bf16_stochastic_round=optimizer_config.bf16_stochastic_round,
205
+ )
206
+ case "AdamFp8":
207
+ from torchao.optim import AdamFp8
208
+
209
+ logger.warning(
210
+ "Ignoring weight_decay for AdamFp8 optimizer as it is not supported."
211
+ )
212
+
213
+ optimizer = AdamFp8(
214
+ params=parameters_groups["params"],
215
+ lr=lr,
216
+ betas=optimizer_config.betas,
217
+ block_size=optimizer_config.block_size,
218
+ bf16_stochastic_round=optimizer_config.bf16_stochastic_round,
219
+ )
220
+ case "AdamW8bit":
221
+ from torchao.optim import AdamW8bit
222
+
223
+ optimizer = AdamW8bit(
224
+ params=parameters_groups["params"],
225
+ lr=lr,
226
+ weight_decay=weight_decay,
227
+ betas=optimizer_config.betas,
228
+ block_size=optimizer_config.block_size,
229
+ bf16_stochastic_round=optimizer_config.bf16_stochastic_round,
230
+ )
231
+ case "AdamW4bit":
232
+ from torchao.optim import AdamW4bit
233
+
234
+ optimizer = AdamW4bit(
235
+ params=parameters_groups["params"],
236
+ lr=lr,
237
+ weight_decay=weight_decay,
238
+ betas=optimizer_config.betas,
239
+ block_size=optimizer_config.block_size,
240
+ bf16_stochastic_round=optimizer_config.bf16_stochastic_round,
241
+ )
242
+ case "AdamWFp8":
243
+ from torchao.optim import AdamWFp8
244
+
245
+ optimizer = AdamWFp8(
246
+ params=parameters_groups["params"],
247
+ lr=lr,
248
+ weight_decay=weight_decay,
249
+ betas=optimizer_config.betas,
250
+ block_size=optimizer_config.block_size,
251
+ bf16_stochastic_round=optimizer_config.bf16_stochastic_round,
252
+ )
253
+ case _:
254
+ raise ValueError(f"Unsupported optimizer type: {optimizer_config.type}")
255
+ else:
256
+ raise ValueError("Unsupported optimizer configuration type.")
257
+ return optimizer
@@ -0,0 +1,56 @@
1
+ from typing import TypedDict
2
+
3
+ from .base import BaseScheduler
4
+ from .composite import CompositeScheduler
5
+ from .cosine import CosineParamScheduler
6
+ from .cosine import CosineScheduler
7
+ from .linear import LinearParamScheduler
8
+ from .linear import LinearScheduler
9
+ from .plateau import PlateauWithAnnealingParamScheduler
10
+ from .plateau import PlateauWithAnnealingScheduler
11
+
12
+
13
+ class SchedulerMapping(TypedDict):
14
+ """Map names to scheduler classes."""
15
+
16
+ linear: type[LinearScheduler]
17
+ cosine: type[CosineScheduler]
18
+ plateau: type[PlateauWithAnnealingScheduler]
19
+ composite: type[CompositeScheduler]
20
+
21
+
22
+ class ParamSchedulerMapping(TypedDict):
23
+ """Map names to scheduler classes."""
24
+
25
+ linear: type[LinearParamScheduler]
26
+ cosine: type[CosineParamScheduler]
27
+ plateau: type[PlateauWithAnnealingParamScheduler]
28
+
29
+
30
+ SCHEDULER_MAPPING: SchedulerMapping = {
31
+ "linear": LinearScheduler,
32
+ "cosine": CosineScheduler,
33
+ "plateau": PlateauWithAnnealingScheduler,
34
+ "composite": CompositeScheduler,
35
+ }
36
+
37
+
38
+ PARAM_SCHEDULER_MAPPING: ParamSchedulerMapping = {
39
+ "linear": LinearParamScheduler,
40
+ "cosine": CosineParamScheduler,
41
+ "plateau": PlateauWithAnnealingParamScheduler,
42
+ }
43
+
44
+
45
+ __all__ = [
46
+ "PARAM_SCHEDULER_MAPPING",
47
+ "SCHEDULER_MAPPING",
48
+ "BaseScheduler",
49
+ "CompositeScheduler",
50
+ "CosineParamScheduler",
51
+ "CosineScheduler",
52
+ "LinearParamScheduler",
53
+ "LinearScheduler",
54
+ "PlateauWithAnnealingParamScheduler",
55
+ "PlateauWithAnnealingScheduler",
56
+ ]
@@ -145,8 +145,8 @@ class CosineScheduler(_CosineSchedulerCore):
145
145
  freeze_ratio: Optional fraction of iterations to keep the value frozen at zero at the beginning.
146
146
  multiplier_field: Optional per-group multiplier applied to the scheduled value.
147
147
  skip_if_zero: Leave groups untouched when their target field equals zero.
148
- apply_if_field: Require this flag to be present in a param group before updating.
149
- ignore_if_field: Skip groups that declare this flag.
148
+ apply_if_field: Require this key to be present in a param group before updating.
149
+ ignore_if_field: Skip groups that declare this key in their dictionaries.
150
150
 
151
151
  """
152
152
  self.apply_if_field = apply_if_field
@@ -13,21 +13,21 @@ class _LinearScheduleBase(BaseScheduler):
13
13
  self,
14
14
  param_name: str,
15
15
  num_iters: int,
16
- base_value: float,
16
+ initial_value: float,
17
17
  final_value: float,
18
18
  ) -> None:
19
19
  self.param_name = param_name
20
20
  self.num_iters = num_iters
21
- self.base_value = base_value
21
+ self.initial_value = initial_value
22
22
  self.final_value = final_value
23
23
 
24
24
  self.scheduled_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
25
- self.current_value_ = self.base_value
25
+ self.current_value_ = self.initial_value
26
26
  return
27
27
 
28
28
  def _create_scheduler(self) -> None:
29
29
  self.scheduled_values = np.linspace(
30
- self.base_value, self.final_value, num=self.num_iters, dtype=np.float64
30
+ self.initial_value, self.final_value, num=self.num_iters, dtype=np.float64
31
31
  )
32
32
  self._verify()
33
33
  return
@@ -68,7 +68,7 @@ class LinearScheduler(_LinearScheduleBase):
68
68
  optimizer: torch.optim.Optimizer,
69
69
  param_group_field: str,
70
70
  num_iters: int,
71
- base_value: float,
71
+ initial_value: float,
72
72
  final_value: float,
73
73
  multiplier_field: str | None = None,
74
74
  skip_if_zero: bool = False,
@@ -82,12 +82,12 @@ class LinearScheduler(_LinearScheduleBase):
82
82
  optimizer: Optimizer whose param groups are updated in-place.
83
83
  param_group_field: Name of the field that receives the scheduled value.
84
84
  num_iters: Number of scheduler iterations before clamping at ``final_value``.
85
- base_value: Value used on the first iteration.
85
+ initial_value: Value used on the first iteration.
86
86
  final_value: Value used once ``num_iters`` iterations are consumed.
87
87
  multiplier_field: Optional per-group multiplier applied to the scheduled value.
88
88
  skip_if_zero: Leave groups untouched when their target field equals zero.
89
- apply_if_field: Require this flag to be present in a param group before updating.
90
- ignore_if_field: Skip groups that declare this flag.
89
+ apply_if_field: Require this key to be present in a param group before updating.
90
+ ignore_if_field: Skip groups that declare this key in their dictionaries.
91
91
 
92
92
  """
93
93
  self.apply_if_field = apply_if_field
@@ -98,7 +98,7 @@ class LinearScheduler(_LinearScheduleBase):
98
98
  super().__init__(
99
99
  param_name=param_group_field,
100
100
  num_iters=num_iters,
101
- base_value=base_value,
101
+ initial_value=initial_value,
102
102
  final_value=final_value,
103
103
  )
104
104
  self.param_group_field = param_group_field
@@ -198,8 +198,8 @@ class PlateauWithAnnealingScheduler(_PlateauWithAnnealingCore):
198
198
  annealing_type: Type of annealing from plateau to final value ("cosine" or "linear").
199
199
  multiplier_field: Optional per-group multiplier applied to the scheduled value.
200
200
  skip_if_zero: Leave groups untouched when their target field equals zero.
201
- apply_if_field: Require this flag to be present in a param group before updating.
202
- ignore_if_field: Skip groups that declare this flag.
201
+ apply_if_field: Require this key to be present in a param group before updating.
202
+ ignore_if_field: Skip groups that declare this key in their dictionaries.
203
203
 
204
204
  """
205
205
  self.apply_if_field = apply_if_field
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.39
3
+ Version: 0.1.41
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
7
7
  Requires-Dist: case-converter>=1.2.0 ; extra == 'ml'
8
8
  Requires-Dist: pydantic>=2.12.4 ; extra == 'ml'
9
- Requires-Dist: torch>=2.9.1 ; extra == 'ml'
10
- Requires-Dist: transformers>=4.57.1 ; extra == 'ml'
9
+ Requires-Dist: torch ; extra == 'ml'
10
+ Requires-Dist: transformers ; extra == 'ml'
11
11
  Requires-Python: >=3.12
12
12
  Provides-Extra: ml
13
13
  Description-Content-Type: text/markdown
@@ -1,40 +1,42 @@
1
1
  kostyl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  kostyl/ml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  kostyl/ml/base_uploader.py,sha256=KxHuohCcNK18kTVFBBqDu_IOQefluhSXOzwC56O66wc,484
4
- kostyl/ml/configs/__init__.py,sha256=9AbM3ugsvc0bRKHA99DTRuUc2N7iSV9BH5FTDOp0cpw,913
5
- kostyl/ml/configs/hyperparams.py,sha256=wU6NoZoTG9WlnU4z3g4ViUfjjp8fb8YWUS_2gAd5wsY,4131
4
+ kostyl/ml/configs/__init__.py,sha256=ytr2HFtsqMIs1pqQ-Ma0jv8c6Ni1-UxPnDTLu16cL68,1189
5
+ kostyl/ml/configs/hyperparams.py,sha256=poOsR6xbE6jHO02GzxXfbz-6PA0HqULgcTgiuT_4oOM,4401
6
6
  kostyl/ml/configs/mixins.py,sha256=xHHAoRoPbzP9ECFP9duzg6SzegHcoLI8Pr9NrLoWNHs,1411
7
7
  kostyl/ml/configs/training_settings.py,sha256=wT9CHuLaKrLwonsc87Ee421EyFis_c9fqOgn9bSClm8,2747
8
8
  kostyl/ml/data_collator.py,sha256=kxiaMDKwSKXGBtrF8yXxHcypf7t_6syU-NwO1LcX50k,4062
9
9
  kostyl/ml/dist_utils.py,sha256=UFNMLEHc0A5F6KvTRG8GQPpRDwG4m5dvM__UvXNc2aQ,4526
10
10
  kostyl/ml/integrations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- kostyl/ml/integrations/clearml/__init__.py,sha256=3TBVI-3fE9ZzuvOLEohW9TOK0BZTLD5JiYalAVDkocc,217
12
- kostyl/ml/integrations/clearml/checkpoint_uploader.py,sha256=PupFi7jKROsIddOz7X5DhV7nUNdDZg5kKaaLvzdCHlY,4012
13
- kostyl/ml/integrations/clearml/config_mixin.py,sha256=70QRicU7etiDzLX-MplqVX8uFm5siuPrM8KbTOriZnQ,3308
11
+ kostyl/ml/integrations/clearml/__init__.py,sha256=cIJMmiTgRrBNlCwejs6x4DrJydv7zYRPqUCWpVgQJYo,970
12
+ kostyl/ml/integrations/clearml/checkpoint_uploader.py,sha256=hwE2pbi3EK0iYw0BCDTK4KjuxS3lfOWDFq5aqnU7c1U,4381
13
+ kostyl/ml/integrations/clearml/config_mixin.py,sha256=F9Ni6k80KyJsVGcgHjAIIsdpmuDcVqO7eogwiNHccqg,3306
14
14
  kostyl/ml/integrations/clearml/dataset_utils.py,sha256=eij_sr2KDhm8GxEbVbK8aBjPsuVvLl9-PIGGaKVgXLA,1729
15
- kostyl/ml/integrations/clearml/loading_utils.py,sha256=NAMmB9NTGCXCHh-bR_nrQZyqImUVZqicNjExDyPM2mU,5224
15
+ kostyl/ml/integrations/clearml/loading_utils.py,sha256=1wo5QX8cCYC9vGkkupHmaQFirJj4ZUSm_K8zp685DNU,5226
16
16
  kostyl/ml/integrations/clearml/version_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWaJTyq8,1210
17
17
  kostyl/ml/integrations/lightning/__init__.py,sha256=r96os8kTuKIAymx3k9Td1JBrO2PH7nQAWUC54NsY5yY,392
18
18
  kostyl/ml/integrations/lightning/callbacks/__init__.py,sha256=EnKkNwwNDZnEqKRlpY4FVrqP88ECPF6nlT2bSLUIKRk,194
19
- kostyl/ml/integrations/lightning/callbacks/checkpoint.py,sha256=iJQKaQUzCMNza7SUKtCVDol_Cy3rZy9KfBKQ6kT6Swg,18434
19
+ kostyl/ml/integrations/lightning/callbacks/checkpoint.py,sha256=Qpp4LrLloXTIwuM2TDqYi3vvYKn1KZOK_VCi2bs26qY,18588
20
20
  kostyl/ml/integrations/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
21
21
  kostyl/ml/integrations/lightning/loggers/__init__.py,sha256=e51dszaoJbuzwBkbdugmuDsPldoSO4yaRgmZUg1Bdy0,71
22
22
  kostyl/ml/integrations/lightning/loggers/tb_logger.py,sha256=jVzEArHr1bpO-HRdKjlLJ4BJnxomHPAuBKCdYQRHRn4,1023
23
23
  kostyl/ml/integrations/lightning/metrics_formatting.py,sha256=U6vdNENZLvp2dT1L3HqFKtXrHwGKoDXN93hvamPGHjM,1341
24
24
  kostyl/ml/integrations/lightning/mixins.py,sha256=hVIsIUu6Iryrz6S7GQTqog9vNq8LQyjJd2aoJ5Ws6KU,5253
25
- kostyl/ml/integrations/lightning/module.py,sha256=39hcVNZSGyj5tLpXyX8IoqMGWt5vf6-Bx5JnNJ2-Wag,5218
25
+ kostyl/ml/integrations/lightning/module.py,sha256=APodFpspSoBYmrtTnCpUFq3WFEvpmHp2hlWUxwwAChM,5219
26
26
  kostyl/ml/integrations/lightning/utils.py,sha256=QhbK5iTv07xkRKuyXiK05aoY_ObSGorzGu1WcFvvFtI,1712
27
+ kostyl/ml/optim/__init__.py,sha256=GgCr94WbiM-7MVDf6fPlbeUSxS33ufn_UAmqfKerElc,140
28
+ kostyl/ml/optim/factory.py,sha256=1WYbAOOGepVL55gSVCJs5WBsiiE4k4ayrS5svFh9kbE,9760
29
+ kostyl/ml/optim/schedulers/__init__.py,sha256=lVwWj_mtUDo5OLtr8NNljLQgWGJUKPJWV7bm6-997ag,1486
30
+ kostyl/ml/optim/schedulers/base.py,sha256=bjmwgdZpnSqpCnHPnKC6MEiRO79cwxMJpZq-eQVNs2M,1353
31
+ kostyl/ml/optim/schedulers/composite.py,sha256=ee4xlMDMMtjKPkbTF2ue9GTr9DuGCGjZWf11mHbi6aE,2387
32
+ kostyl/ml/optim/schedulers/cosine.py,sha256=LnFujAKaM8v5soPwiWG899nUMeaeAe8oTiaJqvoQxg8,8705
33
+ kostyl/ml/optim/schedulers/linear.py,sha256=hlHfU88WLwfke8XGmRUzRKmEpqyyaLb4RRqq6EIzIxU,5881
34
+ kostyl/ml/optim/schedulers/plateau.py,sha256=kjHivSp2FIMXExgS0q_Wbxk9pEHnuzhrIu4TQ59kYNw,11255
27
35
  kostyl/ml/params_groups.py,sha256=nUyw5d06Pvy9QPiYtZzLYR87xwXqJLxbHthgQH8oSCM,3583
28
- kostyl/ml/schedulers/__init__.py,sha256=VIo8MOP4w5Ll24XqFb3QGi2rKvys6c0dEFYPIdDoPlw,526
29
- kostyl/ml/schedulers/base.py,sha256=bjmwgdZpnSqpCnHPnKC6MEiRO79cwxMJpZq-eQVNs2M,1353
30
- kostyl/ml/schedulers/composite.py,sha256=ee4xlMDMMtjKPkbTF2ue9GTr9DuGCGjZWf11mHbi6aE,2387
31
- kostyl/ml/schedulers/cosine.py,sha256=y8ylrgVOkVcr2-ExoqqNW--tdDX88TBYPQCOppIf2_M,8685
32
- kostyl/ml/schedulers/linear.py,sha256=a-Dq5VYoM1Z7vBb4_2Np0MAOfRmO1QVZnOzEoY8nM6k,5834
33
- kostyl/ml/schedulers/plateau.py,sha256=N-hiostPtTR0W4xnEJYB_1dv0DRx39iufLkGUrSIoWE,11235
34
36
  kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
35
37
  kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
36
38
  kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
37
39
  kostyl/utils/logging.py,sha256=CgNFNogcK0hoZmygvBWlTcq5A3m2Pfv9eOAP_gwx0pM,6633
38
- kostyl_toolkit-0.1.39.dist-info/WHEEL,sha256=e_m4S054HL0hyR3CpOk-b7Q7fDX6BuFkgL5OjAExXas,80
39
- kostyl_toolkit-0.1.39.dist-info/METADATA,sha256=9AA-uH9-jSciHx3xh4WGxHkzTyjEESi0AFRitTkIM5w,4136
40
- kostyl_toolkit-0.1.39.dist-info/RECORD,,
40
+ kostyl_toolkit-0.1.41.dist-info/WHEEL,sha256=e_m4S054HL0hyR3CpOk-b7Q7fDX6BuFkgL5OjAExXas,80
41
+ kostyl_toolkit-0.1.41.dist-info/METADATA,sha256=RoZ8dXiU-CwUldPkhjY2achhvcQ2EE4RzhsJKTXe-AM,4121
42
+ kostyl_toolkit-0.1.41.dist-info/RECORD,,
@@ -1,18 +0,0 @@
1
- from .composite import CompositeScheduler
2
- from .cosine import CosineParamScheduler
3
- from .cosine import CosineScheduler
4
- from .linear import LinearParamScheduler
5
- from .linear import LinearScheduler
6
- from .plateau import PlateauWithAnnealingParamScheduler
7
- from .plateau import PlateauWithAnnealingScheduler
8
-
9
-
10
- __all__ = [
11
- "CompositeScheduler",
12
- "CosineParamScheduler",
13
- "CosineScheduler",
14
- "LinearParamScheduler",
15
- "LinearScheduler",
16
- "PlateauWithAnnealingParamScheduler",
17
- "PlateauWithAnnealingScheduler",
18
- ]
File without changes
File without changes