kostyl-toolkit 0.1.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. kostyl_toolkit-0.1.15/PKG-INFO +105 -0
  2. kostyl_toolkit-0.1.15/README.md +89 -0
  3. kostyl_toolkit-0.1.15/kostyl/__init__.py +0 -0
  4. kostyl_toolkit-0.1.15/kostyl/ml/__init__.py +0 -0
  5. kostyl_toolkit-0.1.15/kostyl/ml/clearml/__init__.py +0 -0
  6. kostyl_toolkit-0.1.15/kostyl/ml/clearml/dataset_utils.py +58 -0
  7. kostyl_toolkit-0.1.15/kostyl/ml/clearml/logging_utils.py +46 -0
  8. kostyl_toolkit-0.1.15/kostyl/ml/clearml/pulling_utils.py +91 -0
  9. kostyl_toolkit-0.1.15/kostyl/ml/configs/__init__.py +30 -0
  10. kostyl_toolkit-0.1.15/kostyl/ml/configs/base_model.py +143 -0
  11. kostyl_toolkit-0.1.15/kostyl/ml/configs/hyperparams.py +84 -0
  12. kostyl_toolkit-0.1.15/kostyl/ml/configs/training_settings.py +104 -0
  13. kostyl_toolkit-0.1.15/kostyl/ml/dist_utils.py +99 -0
  14. kostyl_toolkit-0.1.15/kostyl/ml/lightning/__init__.py +5 -0
  15. kostyl_toolkit-0.1.15/kostyl/ml/lightning/callbacks/__init__.py +10 -0
  16. kostyl_toolkit-0.1.15/kostyl/ml/lightning/callbacks/checkpoint.py +56 -0
  17. kostyl_toolkit-0.1.15/kostyl/ml/lightning/callbacks/early_stopping.py +18 -0
  18. kostyl_toolkit-0.1.15/kostyl/ml/lightning/callbacks/registry_uploading.py +118 -0
  19. kostyl_toolkit-0.1.15/kostyl/ml/lightning/extenstions/__init__.py +5 -0
  20. kostyl_toolkit-0.1.15/kostyl/ml/lightning/extenstions/custom_module.py +175 -0
  21. kostyl_toolkit-0.1.15/kostyl/ml/lightning/extenstions/pretrained_model.py +125 -0
  22. kostyl_toolkit-0.1.15/kostyl/ml/lightning/loggers/__init__.py +4 -0
  23. kostyl_toolkit-0.1.15/kostyl/ml/lightning/loggers/tb_logger.py +31 -0
  24. kostyl_toolkit-0.1.15/kostyl/ml/lightning/steps_estimation.py +44 -0
  25. kostyl_toolkit-0.1.15/kostyl/ml/metrics_formatting.py +41 -0
  26. kostyl_toolkit-0.1.15/kostyl/ml/params_groups.py +93 -0
  27. kostyl_toolkit-0.1.15/kostyl/ml/schedulers/__init__.py +6 -0
  28. kostyl_toolkit-0.1.15/kostyl/ml/schedulers/base.py +48 -0
  29. kostyl_toolkit-0.1.15/kostyl/ml/schedulers/composite.py +68 -0
  30. kostyl_toolkit-0.1.15/kostyl/ml/schedulers/cosine.py +219 -0
  31. kostyl_toolkit-0.1.15/kostyl/utils/__init__.py +10 -0
  32. kostyl_toolkit-0.1.15/kostyl/utils/dict_manipulations.py +40 -0
  33. kostyl_toolkit-0.1.15/kostyl/utils/fs.py +19 -0
  34. kostyl_toolkit-0.1.15/kostyl/utils/logging.py +177 -0
  35. kostyl_toolkit-0.1.15/pyproject.toml +124 -0
