kostyl-toolkit 0.1.38__tar.gz → 0.1.40__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/PKG-INFO +6 -8
  2. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/configs/__init__.py +10 -2
  3. kostyl_toolkit-0.1.40/kostyl/ml/configs/hyperparams.py +126 -0
  4. kostyl_toolkit-0.1.40/kostyl/ml/integrations/clearml/__init__.py +29 -0
  5. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/clearml/checkpoint_uploader.py +40 -29
  6. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/clearml/config_mixin.py +1 -1
  7. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/clearml/loading_utils.py +2 -2
  8. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/lightning/callbacks/checkpoint.py +6 -2
  9. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/lightning/loggers/tb_logger.py +6 -3
  10. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/lightning/module.py +1 -1
  11. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/lightning/utils.py +0 -7
  12. kostyl_toolkit-0.1.40/kostyl/ml/optim/__init__.py +8 -0
  13. kostyl_toolkit-0.1.40/kostyl/ml/optim/factory.py +257 -0
  14. kostyl_toolkit-0.1.40/kostyl/ml/optim/schedulers/__init__.py +56 -0
  15. {kostyl_toolkit-0.1.38/kostyl/ml → kostyl_toolkit-0.1.40/kostyl/ml/optim}/schedulers/cosine.py +2 -2
  16. {kostyl_toolkit-0.1.38/kostyl/ml → kostyl_toolkit-0.1.40/kostyl/ml/optim}/schedulers/linear.py +9 -9
  17. {kostyl_toolkit-0.1.38/kostyl/ml → kostyl_toolkit-0.1.40/kostyl/ml/optim}/schedulers/plateau.py +2 -2
  18. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/pyproject.toml +5 -6
  19. kostyl_toolkit-0.1.38/kostyl/ml/configs/hyperparams.py +0 -94
  20. kostyl_toolkit-0.1.38/kostyl/ml/integrations/clearml/__init__.py +0 -7
  21. kostyl_toolkit-0.1.38/kostyl/ml/schedulers/__init__.py +0 -18
  22. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/README.md +0 -0
  23. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/__init__.py +0 -0
  24. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/__init__.py +0 -0
  25. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/base_uploader.py +0 -0
  26. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/configs/mixins.py +0 -0
  27. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/configs/training_settings.py +0 -0
  28. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/data_collator.py +0 -0
  29. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/dist_utils.py +0 -0
  30. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/__init__.py +0 -0
  31. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/clearml/dataset_utils.py +0 -0
  32. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/clearml/version_utils.py +0 -0
  33. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/lightning/__init__.py +0 -0
  34. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/lightning/callbacks/__init__.py +0 -0
  35. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/lightning/callbacks/early_stopping.py +0 -0
  36. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/lightning/loggers/__init__.py +0 -0
  37. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/lightning/metrics_formatting.py +0 -0
  38. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/integrations/lightning/mixins.py +0 -0
  39. {kostyl_toolkit-0.1.38/kostyl/ml → kostyl_toolkit-0.1.40/kostyl/ml/optim}/schedulers/base.py +0 -0
  40. {kostyl_toolkit-0.1.38/kostyl/ml → kostyl_toolkit-0.1.40/kostyl/ml/optim}/schedulers/composite.py +0 -0
  41. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/ml/params_groups.py +0 -0
  42. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/utils/__init__.py +0 -0
  43. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/utils/dict_manipulations.py +0 -0
  44. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/utils/fs.py +0 -0
  45. {kostyl_toolkit-0.1.38 → kostyl_toolkit-0.1.40}/kostyl/utils/logging.py +0 -0
@@ -1,17 +1,15 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.38
3
+ Version: 0.1.40
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
7
- Requires-Dist: case-converter>=1.2.0 ; extra == 'ml-core'
8
- Requires-Dist: clearml[s3]>=2.0.2 ; extra == 'ml-core'
9
- Requires-Dist: lightning>=2.5.6 ; extra == 'ml-core'
10
- Requires-Dist: pydantic>=2.12.4 ; extra == 'ml-core'
11
- Requires-Dist: torch>=2.9.1 ; extra == 'ml-core'
12
- Requires-Dist: transformers>=4.57.1 ; extra == 'ml-core'
7
+ Requires-Dist: case-converter>=1.2.0 ; extra == 'ml'
8
+ Requires-Dist: pydantic>=2.12.4 ; extra == 'ml'
9
+ Requires-Dist: torch>=2.9.1 ; extra == 'ml'
10
+ Requires-Dist: transformers>=4.57.1 ; extra == 'ml'
13
11
  Requires-Python: >=3.12
14
- Provides-Extra: ml-core
12
+ Provides-Extra: ml
15
13
  Description-Content-Type: text/markdown
16
14
 
17
15
  # Kostyl Toolkit
@@ -1,6 +1,10 @@
1
+ from .hyperparams import OPTIMIZER_CONFIG
2
+ from .hyperparams import AdamConfig
3
+ from .hyperparams import AdamWithPrecisionConfig
1
4
  from .hyperparams import HyperparamsConfig
2
5
  from .hyperparams import Lr
3
- from .hyperparams import Optimizer
6
+ from .hyperparams import MuonConfig
7
+ from .hyperparams import ScheduledParamConfig
4
8
  from .hyperparams import WeightDecay
5
9
  from .mixins import ConfigLoadingMixin
6
10
  from .training_settings import CheckpointConfig
@@ -14,6 +18,9 @@ from .training_settings import TrainingSettings
14
18
 
15
19
 
