kostyl-toolkit 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. kostyl_toolkit-0.1.0/PKG-INFO +102 -0
  2. kostyl_toolkit-0.1.0/README.md +89 -0
  3. kostyl_toolkit-0.1.0/kostyl/__init__.py +0 -0
  4. kostyl_toolkit-0.1.0/kostyl/ml_core/__init__.py +0 -0
  5. kostyl_toolkit-0.1.0/kostyl/ml_core/clearml/__init__.py +0 -0
  6. kostyl_toolkit-0.1.0/kostyl/ml_core/clearml/logging_utils.py +46 -0
  7. kostyl_toolkit-0.1.0/kostyl/ml_core/clearml/pulling_utils.py +83 -0
  8. kostyl_toolkit-0.1.0/kostyl/ml_core/configs/__init__.py +30 -0
  9. kostyl_toolkit-0.1.0/kostyl/ml_core/configs/config_mixins.py +146 -0
  10. kostyl_toolkit-0.1.0/kostyl/ml_core/configs/hyperparams.py +84 -0
  11. kostyl_toolkit-0.1.0/kostyl/ml_core/configs/training_params.py +110 -0
  12. kostyl_toolkit-0.1.0/kostyl/ml_core/dist_utils.py +72 -0
  13. kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/__init__.py +5 -0
  14. kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/callbacks/__init__.py +10 -0
  15. kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/callbacks/checkpoint.py +56 -0
  16. kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/callbacks/early_stopping.py +18 -0
  17. kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/callbacks/registry_uploading.py +126 -0
  18. kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/extenstions/__init__.py +5 -0
  19. kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/extenstions/custom_module.py +179 -0
  20. kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/extenstions/pretrained_model.py +115 -0
  21. kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/loggers/__init__.py +0 -0
  22. kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/loggers/tb_logger.py +31 -0
  23. kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/steps_estimation.py +44 -0
  24. kostyl_toolkit-0.1.0/kostyl/ml_core/metrics_formatting.py +41 -0
  25. kostyl_toolkit-0.1.0/kostyl/ml_core/params_groups.py +50 -0
  26. kostyl_toolkit-0.1.0/kostyl/ml_core/schedulers/__init__.py +6 -0
  27. kostyl_toolkit-0.1.0/kostyl/ml_core/schedulers/base.py +48 -0
  28. kostyl_toolkit-0.1.0/kostyl/ml_core/schedulers/composite.py +68 -0
  29. kostyl_toolkit-0.1.0/kostyl/ml_core/schedulers/cosine.py +219 -0
  30. kostyl_toolkit-0.1.0/kostyl/utils/__init__.py +10 -0
  31. kostyl_toolkit-0.1.0/kostyl/utils/dict_manipulations.py +40 -0
  32. kostyl_toolkit-0.1.0/kostyl/utils/logging.py +147 -0
  33. kostyl_toolkit-0.1.0/pyproject.toml +105 -0