@@ -0,0 +1,105 @@
1
+ Metadata-Version: 2.3
2
+ Name: kostyl-toolkit
3
+ Version: 0.1.15
4
+ Summary: Kickass Orchestration System for Training, Yielding & Logging
5
+ Requires-Dist: case-converter>=1.2.0
6
+ Requires-Dist: loguru>=0.7.3
7
+ Requires-Dist: case-converter>=1.2.0 ; extra == 'ml-core'
8
+ Requires-Dist: clearml[s3]>=2.0.2 ; extra == 'ml-core'
9
+ Requires-Dist: lightning>=2.5.6 ; extra == 'ml-core'
10
+ Requires-Dist: pydantic>=2.12.4 ; extra == 'ml-core'
11
+ Requires-Dist: torch>=2.9.1 ; extra == 'ml-core'
12
+ Requires-Dist: transformers>=4.57.1 ; extra == 'ml-core'
13
+ Requires-Python: >=3.12
14
+ Provides-Extra: ml-core
15
+ Description-Content-Type: text/markdown
16
+
17
+ # Kostyl Toolkit
18
+
19
+ Kickass Orchestration System for Training, Yielding & Logging — a batteries-included toolbox that glues PyTorch Lightning, Hugging Face Transformers, and ClearML into a single workflow.
20
+
21
+ ## Overview
22
+ - Rapidly bootstrap Lightning experiments with opinionated defaults (`KostylLightningModule`, custom schedulers, grad clipping and metric formatting).
23
+ - Keep model configs source-controlled via Pydantic mixins, with ClearML syncing out of the box (`ConfigLoadingMixin`, `ClearMLConfigMixin`).
24
+ - Reuse Lightning checkpoints directly inside Transformers models through `LightningCheckpointLoaderMixin`.
25
+ - Ship distributed-friendly utilities (deterministic logging, FSDP helpers, LR scaling, ClearML tag management).
26
+
27
+ ## Installation
28
+ ```bash
29
+ # Latest release from PyPI
30
+ pip install kostyl-toolkit
31
+
32
+ # or with uv
33
+ uv pip install kostyl-toolkit
34
+ ```
35
+
36
+ Development setup:
37
+ ```bash
38
+ uv sync # creates the virtualenv declared in pyproject.toml
39
+ source .venv/bin/activate.fish
40
+ pre-commit install # optional but recommended
41
+ ```
42
+
43
+ ## Quick Start
44
+ ```python
45
+ from lightning import Trainer
46
+ from transformers import AutoModelForSequenceClassification
47
+
48
+ from kostyl.ml_core.configs.hyperparams import HyperparamsConfig
49
+ from kostyl.ml_core.configs.training_params import TrainingParams
50
+ from kostyl.ml_core.lightning.extenstions.custom_module import KostylLightningModule
51
+
52
+
53
+ class TextClassifier(KostylLightningModule):
54
+ def __init__(self, hyperparams: HyperparamsConfig):
55
+ super().__init__()
56
+ self.hyperparams = hyperparams # grad clipping + scheduler knobs
57
+ self.model = AutoModelForSequenceClassification.from_pretrained(
58
+ "distilbert-base-uncased",
59
+ num_labels=2,
60
+ )
61
+
62
+ def training_step(self, batch, batch_idx):
63
+ outputs = self.model(**batch)
64
+ self.log("train/loss", outputs.loss)
65
+ return outputs.loss
66
+
67
+ train_cfg = TrainingParams.from_file("configs/training.yaml")
68
+ hyperparams = HyperparamsConfig.from_file("configs/hyperparams.yaml")
69
+
70
+ module = TextClassifier(hyperparams)
71
+
72
+ trainer = Trainer(**train_cfg.trainer.model_dump())
73
+ trainer.fit(module)
74
+ ```
75
+
76
+ Restoring a plain Transformers model from a Lightning checkpoint:
77
+ ```python
78
+ from kostyl.ml_core.lightning.extenstions.pretrained_model import LightningCheckpointLoaderMixin
79
+
80
+
81
+ model = LightningCheckpointLoaderMixin.from_lighting_checkpoint(
82
+ "checkpoints/epoch=03-step=500.ckpt",
83
+ config_key="config",
84
+ weights_prefix="model.",
85
+ )
86
+ ```
87
+
88
+ ## Components
89
+ - **Configurations** (`kostyl/ml_core/configs`): strongly-typed training, optimizer, and scheduler configs with ClearML syncing helpers.
90
+ - **Lightning Extensions** (`kostyl/ml_core/lightning`): custom LightningModule base class, callbacks, logging bridges, and the checkpoint loader mixin.
91
+ - **Schedulers** (`kostyl/ml_core/schedulers`): extensible LR schedulers (base/composite/cosine) with serialization helpers and on-step logging.
92
+ - **ClearML Utilities** (`kostyl/ml_core/clearml`): tag/version helpers and logging bridges for ClearML Tasks.
93
+ - **Distributed + Metrics Utils** (`kostyl/ml_core/dist_utils.py`, `metrics_formatting.py`): world-size-aware LR scaling, rank-aware metric naming, and per-class formatting.
94
+ - **Logging Helpers** (`kostyl/utils/logging.py`): rank-aware Loguru setup and uniform handling of incompatible checkpoint keys.
95
+
96
+ ## Project Layout
97
+ ```
98
+ kostyl/
99
+ ml_core/
100
+ configs/ # Pydantic configs + ClearML mixins
101
+ lightning/ # Lightning module, callbacks, loggers, extensions
102
+ schedulers/ # Base + composite/cosine schedulers
103
+ clearml/ # Logging + pulling utilities
104
+ utils/ # Dict helpers, logging utilities
105
+ ```
@@ -0,0 +1,89 @@
1
+ # Kostyl Toolkit
2
+
3
+ Kickass Orchestration System for Training, Yielding & Logging — a batteries-included toolbox that glues PyTorch Lightning, Hugging Face Transformers, and ClearML into a single workflow.
4
+
5
+ ## Overview
6
+ - Rapidly bootstrap Lightning experiments with opinionated defaults (`KostylLightningModule`, custom schedulers, grad clipping and metric formatting).
7
+ - Keep model configs source-controlled via Pydantic mixins, with ClearML syncing out of the box (`ConfigLoadingMixin`, `ClearMLConfigMixin`).
8
+ - Reuse Lightning checkpoints directly inside Transformers models through `LightningCheckpointLoaderMixin`.
9
+ - Ship distributed-friendly utilities (deterministic logging, FSDP helpers, LR scaling, ClearML tag management).
10
+
11
+ ## Installation
12
+ ```bash
13
+ # Latest release from PyPI
14
+ pip install kostyl-toolkit
15
+
16
+ # or with uv
17
+ uv pip install kostyl-toolkit
18
+ ```
19
+
20
+ Development setup:
21
+ ```bash
22
+ uv sync # creates the virtualenv declared in pyproject.toml
23
+ source .venv/bin/activate.fish
24
+ pre-commit install # optional but recommended
25
+ ```
26
+
27
+ ## Quick Start
28
+ ```python
29
+ from lightning import Trainer
30
+ from transformers import AutoModelForSequenceClassification
31
+
32
+ from kostyl.ml_core.configs.hyperparams import HyperparamsConfig
33
+ from kostyl.ml_core.configs.training_params import TrainingParams
34
+ from kostyl.ml_core.lightning.extenstions.custom_module import KostylLightningModule
35
+
36
+
37
+ class TextClassifier(KostylLightningModule):
38
+ def __init__(self, hyperparams: HyperparamsConfig):
39
+ super().__init__()
40
+ self.hyperparams = hyperparams # grad clipping + scheduler knobs
41
+ self.model = AutoModelForSequenceClassification.from_pretrained(
42
+ "distilbert-base-uncased",
43
+ num_labels=2,
44
+ )
45
+
46
+ def training_step(self, batch, batch_idx):
47
+ outputs = self.model(**batch)
48
+ self.log("train/loss", outputs.loss)
49
+ return outputs.loss
50
+
51
+ train_cfg = TrainingParams.from_file("configs/training.yaml")
52
+ hyperparams = HyperparamsConfig.from_file("configs/hyperparams.yaml")
53
+
54
+ module = TextClassifier(hyperparams)
55
+
56
+ trainer = Trainer(**train_cfg.trainer.model_dump())
57
+ trainer.fit(module)
58
+ ```
59
+
60
+ Restoring a plain Transformers model from a Lightning checkpoint:
61
+ ```python
62
+ from kostyl.ml_core.lightning.extenstions.pretrained_model import LightningCheckpointLoaderMixin
63
+
64
+
65
+ model = LightningCheckpointLoaderMixin.from_lighting_checkpoint(
66
+ "checkpoints/epoch=03-step=500.ckpt",
67
+ config_key="config",
68
+ weights_prefix="model.",
69
+ )
70
+ ```
71
+
72
+ ## Components
73
+ - **Configurations** (`kostyl/ml_core/configs`): strongly-typed training, optimizer, and scheduler configs with ClearML syncing helpers.
74
+ - **Lightning Extensions** (`kostyl/ml_core/lightning`): custom LightningModule base class, callbacks, logging bridges, and the checkpoint loader mixin.
75
+ - **Schedulers** (`kostyl/ml_core/schedulers`): extensible LR schedulers (base/composite/cosine) with serialization helpers and on-step logging.
76
+ - **ClearML Utilities** (`kostyl/ml_core/clearml`): tag/version helpers and logging bridges for ClearML Tasks.
77
+ - **Distributed + Metrics Utils** (`kostyl/ml_core/dist_utils.py`, `metrics_formatting.py`): world-size-aware LR scaling, rank-aware metric naming, and per-class formatting.
78
+ - **Logging Helpers** (`kostyl/utils/logging.py`): rank-aware Loguru setup and uniform handling of incompatible checkpoint keys.
79
+
80
+ ## Project Layout
81
+ ```
82
+ kostyl/
83
+ ml_core/
84
+ configs/ # Pydantic configs + ClearML mixins
85
+ lightning/ # Lightning module, callbacks, loggers, extensions
86
+ schedulers/ # Base + composite/cosine schedulers
87
+ clearml/ # Logging + pulling utilities
88
+ utils/ # Dict helpers, logging utilities
89
+ ```
File without changes
File without changes
File without changes
@@ -0,0 +1,58 @@
1
+ from collections.abc import Collection
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ from concurrent.futures import as_completed
4
+ from pathlib import Path
5
+
6
+ from clearml import Dataset as ClearMLDataset
7
+
8
+
9
+ def collect_clearml_datasets(
10
+ datasets_mapping: dict[str, str],
11
+ ) -> dict[str, ClearMLDataset]:
12
+ """
13
+ Collect ClearML datasets by dataset ID.
14
+
15
+ Args:
16
+ datasets_mapping: Mapping where keys are human-readable names and values
17
+ are ClearML dataset IDs.
18
+
19
+ Returns:
20
+ A mapping of dataset names to fetched `ClearMLDataset` instances.
21
+
22
+ """
23
+ datasets_list = {}
24
+ for name, dataset_id in datasets_mapping.items():
25
+ clearml_dataset = ClearMLDataset.get(dataset_id, alias=name)
26
+ datasets_list[name] = clearml_dataset
27
+ return datasets_list
28
+
29
+
30
+ def download_clearml_datasets(datasets: Collection[ClearMLDataset]) -> None:
31
+ """
32
+ Download all ClearML datasets in parallel.
33
+
34
+ Args:
35
+ datasets: Collection of initialized `ClearMLDataset` instances to download.
36
+
37
+ """
38
+ with ThreadPoolExecutor() as executor:
39
+ futures = [executor.submit(ds.get_local_copy) for ds in datasets]
40
+ for future in as_completed(futures):
41
+ future.result()
42
+ return
43
+
44
+
45
+ def get_datasets_paths(datasets_mapping: dict[str, ClearMLDataset]) -> dict[str, Path]:
46
+ """
47
+ Return local filesystem paths for ClearML datasets.
48
+
49
+ Args:
50
+ datasets_mapping: Mapping of dataset names to initialized
51
+ `ClearMLDataset` instances.
52
+
53
+ Returns:
54
+ Mapping of dataset names to local `Path` objects pointing to the
55
+ downloaded dataset copies.
56
+
57
+ """
58
+ return {name: Path(ds.get_local_copy()) for name, ds in datasets_mapping.items()}
@@ -0,0 +1,46 @@
1
+ import re
2
+
3
+ from kostyl.utils import setup_logger
4
+
5
+
6
+ logger = setup_logger(name="clearml_logging_utils.py", fmt="only_message")
7
+
8
+
9
+ def increment_version(s: str) -> str:
10
+ """
11
+ Increments the minor part of a version string.
12
+
13
+ Examples:
14
+ v1.00 -> v1.01
15
+ v2.99 -> v2.100
16
+ v.3.009 -> v.3.010
17
+
18
+ """
19
+ s = s.strip()
20
+ m = re.fullmatch(r"v(\.?)(\d+)\.(\d+)", s)
21
+ if not m:
22
+ raise ValueError(f"Invalid version format: {s!r}. Expected 'vX.Y' or 'v.X.Y'.")
23
+
24
+ vdot, major_str, minor_str = m.groups()
25
+ major = int(major_str)
26
+ minor = int(minor_str) + 1
27
+
28
+ # preserve leading zeros based on original width, length may increase (99 -> 100)
29
+ minor_out = str(minor).zfill(len(minor_str))
30
+ prefix = f"v{vdot}" # 'v' or 'v.'
31
+ return f"{prefix}{major}.{minor_out}"
32
+
33
+
34
+ def find_version_in_tags(tags: list[str]) -> str | None:
35
+ """
36
+ Finds the first version tag in the list of tags.
37
+
38
+ Note:
39
+ Version tags must be in the format 'vX.Y' or 'v.X.Y' (an optional dot after 'v' is supported).
40
+
41
+ """
42
+ version_pattern = re.compile(r"^v(\.?)(\d+)\.(\d+)$")
43
+ for tag in tags:
44
+ if version_pattern.match(tag):
45
+ return tag
46
+ return None
@@ -0,0 +1,91 @@
1
+ from pathlib import Path
2
+ from typing import Any
3
+ from typing import cast
4
+
5
+ from clearml import InputModel
6
+ from clearml import Task
7
+ from transformers import AutoModel
8
+ from transformers import AutoTokenizer
9
+ from transformers import PreTrainedModel
10
+ from transformers import PreTrainedTokenizerBase
11
+
12
+ from kostyl.ml.lightning.extenstions.pretrained_model import (
13
+ LightningCheckpointLoaderMixin,
14
+ )
15
+
16
+
17
+ def get_tokenizer_from_clearml(
18
+ model_id: str, task: Task | None = None, ignore_remote_overrides: bool = True
19
+ ) -> PreTrainedTokenizerBase:
20
+ """
21
+ Retrieve a Hugging Face tokenizer stored in a ClearML.
22
+
23
+ Args:
24
+ model_id (str): The ClearML InputModel identifier that holds the tokenizer artifacts.
25
+ task (Task | None, optional): An optional ClearML Task used to associate and sync
26
+ the model. Defaults to None.
27
+ ignore_remote_overrides (bool, optional): Whether to ignore remote hyperparameter
28
+ overrides when connecting the ClearML task. Defaults to True.
29
+
30
+ Returns:
31
+ PreTrainedTokenizerBase: The instantiated tokenizer loaded from the local copy
32
+ of the referenced ClearML InputModel.
33
+
34
+ """
35
+ clearml_tokenizer = InputModel(model_id=model_id)
36
+ if task is not None:
37
+ clearml_tokenizer.connect(task, ignore_remote_overrides=ignore_remote_overrides)
38
+
39
+ tokenizer = AutoTokenizer.from_pretrained(
40
+ clearml_tokenizer.get_local_copy(raise_on_error=True)
41
+ )
42
+ return tokenizer
43
+
44
+
45
+ def get_model_from_clearml[
46
+ TModel: PreTrainedModel | LightningCheckpointLoaderMixin | AutoModel
47
+ ](
48
+ model_id: str,
49
+ model: type[TModel],
50
+ task: Task | None = None,
51
+ ignore_remote_overrides: bool = True,
52
+ **kwargs: Any,
53
+ ) -> TModel:
54
+ """
55
+ Retrieve a pretrained model from ClearML and instantiate it using the appropriate loader.
56
+
57
+ Args:
58
+ model_id: Identifier of the ClearML input model to retrieve.
59
+ model: The model class that implements either PreTrainedModel or LightningCheckpointLoaderMixin.
60
+ task: Optional ClearML task used to resolve the input model reference. If provided, the input model
61
+ will be connected to this task, with remote overrides optionally ignored.
62
+ ignore_remote_overrides: When connecting the input model to the provided task, determines whether
63
+ remote configuration overrides should be ignored.
64
+ **kwargs: Additional keyword arguments to pass to the model loading method.
65
+
66
+ Returns:
67
+ An instantiated model loaded either from a ClearML package directory or a Lightning checkpoint.
68
+
69
+ """
70
+ input_model = InputModel(model_id=model_id)
71
+
72
+ if task is not None:
73
+ input_model.connect(task, ignore_remote_overrides=ignore_remote_overrides)
74
+
75
+ local_path = Path(input_model.get_local_copy(raise_on_error=True))
76
+
77
+ if local_path.is_dir() and input_model._is_package():
78
+ model_instance = model.from_pretrained(local_path, **kwargs)
79
+ elif local_path.suffix == ".ckpt":
80
+ if not issubclass(model, LightningCheckpointLoaderMixin):
81
+ raise ValueError(
82
+ f"Model class {model.__name__} is not compatible with Lightning checkpoints."
83
+ )
84
+ model_instance = model.from_lighting_checkpoint(local_path, **kwargs)
85
+ else:
86
+ raise ValueError(
87
+ f"Unsupported model format for path: {local_path}. "
88
+ "Expected a ClearML package directory or a .ckpt file."
89
+ )
90
+ model_instance = cast(TModel, model_instance)
91
+ return model_instance
@@ -0,0 +1,30 @@
1
+ from .base_model import KostylBaseModel
2
+ from .hyperparams import HyperparamsConfig
3
+ from .hyperparams import Lr
4
+ from .hyperparams import Optimizer
5
+ from .hyperparams import WeightDecay
6
+ from .training_settings import CheckpointConfig
7
+ from .training_settings import DataConfig
8
+ from .training_settings import DDPStrategyConfig
9
+ from .training_settings import EarlyStoppingConfig
10
+ from .training_settings import FSDP1StrategyConfig
11
+ from .training_settings import LightningTrainerParameters
12
+ from .training_settings import SingleDeviceStrategyConfig
13
+ from .training_settings import TrainingSettings
14
+
15
+
16
+ __all__ = [
17
+ "CheckpointConfig",
18
+ "DDPStrategyConfig",
19
+ "DataConfig",
20
+ "EarlyStoppingConfig",
21
+ "FSDP1StrategyConfig",
22
+ "HyperparamsConfig",
23
+ "KostylBaseModel",
24
+ "LightningTrainerParameters",
25
+ "Lr",
26
+ "Optimizer",
27
+ "SingleDeviceStrategyConfig",
28
+ "TrainingSettings",
29
+ "WeightDecay",
30
+ ]
@@ -0,0 +1,143 @@
1
+ from pathlib import Path
2
+ from typing import Self
3
+ from typing import TypeVar
4
+
5
+ from caseconverter import pascalcase
6
+ from caseconverter import snakecase
7
+ from clearml import Task
8
+ from pydantic import BaseModel as PydanticBaseModel
9
+
10
+ from kostyl.utils.dict_manipulations import convert_to_flat_dict
11
+ from kostyl.utils.dict_manipulations import flattened_dict_to_nested
12
+ from kostyl.utils.fs import load_config
13
+
14
+
15
+ TConfig = TypeVar("TConfig", bound=PydanticBaseModel)
16
+
17
+
18
+ class BaseModelWithConfigLoading(PydanticBaseModel):
19
+ """Pydantic class providing basic configuration loading functionality."""
20
+
21
+ @classmethod
22
+ def from_file(
23
+ cls: type[Self], # pyright: ignore
24
+ path: str | Path,
25
+ ) -> Self:
26
+ """
27
+ Create an instance of the class from a configuration file.
28
+
29
+ Args:
30
+ cls_: The class type to instantiate.
31
+ path (str | Path): Path to the configuration file.
32
+
33
+ Returns:
34
+ An instance of the class created from the configuration file.
35
+
36
+ """
37
+ config = load_config(path)
38
+ instance = cls.model_validate(config)
39
+ return instance
40
+
41
+ @classmethod
42
+ def from_dict(
43
+ cls: type[Self], # pyright: ignore
44
+ state_dict: dict,
45
+ ) -> Self:
46
+ """
47
+ Creates an instance from a dictionary.
48
+
49
+ Args:
50
+ cls_: The class type to instantiate.
51
+ state_dict (dict): A dictionary representing the state of the
52
+ class that must be validated and used for initialization.
53
+
54
+ Returns:
55
+ An initialized instance of the class based on the
56
+ provided state dictionary.
57
+
58
+ """
59
+ instance = cls.model_validate(state_dict)
60
+ return instance
61
+
62
+
63
+ class BaseModelWithClearmlSyncing(BaseModelWithConfigLoading):
64
+ """Pydantic class providing ClearML configuration loading and syncing functionality."""
65
+
66
+ @classmethod
67
+ def connect_as_file(
68
+ cls: type[Self], # pyright: ignore
69
+ task: Task,
70
+ path: str | Path,
71
+ alias: str | None = None,
72
+ ) -> Self:
73
+ """
74
+ Connects the configuration file to a ClearML task and creates an instance of the class from it.
75
+
76
+ This method connects the specified configuration file to the given ClearML task for version control and monitoring,
77
+ then loads and validates the configuration to the class.
78
+
79
+ Args:
80
+ cls: The class type to instantiate.
81
+ task: The ClearML Task object to connect the configuration to.
82
+ path: Path to the configuration file (supports YAML format).
83
+ alias: Optional alias for the configuration in ClearML. Defaults to PascalCase of the class name if None.
84
+
85
+ Returns:
86
+ An instance of the class created from the connected configuration file.
87
+
88
+ """
89
+ if isinstance(path, Path):
90
+ str_path = str(path)
91
+ else:
92
+ str_path = path
93
+
94
+ name = alias if alias is not None else pascalcase(cls.__name__)
95
+ connected_path = task.connect_configuration(str_path, name=pascalcase(name))
96
+
97
+ if not isinstance(connected_path, str):
98
+ connected_path_str = str(connected_path)
99
+ else:
100
+ connected_path_str = connected_path
101
+
102
+ model = cls.from_file(path=connected_path_str)
103
+ return model
104
+
105
+ @classmethod
106
+ def connect_as_dict(
107
+ cls: type[Self], # pyright: ignore
108
+ task: Task,
109
+ path: str | Path,
110
+ alias: str | None = None,
111
+ ) -> Self:
112
+ """
113
+ Connects configuration from a file as a dictionary to a ClearML task and creates an instance of the class.
114
+
115
+ This class method loads configuration from a file as a dictionary, flattens and sync them with ClearML
116
+ task parameters. Then it creates an instance of the class using the synced dictionary.
117
+
118
+ Args:
119
+ cls: The class type of the model to be created (must be a TRetuningModel subclass).
120
+ task: The ClearML task to connect the configuration to.
121
+ path: Path to the configuration file to load parameters from.
122
+ alias: Optional alias name for the configuration. If None, uses snake_case of class name.
123
+
124
+ Returns:
125
+ An instance of the specified class created from the loaded configuration.
126
+
127
+ """
128
+ name = alias if alias is not None else snakecase(cls.__name__)
129
+
130
+ config = load_config(path)
131
+
132
+ flattened_config = convert_to_flat_dict(config)
133
+ task.connect(flattened_config, name=pascalcase(name))
134
+ config = flattened_dict_to_nested(flattened_config)
135
+
136
+ model = cls.from_dict(state_dict=config)
137
+ return model
138
+
139
+
140
+ class KostylBaseModel(BaseModelWithClearmlSyncing):
141
+ """A Pydantic model class with basic configuration loading functionality."""
142
+
143
+ pass
@@ -0,0 +1,84 @@
1
+ from pydantic import BaseModel
2
+ from pydantic import Field
3
+ from pydantic import model_validator
4
+
5
+ from kostyl.utils.logging import setup_logger
6
+
7
+ from .base_model import KostylBaseModel
8
+
9
+
10
+ logger = setup_logger(fmt="only_message")
11
+
12
+
13
+ class Optimizer(BaseModel):
14
+ """Optimizer hyperparameters configuration."""
15
+
16
+ adamw_beta1: float = 0.9
17
+ adamw_beta2: float = 0.999
18
+
19
+
20
+ class Lr(BaseModel):
21
+ """Learning rate hyperparameters configuration."""
22
+
23
+ use_scheduler: bool = False
24
+ warmup_iters_ratio: float | None = Field(
25
+ default=None, gt=0, lt=1, validate_default=False
26
+ )
27
+ warmup_value: float | None = Field(default=None, gt=0, validate_default=False)
28
+ base_value: float
29
+ final_value: float | None = Field(default=None, gt=0, validate_default=False)
30
+
31
+ @model_validator(mode="after")
32
+ def validate_warmup(self) -> "Lr":
33
+ """Validates the warmup parameters based on use_scheduler."""
34
+ if (self.warmup_value is None) != (
35
+ self.warmup_iters_ratio is None
36
+ ) and self.use_scheduler:
37
+ raise ValueError(
38
+ "Both warmup_value and warmup_iters_ratio must be provided or neither"
39
+ )
40
+ elif (
41
+ (self.warmup_value is not None) or (self.warmup_iters_ratio is not None)
42
+ ) and (not self.use_scheduler):
43
+ logger.warning(
44
+ "use_scheduler is False, warmup_value and warmup_iters_ratio will be ignored."
45
+ )
46
+ self.warmup_value = None
47
+ self.warmup_iters_ratio = None
48
+ return self
49
+
50
+ @model_validator(mode="after")
51
+ def validate_final_value(self) -> "Lr":
52
+ """Validates the final_value based on use_scheduler."""
53
+ if self.use_scheduler and (self.final_value is None):
54
+ raise ValueError("If use_scheduler is True, final_value must be provided.")
55
+ if (not self.use_scheduler) and (self.final_value is not None):
56
+ logger.warning("use_scheduler is False, final_value will be ignored.")
57
+ self.final_value = None
58
+ return self
59
+
60
+
61
+ class WeightDecay(BaseModel):
62
+ """Weight decay hyperparameters configuration."""
63
+
64
+ use_scheduler: bool = False
65
+ base_value: float
66
+ final_value: float | None = None
67
+
68
+ @model_validator(mode="after")
69
+ def validate_final_value(self) -> "WeightDecay":
70
+ """Validates the final_value based on use_scheduler."""
71
+ if self.use_scheduler and self.final_value is None:
72
+ raise ValueError("If use_scheduler is True, final_value must be provided.")
73
+ if not self.use_scheduler and self.final_value is not None:
74
+ logger.warning("use_scheduler is False, final_value will be ignored.")
75
+ return self
76
+
77
+
78
+ class HyperparamsConfig(KostylBaseModel):
79
+ """Model training hyperparameters configuration."""
80
+
81
+ grad_clip_val: float | None = Field(default=None, gt=0, validate_default=False)
82
+ optimizer: Optimizer = Optimizer()
83
+ lr: Lr
84
+ weight_decay: WeightDecay