16
20
  __all__ = [
21
+ "OPTIMIZER_CONFIG",
22
+ "AdamConfig",
23
+ "AdamWithPrecisionConfig",
17
24
  "CheckpointConfig",
18
25
  "ConfigLoadingMixin",
19
26
  "DDPStrategyConfig",
@@ -23,7 +30,8 @@ __all__ = [
23
30
  "HyperparamsConfig",
24
31
  "LightningTrainerParameters",
25
32
  "Lr",
26
- "Optimizer",
33
+ "MuonConfig",
34
+ "ScheduledParamConfig",
27
35
  "SingleDeviceStrategyConfig",
28
36
  "TrainingSettings",
29
37
  "WeightDecay",
@@ -0,0 +1,126 @@
1
+ from typing import Literal
2
+
3
+ from pydantic import BaseModel
4
+ from pydantic import Field
5
+ from pydantic import model_validator
6
+
7
+ from kostyl.utils.logging import setup_logger
8
+
9
+
10
+ logger = setup_logger(fmt="only_message")
11
+
12
+
13
+ class AdamConfig(BaseModel):
14
+ """Adam optimizer hyperparameters configuration."""
15
+
16
+ type: Literal["AdamW", "Adam"] = "AdamW"
17
+ betas: tuple[float, float] = (0.9, 0.999)
18
+
19
+
20
+ class MuonConfig(BaseModel):
21
+ """Muon optimizer hyperparameters configuration."""
22
+
23
+ type: Literal["Muon"]
24
+ momentum: float = 0.95
25
+ nesterov: bool = True
26
+ ns_coefficients: tuple[float, float, float] = (3.4445, -4.7750, 2.0315)
27
+ ns_steps: int = 5
28
+
29
+
30
+ class AdamWithPrecisionConfig(BaseModel):
31
+ """Adam optimizer with low-precision hyperparameters configuration."""
32
+
33
+ type: Literal[
34
+ "Adam8bit", "Adam4bit", "AdamFp8", "AdamW8bit", "AdamW4bit", "AdamWFp8"
35
+ ]
36
+ betas: tuple[float, float] = (0.9, 0.999)
37
+ block_size: int
38
+ bf16_stochastic_round: bool = False
39
+
40
+
41
+ OPTIMIZER_CONFIG = AdamConfig | AdamWithPrecisionConfig | MuonConfig
42
+ SCHEDULER = Literal[
43
+ "linear",
44
+ "cosine",
45
+ "plateau-with-cosine-annealing",
46
+ "plateau-with-linear-annealing",
47
+ ]
48
+
49
+
50
+ class ScheduledParamConfig(BaseModel):
51
+ """Base configuration for a scheduled hyperparameter."""
52
+
53
+ scheduler_type: SCHEDULER | None = None
54
+
55
+ freeze_ratio: float | None = Field(default=None, ge=0, le=1)
56
+ warmup_ratio: float | None = Field(default=None, gt=0, lt=1, validate_default=False)
57
+ warmup_value: float | None = Field(default=None, gt=0, validate_default=False)
58
+ base_value: float
59
+ final_value: float | None = Field(default=None, gt=0, validate_default=False)
60
+ plateau_ratio: float | None = Field(
61
+ default=None, gt=0, lt=1, validate_default=False
62
+ )
63
+
64
+ @model_validator(mode="after")
65
+ def _validate_freeze_ratio(self) -> "ScheduledParamConfig":
66
+ if self.scheduler_type is None and self.freeze_ratio is not None:
67
+ logger.warning("use_scheduler is False, freeze_ratio will be ignored.")
68
+ self.freeze_ratio = None
69
+ return self
70
+
71
+ @model_validator(mode="after")
72
+ def _validate_warmup(self) -> "ScheduledParamConfig":
73
+ if ((self.warmup_value is not None) or (self.warmup_ratio is not None)) and self.scheduler_type is None: # fmt: skip
74
+ logger.warning(
75
+ "scheduler_type is None, warmup_value and warmup_ratio will be ignored."
76
+ )
77
+ self.warmup_value = None
78
+ self.warmup_ratio = None
79
+ if (self.warmup_value is None) != (self.warmup_ratio is None): # fmt: skip
80
+ raise ValueError(
81
+ "Both warmup_value and warmup_ratio must be provided or neither"
82
+ )
83
+ return self
84
+
85
+ @model_validator(mode="after")
86
+ def _validate_final_value(self) -> "ScheduledParamConfig":
87
+ if (self.scheduler_type in {"linear"}) and (self.final_value is not None):
88
+ raise ValueError("If scheduler_type is 'linear', final_value must be None.")
89
+ if (self.scheduler_type is None) and (self.final_value is not None):
90
+ logger.warning("use_scheduler is False, final_value will be ignored.")
91
+ self.final_value = None
92
+ return self
93
+
94
+ @model_validator(mode="after")
95
+ def _validate_plateau_ratio(self) -> "ScheduledParamConfig":
96
+ if self.scheduler_type is not None:
97
+ if self.scheduler_type.startswith("plateau") and self.plateau_ratio is None:
98
+ raise ValueError(
99
+ "If scheduler_type is 'plateau-with-*', plateau_ratio must be provided."
100
+ )
101
+ if (
102
+ not self.scheduler_type.startswith("plateau")
103
+ and self.plateau_ratio is not None
104
+ ):
105
+ logger.warning(
106
+ "scheduler_type is not 'plateau-with-*', plateau_ratio will be ignored."
107
+ )
108
+ self.plateau_ratio = None
109
+ return self
110
+
111
+
112
+ class Lr(ScheduledParamConfig):
113
+ """Learning rate hyperparameters configuration."""
114
+
115
+
116
+ class WeightDecay(ScheduledParamConfig):
117
+ """Weight decay hyperparameters configuration."""
118
+
119
+
120
+ class HyperparamsConfig(BaseModel):
121
+ """Model training hyperparameters configuration."""
122
+
123
+ grad_clip_val: float | None = Field(default=None, gt=0, validate_default=False)
124
+ optimizer: OPTIMIZER_CONFIG
125
+ lr: Lr
126
+ weight_decay: WeightDecay
@@ -0,0 +1,29 @@
1
+ try:
2
+ import clearml # noqa: F401
3
+ except ImportError as e:
4
+ raise ImportError(
5
+ "ClearML integration requires the 'clearml' package. "
6
+ "Please install it via 'pip install clearml'."
7
+ ) from e
8
+ from .checkpoint_uploader import ClearMLCheckpointUploader
9
+ from .config_mixin import ConfigSyncingClearmlMixin
10
+ from .dataset_utils import collect_clearml_datasets
11
+ from .dataset_utils import download_clearml_datasets
12
+ from .dataset_utils import get_datasets_paths
13
+ from .loading_utils import load_model_from_clearml
14
+ from .loading_utils import load_tokenizer_from_clearml
15
+ from .version_utils import find_version_in_tags
16
+ from .version_utils import increment_version
17
+
18
+
19
+ __all__ = [
20
+ "ClearMLCheckpointUploader",
21
+ "ConfigSyncingClearmlMixin",
22
+ "collect_clearml_datasets",
23
+ "download_clearml_datasets",
24
+ "find_version_in_tags",
25
+ "get_datasets_paths",
26
+ "increment_version",
27
+ "load_model_from_clearml",
28
+ "load_tokenizer_from_clearml",
29
+ ]
@@ -1,5 +1,5 @@
1
1
  from collections.abc import Callable
2
- from functools import partial
2
+ from datetime import datetime
3
3
  from pathlib import Path
4
4
  from typing import override
5
5
 
@@ -24,7 +24,7 @@ class ClearMLCheckpointUploader(ModelCheckpointUploader):
24
24
  comment: str | None = None,
25
25
  framework: str | None = None,
26
26
  base_model_id: str | None = None,
27
- new_model_per_upload: bool = True,
27
+ upload_as_new_model: bool = True,
28
28
  verbose: bool = True,
29
29
  ) -> None:
30
30
  """
@@ -38,20 +38,22 @@ class ClearMLCheckpointUploader(ModelCheckpointUploader):
38
38
  comment: A comment / description for the model.
39
39
  framework: The framework of the model (e.g., "PyTorch", "TensorFlow").
40
40
  base_model_id: Optional ClearML model ID to use as a base for the new model
41
- new_model_per_upload: Whether to create a new ClearML model
42
- for every upload or update weights of the same model. When updating weights,
43
- the last uploaded checkpoint will be replaced (and deleted).
41
+ upload_as_new_model: Whether to create a new ClearML model
42
+ for every upload or update weights of the same model. When True,
43
+ each checkpoint is uploaded as a separate model with timestamp added to the name.
44
+ When False, weights of the same model are updated.
44
45
  verbose: Whether to log messages during upload.
45
46
 
46
47
  """
47
48
  super().__init__()
48
- if base_model_id is not None and new_model_per_upload:
49
+ if base_model_id is not None and upload_as_new_model:
49
50
  raise ValueError(
50
- "Cannot set base_model_id when new_model_per_upload is True."
51
+ "Cannot set base_model_id when upload_as_new_model is True."
51
52
  )
52
53
 
53
54
  self.verbose = verbose
54
- self.new_model_per_upload = new_model_per_upload
55
+ self.upload_as_new_model = upload_as_new_model
56
+ self.model_name = model_name
55
57
  self.best_model_path: str = ""
56
58
  self.config_dict = config_dict
57
59
  self._output_model: OutputModel | None = None
@@ -59,15 +61,13 @@ class ClearMLCheckpointUploader(ModelCheckpointUploader):
59
61
  self._upload_callback: Callable | None = None
60
62
 
61
63
  self._validate_tags(tags)
62
- self.model_fabric = partial(
63
- OutputModel,
64
- name=model_name,
65
- label_enumeration=label_enumeration,
66
- tags=tags,
67
- comment=comment,
68
- framework=framework,
69
- base_model_id=base_model_id,
70
- )
64
+ self.model_fabric_kwargs = {
65
+ "label_enumeration": label_enumeration,
66
+ "tags": tags,
67
+ "comment": comment,
68
+ "framework": framework,
69
+ "base_model_id": base_model_id,
70
+ }
71
71
  return
72
72
 
73
73
  @staticmethod
@@ -78,16 +78,22 @@ class ClearMLCheckpointUploader(ModelCheckpointUploader):
78
78
  tags.append("LightningCheckpoint")
79
79
  return None
80
80
 
81
- @property
82
- def output_model_(self) -> OutputModel:
83
- """Returns the OutputModel instance based on `new_model_per_upload` setting."""
84
- if self.new_model_per_upload:
85
- model = self.model_fabric()
86
- self._output_model = self.model_fabric()
87
- else:
88
- if self._output_model is None:
89
- self._output_model = self.model_fabric()
90
- model = self._output_model
81
+ def _create_new_model(self) -> OutputModel:
82
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
83
+ model_name_with_timestamp = f"{self.model_name}_{timestamp}"
84
+ model = OutputModel(
85
+ name=model_name_with_timestamp,
86
+ **self.model_fabric_kwargs,
87
+ )
88
+ return model
89
+
90
+ def _get_output_model(self) -> OutputModel:
91
+ if self._output_model is None:
92
+ self._output_model = OutputModel(
93
+ name=self.model_name,
94
+ **self.model_fabric_kwargs,
95
+ )
96
+ model = self._output_model
91
97
  return model
92
98
 
93
99
  @override
@@ -105,12 +111,17 @@ class ClearMLCheckpointUploader(ModelCheckpointUploader):
105
111
  if self.verbose:
106
112
  logger.info(f"Uploading model from {path}")
107
113
 
108
- self.output_model_.update_weights(
114
+ if self.upload_as_new_model:
115
+ output_model = self._create_new_model()
116
+ else:
117
+ output_model = self._get_output_model()
118
+
119
+ output_model.update_weights(
109
120
  path,
110
121
  auto_delete_file=False,
111
122
  async_enable=False,
112
123
  )
113
- self.output_model_.update_design(config_dict=self.config_dict)
124
+ output_model.update_design(config_dict=self.config_dict)
114
125
 
115
126
  self._last_uploaded_model_path = path
116
127
  return
@@ -10,7 +10,7 @@ from kostyl.utils.dict_manipulations import flattened_dict_to_nested
10
10
  from kostyl.utils.fs import load_config
11
11
 
12
12
 
13
- class BaseModelWithClearmlSyncing[TConfig: ConfigLoadingMixin]:
13
+ class ConfigSyncingClearmlMixin[TConfig: ConfigLoadingMixin]:
14
14
  """Mixin providing ClearML task configuration syncing functionality for Pydantic models."""
15
15
 
16
16
  @classmethod
@@ -31,7 +31,7 @@ except ImportError:
31
31
  LIGHTING_MIXIN_AVAILABLE = False
32
32
 
33
33
 
34
- def get_tokenizer_from_clearml(
34
+ def load_tokenizer_from_clearml(
35
35
  model_id: str,
36
36
  task: Task | None = None,
37
37
  ignore_remote_overrides: bool = True,
@@ -66,7 +66,7 @@ def get_tokenizer_from_clearml(
66
66
  return tokenizer, clearml_tokenizer
67
67
 
68
68
 
69
- def get_model_from_clearml[
69
+ def load_model_from_clearml[
70
70
  TModel: PreTrainedModel | LightningCheckpointLoaderMixin | AutoModel
71
71
  ](
72
72
  model_id: str,
@@ -15,7 +15,7 @@ from kostyl.ml.dist_utils import is_local_zero_rank
15
15
  from kostyl.utils import setup_logger
16
16
 
17
17
 
18
- logger = setup_logger("callbacks/checkpoint.py")
18
+ logger = setup_logger()
19
19
 
20
20
 
21
21
  class ModelCheckpointWithCheckpointUploader(ModelCheckpoint):
@@ -278,6 +278,10 @@ class ModelCheckpointWithCheckpointUploader(ModelCheckpoint):
278
278
  case "only-best":
279
279
  if filepath == self.best_model_path:
280
280
  self.registry_uploader.upload_checkpoint(filepath)
281
+ case _:
282
+ logger.warning_once(
283
+ "Unknown upload strategy for checkpoint uploader. Skipping upload."
284
+ )
281
285
  return
282
286
 
283
287
 
@@ -286,7 +290,7 @@ def setup_checkpoint_callback(
286
290
  ckpt_cfg: CheckpointConfig,
287
291
  checkpoint_uploader: ModelCheckpointUploader | None = None,
288
292
  upload_strategy: Literal["only-best", "every-checkpoint"] | None = None,
289
- remove_folder_if_exists: bool = True,
293
+ remove_folder_if_exists: bool = False,
290
294
  ) -> ModelCheckpointWithCheckpointUploader | ModelCheckpoint:
291
295
  """
292
296
  Create and configure a checkpoint callback for model saving.
@@ -11,14 +11,17 @@ logger = setup_logger()
11
11
 
12
12
 
13
13
  def setup_tb_logger(
14
- runs_dir: Path,
14
+ runs_dir: Path, remove_folder_if_exists: bool = False
15
15
  ) -> TensorBoardLogger:
16
16
  """Sets up a TensorBoardLogger for PyTorch Lightning."""
17
17
  if runs_dir.exists():
18
18
  if is_local_zero_rank():
19
19
  logger.warning(f"TensorBoard log directory {runs_dir} already exists.")
20
- rmtree(runs_dir)
21
- logger.warning(f"Removed existing TensorBoard log directory {runs_dir}.")
20
+ if remove_folder_if_exists:
21
+ rmtree(runs_dir)
22
+ logger.warning(
23
+ f"Removed existing TensorBoard log directory {runs_dir}."
24
+ )
22
25
  else:
23
26
  logger.info(f"Creating TensorBoard log directory {runs_dir}.")
24
27
  runs_dir.mkdir(parents=True, exist_ok=True)
@@ -14,7 +14,7 @@ from transformers import PretrainedConfig
14
14
  from transformers import PreTrainedModel
15
15
 
16
16
  from kostyl.ml.integrations.lightning.metrics_formatting import apply_suffix
17
- from kostyl.ml.schedulers.base import BaseScheduler
17
+ from kostyl.ml.optim.schedulers import BaseScheduler
18
18
  from kostyl.utils import setup_logger
19
19
 
20
20
 
@@ -4,16 +4,9 @@ import lightning as L
4
4
  import torch.distributed as dist
5
5
  from torch.distributed import ProcessGroup
6
6
 
7
- from kostyl.ml.configs import DDPStrategyConfig
8
- from kostyl.ml.configs import FSDP1StrategyConfig
9
- from kostyl.ml.configs import SingleDeviceStrategyConfig
10
7
  from kostyl.utils.logging import setup_logger
11
8
 
12
9
 
13
- TRAINING_STRATEGIES = (
14
- FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
15
- )
16
-
17
10
  logger = setup_logger()
18
11
 
19
12
 
@@ -0,0 +1,8 @@
1
+ from .factory import create_optimizer
2
+ from .factory import create_scheduler
3
+
4
+
5
+ __all__ = [
6
+ "create_optimizer",
7
+ "create_scheduler",
8
+ ]
@@ -0,0 +1,257 @@
1
+ from typing import Any
2
+
3
+ from torch.optim import Optimizer
4
+
5
+ from kostyl.ml.configs import OPTIMIZER_CONFIG
6
+ from kostyl.ml.configs import AdamConfig
7
+ from kostyl.ml.configs import AdamWithPrecisionConfig
8
+ from kostyl.ml.configs import MuonConfig
9
+ from kostyl.ml.configs import ScheduledParamConfig
10
+ from kostyl.utils import setup_logger
11
+
12
+ from .schedulers import SCHEDULER_MAPPING
13
+ from .schedulers import CosineScheduler
14
+ from .schedulers import LinearScheduler
15
+ from .schedulers import PlateauWithAnnealingScheduler
16
+
17
+
18
+ logger = setup_logger(fmt="only_message")
19
+
20
+
21
+ def create_scheduler(
22
+ config: ScheduledParamConfig,
23
+ param_group_field: str,
24
+ num_iters: int,
25
+ optim: Optimizer,
26
+ multiplier_field: str | None = None,
27
+ skip_if_zero: bool = False,
28
+ apply_if_field: str | None = None,
29
+ ignore_if_field: str | None = None,
30
+ ) -> LinearScheduler | CosineScheduler | PlateauWithAnnealingScheduler:
31
+ """
32
+ Converts a ScheduledParamConfig to a scheduler instance.
33
+
34
+ Args:
35
+ config: Configuration object for the scheduler.
36
+ param_group_field: The field name in the optimizer's param groups to schedule.
37
+ num_iters: Total number of iterations.
38
+ optim: The optimizer instance.
39
+ multiplier_field: Optional per-group field name that contains a multiplier applied to the scheduled value. If None, no multiplier is applied.
40
+ skip_if_zero: Leave groups untouched when their target field equals zero.
41
+ Default is False.
42
+ apply_if_field: Require this key to be present in a param group before updating.
43
+ ignore_if_field: Skip groups that declare this key in their dictionaries.
44
+
45
+ Returns:
46
+ A scheduler instance based on the configuration.
47
+
48
+ """
49
+ if config.scheduler_type is None:
50
+ raise ValueError("scheduler_type must be specified in the config.")
51
+
52
+ if "plateau" in config.scheduler_type:
53
+ scheduler_type = "plateau"
54
+ else:
55
+ scheduler_type = config.scheduler_type
56
+ scheduler_cls = SCHEDULER_MAPPING[scheduler_type] # type: ignore
57
+
58
+ if issubclass(scheduler_cls, PlateauWithAnnealingScheduler):
59
+ if "cosine" in config.scheduler_type:
60
+ annealing_type = "cosine"
61
+ elif "linear" in config.scheduler_type:
62
+ annealing_type = "linear"
63
+ else:
64
+ raise ValueError(f"Unknown annealing_type: {config.scheduler_type}")
65
+ scheduler = scheduler_cls(
66
+ optimizer=optim,
67
+ param_group_field=param_group_field,
68
+ num_iters=num_iters,
69
+ plateau_value=config.base_value,
70
+ final_value=config.final_value, # type: ignore
71
+ warmup_ratio=config.warmup_ratio,
72
+ warmup_value=config.warmup_value,
73
+ freeze_ratio=config.freeze_ratio,
74
+ plateau_ratio=config.plateau_ratio, # type: ignore
75
+ annealing_type=annealing_type,
76
+ multiplier_field=multiplier_field,
77
+ skip_if_zero=skip_if_zero,
78
+ apply_if_field=apply_if_field,
79
+ ignore_if_field=ignore_if_field,
80
+ )
81
+ elif issubclass(scheduler_cls, LinearScheduler):
82
+ scheduler = scheduler_cls(
83
+ optimizer=optim,
84
+ param_group_field=param_group_field,
85
+ num_iters=num_iters,
86
+ initial_value=config.base_value,
87
+ final_value=config.final_value, # type: ignore
88
+ multiplier_field=multiplier_field,
89
+ skip_if_zero=skip_if_zero,
90
+ apply_if_field=apply_if_field,
91
+ ignore_if_field=ignore_if_field,
92
+ )
93
+ elif issubclass(scheduler_cls, CosineScheduler):
94
+ scheduler = scheduler_cls(
95
+ optimizer=optim,
96
+ param_group_field=param_group_field,
97
+ num_iters=num_iters,
98
+ base_value=config.base_value,
99
+ final_value=config.final_value, # type: ignore
100
+ warmup_ratio=config.warmup_ratio,
101
+ warmup_value=config.warmup_value,
102
+ freeze_ratio=config.freeze_ratio,
103
+ multiplier_field=multiplier_field,
104
+ skip_if_zero=skip_if_zero,
105
+ apply_if_field=apply_if_field,
106
+ ignore_if_field=ignore_if_field,
107
+ )
108
+ else:
109
+ raise ValueError(f"Unsupported scheduler type: {config.scheduler_type}")
110
+ return scheduler
111
+
112
+
113
+ def create_optimizer( # noqa: C901
114
+ parameters_groups: dict[str, Any],
115
+ optimizer_config: OPTIMIZER_CONFIG,
116
+ lr: float,
117
+ weight_decay: float,
118
+ ) -> Optimizer:
119
+ """
120
+ Creates an optimizer based on the configuration.
121
+
122
+ Args:
123
+ parameters_groups: Dictionary containing model parameters
124
+ (key "params" and per-group options, i.e. "lr", "weight_decay" and etc.).
125
+ optimizer_config: Configuration for the optimizer.
126
+ lr: Learning rate.
127
+ weight_decay: Weight decay.
128
+
129
+ Returns:
130
+ An instantiated optimizer.
131
+
132
+ """
133
+ if isinstance(optimizer_config, AdamConfig):
134
+ match optimizer_config.type:
135
+ case "Adam":
136
+ from torch.optim import Adam
137
+
138
+ optimizer = Adam(
139
+ params=parameters_groups["params"],
140
+ lr=lr,
141
+ weight_decay=weight_decay,
142
+ betas=optimizer_config.betas,
143
+ )
144
+
145
+ case "AdamW":
146
+ from torch.optim import AdamW
147
+
148
+ optimizer = AdamW(
149
+ params=parameters_groups["params"],
150
+ lr=lr,
151
+ weight_decay=weight_decay,
152
+ betas=optimizer_config.betas,
153
+ )
154
+ return optimizer
155
+ case _:
156
+ raise ValueError(f"Unsupported optimizer type: {optimizer_config.type}")
157
+ elif isinstance(optimizer_config, MuonConfig):
158
+ from torch.optim import Muon
159
+
160
+ optimizer = Muon(
161
+ params=parameters_groups["params"],
162
+ lr=lr,
163
+ weight_decay=weight_decay,
164
+ momentum=optimizer_config.momentum,
165
+ nesterov=optimizer_config.nesterov,
166
+ ns_coefficients=optimizer_config.ns_coefficients,
167
+ ns_steps=optimizer_config.ns_steps,
168
+ )
169
+ elif isinstance(optimizer_config, AdamWithPrecisionConfig):
170
+ try:
171
+ import torchao # noqa: F401
172
+ except ImportError as e:
173
+ raise ImportError(
174
+ "torchao is required for low-precision Adam optimizers. "
175
+ "Please install it via 'pip install torchao'."
176
+ ) from e
177
+ match optimizer_config.type:
178
+ case "Adam8bit":
179
+ from torchao.optim import Adam8bit
180
+
181
+ logger.warning(
182
+ "Ignoring weight_decay for Adam8bit optimizer as it is not supported."
183
+ )
184
+
185
+ optimizer = Adam8bit(
186
+ params=parameters_groups["params"],
187
+ lr=lr,
188
+ betas=optimizer_config.betas,
189
+ block_size=optimizer_config.block_size,
190
+ bf16_stochastic_round=optimizer_config.bf16_stochastic_round,
191
+ )
192
+ case "Adam4bit":
193
+ from torchao.optim import Adam4bit
194
+
195
+ logger.warning(
196
+ "Ignoring weight_decay for Adam4bit optimizer as it is not supported."
197
+ )
198
+
199
+ optimizer = Adam4bit(
200
+ params=parameters_groups["params"],
201
+ lr=lr,
202
+ betas=optimizer_config.betas,
203
+ block_size=optimizer_config.block_size,
204
+ bf16_stochastic_round=optimizer_config.bf16_stochastic_round,
205
+ )
206
+ case "AdamFp8":
207
+ from torchao.optim import AdamFp8
208
+
209
+ logger.warning(
210
+ "Ignoring weight_decay for AdamFp8 optimizer as it is not supported."
211
+ )
212
+
213
+ optimizer = AdamFp8(
214
+ params=parameters_groups["params"],
215
+ lr=lr,
216
+ betas=optimizer_config.betas,
217
+ block_size=optimizer_config.block_size,
218
+ bf16_stochastic_round=optimizer_config.bf16_stochastic_round,
219
+ )
220
+ case "AdamW8bit":
221
+ from torchao.optim import AdamW8bit
222
+
223
+ optimizer = AdamW8bit(
224
+ params=parameters_groups["params"],
225
+ lr=lr,
226
+ weight_decay=weight_decay,
227
+ betas=optimizer_config.betas,
228
+ block_size=optimizer_config.block_size,
229
+ bf16_stochastic_round=optimizer_config.bf16_stochastic_round,
230
+ )
231
+ case "AdamW4bit":
232
+ from torchao.optim import AdamW4bit
233
+
234
+ optimizer = AdamW4bit(
235
+ params=parameters_groups["params"],
236
+ lr=lr,
237
+ weight_decay=weight_decay,
238
+ betas=optimizer_config.betas,
239
+ block_size=optimizer_config.block_size,
240
+ bf16_stochastic_round=optimizer_config.bf16_stochastic_round,
241
+ )
242
+ case "AdamWFp8":
243
+ from torchao.optim import AdamWFp8
244
+
245
+ optimizer = AdamWFp8(
246
+ params=parameters_groups["params"],
247
+ lr=lr,
248
+ weight_decay=weight_decay,
249
+ betas=optimizer_config.betas,
250
+ block_size=optimizer_config.block_size,
251
+ bf16_stochastic_round=optimizer_config.bf16_stochastic_round,
252
+ )
253
+ case _:
254
+ raise ValueError(f"Unsupported optimizer type: {optimizer_config.type}")
255
+ else:
256
+ raise ValueError("Unsupported optimizer configuration type.")
257
+ return optimizer
@@ -0,0 +1,56 @@
1
+ from typing import TypedDict
2
+
3
+ from .base import BaseScheduler
4
+ from .composite import CompositeScheduler
5
+ from .cosine import CosineParamScheduler
6
+ from .cosine import CosineScheduler
7
+ from .linear import LinearParamScheduler
8
+ from .linear import LinearScheduler
9
+ from .plateau import PlateauWithAnnealingParamScheduler
10
+ from .plateau import PlateauWithAnnealingScheduler
11
+
12
+
13
+ class SchedulerMapping(TypedDict):
14
+ """Map names to scheduler classes."""
15
+
16
+ linear: type[LinearScheduler]
17
+ cosine: type[CosineScheduler]
18
+ plateau: type[PlateauWithAnnealingScheduler]
19
+ composite: type[CompositeScheduler]
20
+
21
+
22
+ class ParamSchedulerMapping(TypedDict):
23
+ """Map names to scheduler classes."""
24
+
25
+ linear: type[LinearParamScheduler]
26
+ cosine: type[CosineParamScheduler]
27
+ plateau: type[PlateauWithAnnealingParamScheduler]
28
+
29
+
30
+ SCHEDULER_MAPPING: SchedulerMapping = {
31
+ "linear": LinearScheduler,
32
+ "cosine": CosineScheduler,
33
+ "plateau": PlateauWithAnnealingScheduler,
34
+ "composite": CompositeScheduler,
35
+ }
36
+
37
+
38
+ PARAM_SCHEDULER_MAPPING: ParamSchedulerMapping = {
39
+ "linear": LinearParamScheduler,
40
+ "cosine": CosineParamScheduler,
41
+ "plateau": PlateauWithAnnealingParamScheduler,
42
+ }
43
+
44
+
45
+ __all__ = [
46
+ "PARAM_SCHEDULER_MAPPING",
47
+ "SCHEDULER_MAPPING",
48
+ "BaseScheduler",
49
+ "CompositeScheduler",
50
+ "CosineParamScheduler",
51
+ "CosineScheduler",
52
+ "LinearParamScheduler",
53
+ "LinearScheduler",
54
+ "PlateauWithAnnealingParamScheduler",
55
+ "PlateauWithAnnealingScheduler",
56
+ ]
@@ -145,8 +145,8 @@ class CosineScheduler(_CosineSchedulerCore):
145
145
  freeze_ratio: Optional fraction of iterations to keep the value frozen at zero at the beginning.
146
146
  multiplier_field: Optional per-group multiplier applied to the scheduled value.
147
147
  skip_if_zero: Leave groups untouched when their target field equals zero.
148
- apply_if_field: Require this flag to be present in a param group before updating.
149
- ignore_if_field: Skip groups that declare this flag.
148
+ apply_if_field: Require this key to be present in a param group before updating.
149
+ ignore_if_field: Skip groups that declare this key in their dictionaries.
150
150
 
151
151
  """
152
152
  self.apply_if_field = apply_if_field
@@ -13,21 +13,21 @@ class _LinearScheduleBase(BaseScheduler):
13
13
  self,
14
14
  param_name: str,
15
15
  num_iters: int,
16
- start_value: float,
16
+ initial_value: float,
17
17
  final_value: float,
18
18
  ) -> None:
19
19
  self.param_name = param_name
20
20
  self.num_iters = num_iters
21
- self.start_value = start_value
21
+ self.initial_value = initial_value
22
22
  self.final_value = final_value
23
23
 
24
24
  self.scheduled_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
25
- self.current_value_ = self.start_value
25
+ self.current_value_ = self.initial_value
26
26
  return
27
27
 
28
28
  def _create_scheduler(self) -> None:
29
29
  self.scheduled_values = np.linspace(
30
- self.start_value, self.final_value, num=self.num_iters, dtype=np.float64
30
+ self.initial_value, self.final_value, num=self.num_iters, dtype=np.float64
31
31
  )
32
32
  self._verify()
33
33
  return
@@ -68,7 +68,7 @@ class LinearScheduler(_LinearScheduleBase):
68
68
  optimizer: torch.optim.Optimizer,
69
69
  param_group_field: str,
70
70
  num_iters: int,
71
- start_value: float,
71
+ initial_value: float,
72
72
  final_value: float,
73
73
  multiplier_field: str | None = None,
74
74
  skip_if_zero: bool = False,
@@ -82,12 +82,12 @@ class LinearScheduler(_LinearScheduleBase):
82
82
  optimizer: Optimizer whose param groups are updated in-place.
83
83
  param_group_field: Name of the field that receives the scheduled value.
84
84
  num_iters: Number of scheduler iterations before clamping at ``final_value``.
85
- start_value: Value used on the first iteration.
85
+ initial_value: Value used on the first iteration.
86
86
  final_value: Value used once ``num_iters`` iterations are consumed.
87
87
  multiplier_field: Optional per-group multiplier applied to the scheduled value.
88
88
  skip_if_zero: Leave groups untouched when their target field equals zero.
89
- apply_if_field: Require this flag to be present in a param group before updating.
90
- ignore_if_field: Skip groups that declare this flag.
89
+ apply_if_field: Require this key to be present in a param group before updating.
90
+ ignore_if_field: Skip groups that declare this key in their dictionaries.
91
91
 
92
92
  """
93
93
  self.apply_if_field = apply_if_field
@@ -98,7 +98,7 @@ class LinearScheduler(_LinearScheduleBase):
98
98
  super().__init__(
99
99
  param_name=param_group_field,
100
100
  num_iters=num_iters,
101
- start_value=start_value,
101
+ initial_value=initial_value,
102
102
  final_value=final_value,
103
103
  )
104
104
  self.param_group_field = param_group_field
@@ -198,8 +198,8 @@ class PlateauWithAnnealingScheduler(_PlateauWithAnnealingCore):
198
198
  annealing_type: Type of annealing from plateau to final value ("cosine" or "linear").
199
199
  multiplier_field: Optional per-group multiplier applied to the scheduled value.
200
200
  skip_if_zero: Leave groups untouched when their target field equals zero.
201
- apply_if_field: Require this flag to be present in a param group before updating.
202
- ignore_if_field: Skip groups that declare this flag.
201
+ apply_if_field: Require this key to be present in a param group before updating.
202
+ ignore_if_field: Skip groups that declare this key in their dictionaries.
203
203
 
204
204
  """
205
205
  self.apply_if_field = apply_if_field
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "kostyl-toolkit"
3
- version = "0.1.38"
3
+ version = "0.1.40"
4
4
  description = "Kickass Orchestration System for Training, Yielding & Logging "
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -10,10 +10,8 @@ dependencies = [
10
10
  ]
11
11
 
12
12
  [project.optional-dependencies]
13
- ml-core = [
13
+ ml = [
14
14
  "case-converter>=1.2.0",
15
- "clearml[s3]>=2.0.2",
16
- "lightning>=2.5.6",
17
15
  "pydantic>=2.12.4",
18
16
  "torch>=2.9.1",
19
17
  "transformers>=4.57.1",
@@ -30,18 +28,19 @@ dev = [
30
28
  "pyarrow>=22.0.0",
31
29
  ]
32
30
 
33
- ml-core = [
31
+ ml = [
34
32
  "case-converter>=1.2.0",
35
33
  "clearml[s3]>=2.0.2",
36
34
  "lightning>=2.5.6",
37
35
  "pydantic>=2.12.4",
38
36
  "torch>=2.9.1",
37
+ "torchao>=0.15.0",
39
38
  "transformers>=4.57.1",
40
39
  ]
41
40
 
42
41
 
43
42
  [tool.uv]
44
- default-groups = ["dev", "ml-core"]
43
+ default-groups = ["dev", "ml"]
45
44
 
46
45
  [build-system]
47
46
  requires = ["uv_build>=0.9.9,<0.10.0"]
@@ -1,94 +0,0 @@
1
- from typing import Literal
2
-
3
- from pydantic import BaseModel
4
- from pydantic import Field
5
- from pydantic import model_validator
6
-
7
- from kostyl.utils.logging import setup_logger
8
-
9
-
10
- logger = setup_logger(fmt="only_message")
11
-
12
-
13
- class AdamConfig(BaseModel):
14
- """AdamW optimizer hyperparameters configuration."""
15
-
16
- type: Literal["AdamW"] = "AdamW"
17
- betas: tuple[float, float] = (0.9, 0.999)
18
- is_adamw: bool = True
19
-
20
-
21
- class AdamWithPrecisionConfig(BaseModel):
22
- """Adam optimizer with low-precision hyperparameters configuration."""
23
-
24
- type: Literal["Adam8bit", "Adam4bit", "AdamFp8"]
25
- betas: tuple[float, float] = (0.9, 0.999)
26
- block_size: int
27
- bf16_stochastic_round: bool = False
28
- is_adamw: bool = True
29
-
30
-
31
- Optimizer = AdamConfig | AdamWithPrecisionConfig
32
-
33
-
34
- class Lr(BaseModel):
35
- """Learning rate hyperparameters configuration."""
36
-
37
- use_scheduler: bool = False
38
- warmup_iters_ratio: float | None = Field(
39
- default=None, gt=0, lt=1, validate_default=False
40
- )
41
- warmup_value: float | None = Field(default=None, gt=0, validate_default=False)
42
- base_value: float
43
- final_value: float | None = Field(default=None, gt=0, validate_default=False)
44
-
45
- @model_validator(mode="after")
46
- def validate_warmup(self) -> "Lr":
47
- """Validates the warmup parameters based on use_scheduler."""
48
- if (self.warmup_value is None) != (self.warmup_iters_ratio is None): # fmt: skip
49
- raise ValueError(
50
- "Both warmup_value and warmup_iters_ratio must be provided or neither"
51
- )
52
- if ((self.warmup_value is not None) or (self.warmup_iters_ratio is not None)) and not self.use_scheduler: # fmt: skip
53
- logger.warning(
54
- "use_scheduler is False, warmup_value and warmup_iters_ratio will be ignored."
55
- )
56
- self.warmup_value = None
57
- self.warmup_iters_ratio = None
58
- return self
59
-
60
- @model_validator(mode="after")
61
- def validate_final_value(self) -> "Lr":
62
- """Validates the final_value based on use_scheduler."""
63
- if self.use_scheduler and (self.final_value is None):
64
- raise ValueError("If use_scheduler is True, final_value must be provided.")
65
- if (not self.use_scheduler) and (self.final_value is not None):
66
- logger.warning("use_scheduler is False, final_value will be ignored.")
67
- self.final_value = None
68
- return self
69
-
70
-
71
- class WeightDecay(BaseModel):
72
- """Weight decay hyperparameters configuration."""
73
-
74
- use_scheduler: bool = False
75
- base_value: float
76
- final_value: float | None = None
77
-
78
- @model_validator(mode="after")
79
- def validate_final_value(self) -> "WeightDecay":
80
- """Validates the final_value based on use_scheduler."""
81
- if self.use_scheduler and self.final_value is None:
82
- raise ValueError("If use_scheduler is True, final_value must be provided.")
83
- if not self.use_scheduler and self.final_value is not None:
84
- logger.warning("use_scheduler is False, final_value will be ignored.")
85
- return self
86
-
87
-
88
- class HyperparamsConfig(BaseModel):
89
- """Model training hyperparameters configuration."""
90
-
91
- grad_clip_val: float | None = Field(default=None, gt=0, validate_default=False)
92
- optimizer: Optimizer
93
- lr: Lr
94
- weight_decay: WeightDecay
@@ -1,7 +0,0 @@
1
- try:
2
- import clearml # noqa: F401
3
- except ImportError as e:
4
- raise ImportError(
5
- "ClearML integration requires the 'clearml' package. "
6
- "Please install it via 'pip install clearml'."
7
- ) from e
@@ -1,18 +0,0 @@
1
- from .composite import CompositeScheduler
2
- from .cosine import CosineParamScheduler
3
- from .cosine import CosineScheduler
4
- from .linear import LinearParamScheduler
5
- from .linear import LinearScheduler
6
- from .plateau import PlateauWithAnnealingParamScheduler
7
- from .plateau import PlateauWithAnnealingScheduler
8
-
9
-
10
- __all__ = [
11
- "CompositeScheduler",
12
- "CosineParamScheduler",
13
- "CosineScheduler",
14
- "LinearParamScheduler",
15
- "LinearScheduler",
16
- "PlateauWithAnnealingParamScheduler",
17
- "PlateauWithAnnealingScheduler",
18
- ]