@@ -0,0 +1,102 @@
1
+ Metadata-Version: 2.3
2
+ Name: kostyl-toolkit
3
+ Version: 0.1.0
4
+ Summary: Kickass Orchestration System for Training, Yielding & Logging
5
+ Requires-Dist: case-converter>=1.2.0
6
+ Requires-Dist: clearml[s3]>=2.0.2
7
+ Requires-Dist: lightning>=2.5.6
8
+ Requires-Dist: loguru>=0.7.3
9
+ Requires-Dist: pydantic>=2.12.4
10
+ Requires-Dist: transformers>=4.57.1
11
+ Requires-Python: >=3.12
12
+ Description-Content-Type: text/markdown
13
+
14
+ # Kostyl Toolkit
15
+
16
+ Kickass Orchestration System for Training, Yielding & Logging — a batteries-included toolbox that glues PyTorch Lightning, Hugging Face Transformers, and ClearML into a single workflow.
17
+
18
+ ## Overview
19
+ - Rapidly bootstrap Lightning experiments with opinionated defaults (`KostylLightningModule`, custom schedulers, grad clipping and metric formatting).
20
+ - Keep model configs source-controlled via Pydantic mixins, with ClearML syncing out of the box (`ConfigLoadingMixin`, `ClearMLConfigMixin`).
21
+ - Reuse Lightning checkpoints directly inside Transformers models through `LightningCheckpointLoaderMixin`.
22
+ - Ship distributed-friendly utilities (deterministic logging, FSDP helpers, LR scaling, ClearML tag management).
23
+
24
+ ## Installation
25
+ ```bash
26
+ # Latest release from PyPI
27
+ pip install kostyl-toolkit
28
+
29
+ # or with uv
30
+ uv pip install kostyl-toolkit
31
+ ```
32
+
33
+ Development setup:
34
+ ```bash
35
+ uv sync # creates the virtualenv declared in pyproject.toml
36
+ source .venv/bin/activate.fish
37
+ pre-commit install # optional but recommended
38
+ ```
39
+
40
+ ## Quick Start
41
+ ```python
42
+ from lightning import Trainer
43
+ from transformers import AutoModelForSequenceClassification
44
+
45
+ from kostyl.ml_core.configs.hyperparams import HyperparamsConfig
46
+ from kostyl.ml_core.configs.training_params import TrainingParams
47
+ from kostyl.ml_core.lightning.extenstions.custom_module import KostylLightningModule
48
+
49
+
50
+ class TextClassifier(KostylLightningModule):
51
+ def __init__(self, hyperparams: HyperparamsConfig):
52
+ super().__init__()
53
+ self.hyperparams = hyperparams # grad clipping + scheduler knobs
54
+ self.model = AutoModelForSequenceClassification.from_pretrained(
55
+ "distilbert-base-uncased",
56
+ num_labels=2,
57
+ )
58
+
59
+ def training_step(self, batch, batch_idx):
60
+ outputs = self.model(**batch)
61
+ self.log("train/loss", outputs.loss)
62
+ return outputs.loss
63
+
64
+ train_cfg = TrainingParams.from_file("configs/training.yaml")
65
+ hyperparams = HyperparamsConfig.from_file("configs/hyperparams.yaml")
66
+
67
+ module = TextClassifier(hyperparams)
68
+
69
+ trainer = Trainer(**train_cfg.trainer.model_dump())
70
+ trainer.fit(module)
71
+ ```
72
+
73
+ Restoring a plain Transformers model from a Lightning checkpoint:
74
+ ```python
75
+ from kostyl.ml_core.lightning.extenstions.pretrained_model import LightningCheckpointLoaderMixin
76
+
77
+
78
+ model = LightningCheckpointLoaderMixin.from_lighting_checkpoint(
79
+ "checkpoints/epoch=03-step=500.ckpt",
80
+ config_key="config",
81
+ weights_prefix="model.",
82
+ )
83
+ ```
84
+
85
+ ## Components
86
+ - **Configurations** (`kostyl/ml_core/configs`): strongly-typed training, optimizer, and scheduler configs with ClearML syncing helpers.
87
+ - **Lightning Extensions** (`kostyl/ml_core/lightning`): custom LightningModule base class, callbacks, logging bridges, and the checkpoint loader mixin.
88
+ - **Schedulers** (`kostyl/ml_core/schedulers`): extensible LR schedulers (base/composite/cosine) with serialization helpers and on-step logging.
89
+ - **ClearML Utilities** (`kostyl/ml_core/clearml`): tag/version helpers and logging bridges for ClearML Tasks.
90
+ - **Distributed + Metrics Utils** (`kostyl/ml_core/dist_utils.py`, `metrics_formatting.py`): world-size-aware LR scaling, rank-aware metric naming, and per-class formatting.
91
+ - **Logging Helpers** (`kostyl/utils/logging.py`): rank-aware Loguru setup and uniform handling of incompatible checkpoint keys.
92
+
93
+ ## Project Layout
94
+ ```
95
+ kostyl/
96
+ ml_core/
97
+ configs/ # Pydantic configs + ClearML mixins
98
+ lightning/ # Lightning module, callbacks, loggers, extensions
99
+ schedulers/ # Base + composite/cosine schedulers
100
+ clearml/ # Logging + pulling utilities
101
+ utils/ # Dict helpers, logging utilities
102
+ ```
@@ -0,0 +1,89 @@
1
+ # Kostyl Toolkit
2
+
3
+ Kickass Orchestration System for Training, Yielding & Logging — a batteries-included toolbox that glues PyTorch Lightning, Hugging Face Transformers, and ClearML into a single workflow.
4
+
5
+ ## Overview
6
+ - Rapidly bootstrap Lightning experiments with opinionated defaults (`KostylLightningModule`, custom schedulers, grad clipping and metric formatting).
7
+ - Keep model configs source-controlled via Pydantic mixins, with ClearML syncing out of the box (`ConfigLoadingMixin`, `ClearMLConfigMixin`).
8
+ - Reuse Lightning checkpoints directly inside Transformers models through `LightningCheckpointLoaderMixin`.
9
+ - Ship distributed-friendly utilities (deterministic logging, FSDP helpers, LR scaling, ClearML tag management).
10
+
11
+ ## Installation
12
+ ```bash
13
+ # Latest release from PyPI
14
+ pip install kostyl-toolkit
15
+
16
+ # or with uv
17
+ uv pip install kostyl-toolkit
18
+ ```
19
+
20
+ Development setup:
21
+ ```bash
22
+ uv sync # creates the virtualenv declared in pyproject.toml
23
+ source .venv/bin/activate.fish
24
+ pre-commit install # optional but recommended
25
+ ```
26
+
27
+ ## Quick Start
28
+ ```python
29
+ from lightning import Trainer
30
+ from transformers import AutoModelForSequenceClassification
31
+
32
+ from kostyl.ml_core.configs.hyperparams import HyperparamsConfig
33
+ from kostyl.ml_core.configs.training_params import TrainingParams
34
+ from kostyl.ml_core.lightning.extenstions.custom_module import KostylLightningModule
35
+
36
+
37
+ class TextClassifier(KostylLightningModule):
38
+ def __init__(self, hyperparams: HyperparamsConfig):
39
+ super().__init__()
40
+ self.hyperparams = hyperparams # grad clipping + scheduler knobs
41
+ self.model = AutoModelForSequenceClassification.from_pretrained(
42
+ "distilbert-base-uncased",
43
+ num_labels=2,
44
+ )
45
+
46
+ def training_step(self, batch, batch_idx):
47
+ outputs = self.model(**batch)
48
+ self.log("train/loss", outputs.loss)
49
+ return outputs.loss
50
+
51
+ train_cfg = TrainingParams.from_file("configs/training.yaml")
52
+ hyperparams = HyperparamsConfig.from_file("configs/hyperparams.yaml")
53
+
54
+ module = TextClassifier(hyperparams)
55
+
56
+ trainer = Trainer(**train_cfg.trainer.model_dump())
57
+ trainer.fit(module)
58
+ ```
59
+
60
+ Restoring a plain Transformers model from a Lightning checkpoint:
61
+ ```python
62
+ from kostyl.ml_core.lightning.extenstions.pretrained_model import LightningCheckpointLoaderMixin
63
+
64
+
65
+ model = LightningCheckpointLoaderMixin.from_lighting_checkpoint(
66
+ "checkpoints/epoch=03-step=500.ckpt",
67
+ config_key="config",
68
+ weights_prefix="model.",
69
+ )
70
+ ```
71
+
72
+ ## Components
73
+ - **Configurations** (`kostyl/ml_core/configs`): strongly-typed training, optimizer, and scheduler configs with ClearML syncing helpers.
74
+ - **Lightning Extensions** (`kostyl/ml_core/lightning`): custom LightningModule base class, callbacks, logging bridges, and the checkpoint loader mixin.
75
+ - **Schedulers** (`kostyl/ml_core/schedulers`): extensible LR schedulers (base/composite/cosine) with serialization helpers and on-step logging.
76
+ - **ClearML Utilities** (`kostyl/ml_core/clearml`): tag/version helpers and logging bridges for ClearML Tasks.
77
+ - **Distributed + Metrics Utils** (`kostyl/ml_core/dist_utils.py`, `metrics_formatting.py`): world-size-aware LR scaling, rank-aware metric naming, and per-class formatting.
78
+ - **Logging Helpers** (`kostyl/utils/logging.py`): rank-aware Loguru setup and uniform handling of incompatible checkpoint keys.
79
+
80
+ ## Project Layout
81
+ ```
82
+ kostyl/
83
+ ml_core/
84
+ configs/ # Pydantic configs + ClearML mixins
85
+ lightning/ # Lightning module, callbacks, loggers, extensions
86
+ schedulers/ # Base + composite/cosine schedulers
87
+ clearml/ # Logging + pulling utilities
88
+ utils/ # Dict helpers, logging utilities
89
+ ```
File without changes
File without changes
@@ -0,0 +1,46 @@
1
+ import re
2
+
3
+ from kostyl.utils import setup_logger
4
+
5
+
6
+ logger = setup_logger(name="clearml_logging_utils.py", fmt="only_message")
7
+
8
+
9
+ def increment_version(s: str) -> str:
10
+ """
11
+ Increments the minor part of a version string.
12
+
13
+ Examples:
14
+ v1.00 -> v1.01
15
+ v2.99 -> v2.100
16
+ v.3.009 -> v.3.010
17
+
18
+ """
19
+ s = s.strip()
20
+ m = re.fullmatch(r"v(\.?)(\d+)\.(\d+)", s)
21
+ if not m:
22
+ raise ValueError(f"Invalid version format: {s!r}. Expected 'vX.Y' or 'v.X.Y'.")
23
+
24
+ vdot, major_str, minor_str = m.groups()
25
+ major = int(major_str)
26
+ minor = int(minor_str) + 1
27
+
28
+ # preserve leading zeros based on original width, length may increase (99 -> 100)
29
+ minor_out = str(minor).zfill(len(minor_str))
30
+ prefix = f"v{vdot}" # 'v' or 'v.'
31
+ return f"{prefix}{major}.{minor_out}"
32
+
33
+
34
+ def find_version_in_tags(tags: list[str]) -> str | None:
35
+ """
36
+ Finds the first version tag in the list of tags.
37
+
38
+ Note:
39
+ Version tags must be in the format 'vX.Y' or 'v.X.Y' (an optional dot after 'v' is supported).
40
+
41
+ """
42
+ version_pattern = re.compile(r"^v(\.?)(\d+)\.(\d+)$")
43
+ for tag in tags:
44
+ if version_pattern.match(tag):
45
+ return tag
46
+ return None
@@ -0,0 +1,83 @@
1
+ from pathlib import Path
2
+
3
+ from clearml import InputModel
4
+ from clearml import Task
5
+ from transformers import AutoTokenizer
6
+ from transformers import PreTrainedModel
7
+ from transformers import PreTrainedTokenizerBase
8
+
9
+ from kostyl.ml_core.lightning.extenstions.pretrained_model import (
10
+ LightningCheckpointLoaderMixin,
11
+ )
12
+
13
+
14
+ def get_tokenizer_from_clearml(
15
+ model_id: str, task: Task | None = None, ignore_remote_overrides: bool = True
16
+ ) -> PreTrainedTokenizerBase:
17
+ """
18
+ Retrieve a Hugging Face tokenizer stored in a ClearML.
19
+
20
+ Args:
21
+ model_id (str): The ClearML InputModel identifier that holds the tokenizer artifacts.
22
+ task (Task | None, optional): An optional ClearML Task used to associate and sync
23
+ the model. Defaults to None.
24
+ ignore_remote_overrides (bool, optional): Whether to ignore remote hyperparameter
25
+ overrides when connecting the ClearML task. Defaults to True.
26
+
27
+ Returns:
28
+ PreTrainedTokenizerBase: The instantiated tokenizer loaded from the local copy
29
+ of the referenced ClearML InputModel.
30
+
31
+ """
32
+ clearml_tokenizer = InputModel(model_id=model_id)
33
+ if task is not None:
34
+ clearml_tokenizer.connect(task, ignore_remote_overrides=ignore_remote_overrides)
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(
37
+ clearml_tokenizer.get_local_copy(raise_on_error=True)
38
+ )
39
+ return tokenizer
40
+
41
+
42
+ def get_model_from_clearml[TModel: PreTrainedModel | LightningCheckpointLoaderMixin](
43
+ model_id: str,
44
+ model: type[TModel],
45
+ task: Task | None = None,
46
+ ignore_remote_overrides: bool = True,
47
+ ) -> TModel:
48
+ """
49
+ Retrieve a pretrained model from ClearML and instantiate it using the appropriate loader.
50
+
51
+ Args:
52
+ model_id: Identifier of the ClearML input model to retrieve.
53
+ model: The model class that implements either PreTrainedModel or LightningCheckpointLoaderMixin.
54
+ task: Optional ClearML task used to resolve the input model reference. If provided, the input model
55
+ will be connected to this task, with remote overrides optionally ignored.
56
+ ignore_remote_overrides: When connecting the input model to the provided task, determines whether
57
+ remote configuration overrides should be ignored.
58
+
59
+ Returns:
60
+ An instantiated model loaded either from a ClearML package directory or a Lightning checkpoint.
61
+
62
+ """
63
+ input_model = InputModel(model_id=model_id)
64
+
65
+ if task is not None:
66
+ input_model.connect(task, ignore_remote_overrides=ignore_remote_overrides)
67
+
68
+ local_path = Path(input_model.get_local_copy(raise_on_error=True))
69
+
70
+ if local_path.is_dir() and input_model._is_package():
71
+ model_instance = model.from_pretrained(local_path)
72
+ elif local_path.suffix == ".ckpt":
73
+ if not issubclass(model, LightningCheckpointLoaderMixin):
74
+ raise ValueError(
75
+ f"Model class {model.__name__} is not compatible with Lightning checkpoints."
76
+ )
77
+ model_instance = model.from_lighting_checkpoint(local_path)
78
+ else:
79
+ raise ValueError(
80
+ f"Unsupported model format for path: {local_path}. "
81
+ "Expected a ClearML package directory or a .ckpt file."
82
+ )
83
+ return model_instance
@@ -0,0 +1,30 @@
1
+ from .config_mixins import ConfigLoadingMixin
2
+ from .hyperparams import HyperparamsConfig
3
+ from .hyperparams import Lr
4
+ from .hyperparams import Optimizer
5
+ from .hyperparams import WeightDecay
6
+ from .training_params import CheckpointConfig
7
+ from .training_params import ClearMLTrainingParameters
8
+ from .training_params import DataConfig
9
+ from .training_params import DDPStrategyConfig
10
+ from .training_params import EarlyStoppingConfig
11
+ from .training_params import FSDP1StrategyConfig
12
+ from .training_params import SingleDeviceStrategyConfig
13
+ from .training_params import TrainingParams
14
+
15
+
16
+ __all__ = [
17
+ "CheckpointConfig",
18
+ "ClearMLTrainingParameters",
19
+ "ConfigLoadingMixin",
20
+ "DDPStrategyConfig",
21
+ "DataConfig",
22
+ "EarlyStoppingConfig",
23
+ "FSDP1StrategyConfig",
24
+ "HyperparamsConfig",
25
+ "Lr",
26
+ "Optimizer",
27
+ "SingleDeviceStrategyConfig",
28
+ "TrainingParams",
29
+ "WeightDecay",
30
+ ]
@@ -0,0 +1,146 @@
1
+ from pathlib import Path
2
+
3
+ import clearml
4
+ import yaml
5
+ from caseconverter import pascalcase
6
+ from caseconverter import snakecase
7
+ from pydantic import BaseModel
8
+
9
+ from kostyl.utils import convert_to_flat_dict
10
+ from kostyl.utils import flattened_dict_to_nested
11
+
12
+
13
+ def load_config(path: Path | str) -> dict:
14
+ """Load a configuration from file."""
15
+ if isinstance(path, str):
16
+ path = Path(path)
17
+
18
+ if not path.is_file():
19
+ raise ValueError(f"Config file {path} does not exist or is not a file.")
20
+
21
+ match path.suffix:
22
+ case ".yaml" | ".yml":
23
+ config = yaml.safe_load(path.open("r"))
24
+ case _:
25
+ raise ValueError(f"Unsupported config file format: {path.suffix}")
26
+ return config
27
+
28
+
29
+ class ConfigLoadingMixin[TConfig: ConfigLoadingMixin](BaseModel):
30
+ """Pydantic mixin class providing basic configuration loading functionality."""
31
+
32
+ @classmethod
33
+ def from_file(
34
+ cls: type[TConfig],
35
+ path: str | Path,
36
+ ) -> TConfig:
37
+ """
38
+ Create an instance of the class from a configuration file.
39
+
40
+ Args:
41
+ path (str | Path): Path to the configuration file.
42
+
43
+ Returns:
44
+ An instance of the class created from the configuration file.
45
+
46
+ """
47
+ config = load_config(path)
48
+ instance = cls.model_validate(config)
49
+ return instance
50
+
51
+ @classmethod
52
+ def from_dict(
53
+ cls: type[TConfig],
54
+ state_dict: dict,
55
+ ) -> TConfig:
56
+ """
57
+ Creates an instance from a dictionary.
58
+
59
+ Args:
60
+ state_dict (dict): A dictionary representing the state of the
61
+ class that must be validated and used for initialization.
62
+
63
+ Returns:
64
+ An initialized instance of the class based on the
65
+ provided state dictionary.
66
+
67
+ """
68
+ instance = cls.model_validate(state_dict)
69
+ return instance
70
+
71
+
72
+ class ClearMLConfigMixin[TConfig: ClearMLConfigMixin](ConfigLoadingMixin[TConfig]):
73
+ """Pydantic mixin class providing ClearML configuration loading and syncing functionality."""
74
+
75
+ @classmethod
76
+ def connect_as_file(
77
+ cls: type[TConfig],
78
+ task: clearml.Task,
79
+ path: str | Path,
80
+ alias: str | None = None,
81
+ ) -> TConfig:
82
+ """
83
+ Connects the configuration file to a ClearML task and creates an instance of the class from it.
84
+
85
+ This method connects the specified configuration file to the given ClearML task for version control and monitoring,
86
+ then loads and validates the configuration to the class.
87
+
88
+ Args:
89
+ cls: The class type to instantiate.
90
+ task: The ClearML Task object to connect the configuration to.
91
+ path: Path to the configuration file (supports YAML format).
92
+ alias: Optional alias for the configuration in ClearML. Defaults to PascalCase of the class name if None.
93
+
94
+ Returns:
95
+ An instance of the class created from the connected configuration file.
96
+
97
+ """
98
+ if isinstance(path, Path):
99
+ str_path = str(path)
100
+ else:
101
+ str_path = path
102
+
103
+ name = alias if alias is not None else pascalcase(cls.__name__)
104
+ connected_path = task.connect_configuration(str_path, name=pascalcase(name))
105
+
106
+ if not isinstance(connected_path, str):
107
+ connected_path_str = str(connected_path)
108
+ else:
109
+ connected_path_str = connected_path
110
+
111
+ model = cls.from_file(path=connected_path_str)
112
+ return model
113
+
114
+ @classmethod
115
+ def connect_as_dict(
116
+ cls: type[TConfig],
117
+ task: clearml.Task,
118
+ path: str | Path,
119
+ alias: str | None = None,
120
+ ) -> TConfig:
121
+ """
122
+ Connects configuration from a file as a dictionary to a ClearML task and creates an instance of the class.
123
+
124
+ This class method loads configuration from a file as a dictionary, flattens and sync them with ClearML
125
+ task parameters. Then it creates an instance of the class using the synced dictionary.
126
+
127
+ Args:
128
+ cls: The class type of the model to be created (must be a TRetuningModel subclass).
129
+ task: The ClearML task to connect the configuration to.
130
+ path: Path to the configuration file to load parameters from.
131
+ alias: Optional alias name for the configuration. If None, uses snake_case of class name.
132
+
133
+ Returns:
134
+ An instance of the specified class created from the loaded configuration.
135
+
136
+ """
137
+ name = alias if alias is not None else snakecase(cls.__name__)
138
+
139
+ config = load_config(path)
140
+
141
+ flattened_config = convert_to_flat_dict(config)
142
+ task.connect(flattened_config, name=pascalcase(name))
143
+ config = flattened_dict_to_nested(flattened_config)
144
+
145
+ model = cls.from_dict(config)
146
+ return model
@@ -0,0 +1,84 @@
1
+ from pydantic import BaseModel
2
+ from pydantic import Field
3
+ from pydantic import model_validator
4
+
5
+ from kostyl.utils.logging import setup_logger
6
+
7
+ from .config_mixins import ClearMLConfigMixin
8
+
9
+
10
+ logger = setup_logger(fmt="only_message")
11
+
12
+
13
+ class Optimizer(BaseModel):
14
+ """Optimizer hyperparameters configuration."""
15
+
16
+ adamw_beta1: float = 0.9
17
+ adamw_beta2: float = 0.999
18
+
19
+
20
+ class Lr(BaseModel):
21
+ """Learning rate hyperparameters configuration."""
22
+
23
+ use_scheduler: bool = False
24
+ warmup_iters_ratio: float | None = Field(
25
+ default=None, gt=0, lt=1, validate_default=False
26
+ )
27
+ warmup_value: float | None = Field(default=None, gt=0, validate_default=False)
28
+ base_value: float
29
+ final_value: float | None = Field(default=None, gt=0, validate_default=False)
30
+
31
+ @model_validator(mode="after")
32
+ def validate_warmup(self) -> "Lr":
33
+ """Validates the warmup parameters based on use_scheduler."""
34
+ if (self.warmup_value is None) != (
35
+ self.warmup_iters_ratio is None
36
+ ) and self.use_scheduler:
37
+ raise ValueError(
38
+ "Both warmup_value and warmup_iters_ratio must be provided or neither"
39
+ )
40
+ elif (
41
+ (self.warmup_value is not None) or (self.warmup_iters_ratio is not None)
42
+ ) and (not self.use_scheduler):
43
+ logger.warning(
44
+ "use_scheduler is False, warmup_value and warmup_iters_ratio will be ignored."
45
+ )
46
+ self.warmup_value = None
47
+ self.warmup_iters_ratio = None
48
+ return self
49
+
50
+ @model_validator(mode="after")
51
+ def validate_final_value(self) -> "Lr":
52
+ """Validates the final_value based on use_scheduler."""
53
+ if self.use_scheduler and (self.final_value is None):
54
+ raise ValueError("If use_scheduler is True, final_value must be provided.")
55
+ if (not self.use_scheduler) and (self.final_value is not None):
56
+ logger.warning("use_scheduler is False, final_value will be ignored.")
57
+ self.final_value = None
58
+ return self
59
+
60
+
61
+ class WeightDecay(BaseModel):
62
+ """Weight decay hyperparameters configuration."""
63
+
64
+ use_scheduler: bool = False
65
+ base_value: float
66
+ final_value: float | None = None
67
+
68
+ @model_validator(mode="after")
69
+ def validate_final_value(self) -> "WeightDecay":
70
+ """Validates the final_value based on use_scheduler."""
71
+ if self.use_scheduler and self.final_value is None:
72
+ raise ValueError("If use_scheduler is True, final_value must be provided.")
73
+ if not self.use_scheduler and self.final_value is not None:
74
+ logger.warning("use_scheduler is False, final_value will be ignored.")
75
+ return self
76
+
77
+
78
+ class HyperparamsConfig(ClearMLConfigMixin["HyperparamsConfig"]):
79
+ """Model training hyperparameters configuration."""
80
+
81
+ grad_clip_val: float | None = Field(default=None, gt=0, validate_default=False)
82
+ optimizer: Optimizer = Optimizer()
83
+ lr: Lr
84
+ weight_decay: WeightDecay