kostyl-toolkit 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kostyl_toolkit-0.1.0/PKG-INFO +102 -0
- kostyl_toolkit-0.1.0/README.md +89 -0
- kostyl_toolkit-0.1.0/kostyl/__init__.py +0 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/__init__.py +0 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/clearml/__init__.py +0 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/clearml/logging_utils.py +46 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/clearml/pulling_utils.py +83 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/configs/__init__.py +30 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/configs/config_mixins.py +146 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/configs/hyperparams.py +84 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/configs/training_params.py +110 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/dist_utils.py +72 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/__init__.py +5 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/callbacks/__init__.py +10 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/callbacks/checkpoint.py +56 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/callbacks/early_stopping.py +18 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/callbacks/registry_uploading.py +126 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/extenstions/__init__.py +5 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/extenstions/custom_module.py +179 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/extenstions/pretrained_model.py +115 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/loggers/__init__.py +0 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/loggers/tb_logger.py +31 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/lightning/steps_estimation.py +44 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/metrics_formatting.py +41 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/params_groups.py +50 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/schedulers/__init__.py +6 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/schedulers/base.py +48 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/schedulers/composite.py +68 -0
- kostyl_toolkit-0.1.0/kostyl/ml_core/schedulers/cosine.py +219 -0
- kostyl_toolkit-0.1.0/kostyl/utils/__init__.py +10 -0
- kostyl_toolkit-0.1.0/kostyl/utils/dict_manipulations.py +40 -0
- kostyl_toolkit-0.1.0/kostyl/utils/logging.py +147 -0
- kostyl_toolkit-0.1.0/pyproject.toml +105 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: kostyl-toolkit
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Kickass Orchestration System for Training, Yielding & Logging
|
|
5
|
+
Requires-Dist: case-converter>=1.2.0
|
|
6
|
+
Requires-Dist: clearml[s3]>=2.0.2
|
|
7
|
+
Requires-Dist: lightning>=2.5.6
|
|
8
|
+
Requires-Dist: loguru>=0.7.3
|
|
9
|
+
Requires-Dist: pydantic>=2.12.4
|
|
10
|
+
Requires-Dist: transformers>=4.57.1
|
|
11
|
+
Requires-Python: >=3.12
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
|
|
14
|
+
# Kostyl Toolkit
|
|
15
|
+
|
|
16
|
+
Kickass Orchestration System for Training, Yielding & Logging — a batteries-included toolbox that glues PyTorch Lightning, Hugging Face Transformers, and ClearML into a single workflow.
|
|
17
|
+
|
|
18
|
+
## Overview
|
|
19
|
+
- Rapidly bootstrap Lightning experiments with opinionated defaults (`KostylLightningModule`, custom schedulers, grad clipping and metric formatting).
|
|
20
|
+
- Keep model configs source-controlled via Pydantic mixins, with ClearML syncing out of the box (`ConfigLoadingMixin`, `ClearMLConfigMixin`).
|
|
21
|
+
- Reuse Lightning checkpoints directly inside Transformers models through `LightningCheckpointLoaderMixin`.
|
|
22
|
+
- Ship distributed-friendly utilities (deterministic logging, FSDP helpers, LR scaling, ClearML tag management).
|
|
23
|
+
|
|
24
|
+
## Installation
|
|
25
|
+
```bash
|
|
26
|
+
# Latest release from PyPI
|
|
27
|
+
pip install kostyl-toolkit
|
|
28
|
+
|
|
29
|
+
# or with uv
|
|
30
|
+
uv pip install kostyl-toolkit
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
Development setup:
|
|
34
|
+
```bash
|
|
35
|
+
uv sync # creates the virtualenv declared in pyproject.toml
|
|
36
|
+
source .venv/bin/activate.fish
|
|
37
|
+
pre-commit install # optional but recommended
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
## Quick Start
|
|
41
|
+
```python
|
|
42
|
+
from lightning import Trainer
|
|
43
|
+
from transformers import AutoModelForSequenceClassification
|
|
44
|
+
|
|
45
|
+
from kostyl.ml_core.configs.hyperparams import HyperparamsConfig
|
|
46
|
+
from kostyl.ml_core.configs.training_params import TrainingParams
|
|
47
|
+
from kostyl.ml_core.lightning.extenstions.custom_module import KostylLightningModule
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class TextClassifier(KostylLightningModule):
|
|
51
|
+
def __init__(self, hyperparams: HyperparamsConfig):
|
|
52
|
+
super().__init__()
|
|
53
|
+
self.hyperparams = hyperparams # grad clipping + scheduler knobs
|
|
54
|
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
|
55
|
+
"distilbert-base-uncased",
|
|
56
|
+
num_labels=2,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
def training_step(self, batch, batch_idx):
|
|
60
|
+
outputs = self.model(**batch)
|
|
61
|
+
self.log("train/loss", outputs.loss)
|
|
62
|
+
return outputs.loss
|
|
63
|
+
|
|
64
|
+
train_cfg = TrainingParams.from_file("configs/training.yaml")
|
|
65
|
+
hyperparams = HyperparamsConfig.from_file("configs/hyperparams.yaml")
|
|
66
|
+
|
|
67
|
+
module = TextClassifier(hyperparams)
|
|
68
|
+
|
|
69
|
+
trainer = Trainer(**train_cfg.trainer.model_dump())
|
|
70
|
+
trainer.fit(module)
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
Restoring a plain Transformers model from a Lightning checkpoint:
|
|
74
|
+
```python
|
|
75
|
+
from kostyl.ml_core.lightning.extenstions.pretrained_model import LightningCheckpointLoaderMixin
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
model = LightningCheckpointLoaderMixin.from_lighting_checkpoint(
|
|
79
|
+
"checkpoints/epoch=03-step=500.ckpt",
|
|
80
|
+
config_key="config",
|
|
81
|
+
weights_prefix="model.",
|
|
82
|
+
)
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
## Components
|
|
86
|
+
- **Configurations** (`kostyl/ml_core/configs`): strongly-typed training, optimizer, and scheduler configs with ClearML syncing helpers.
|
|
87
|
+
- **Lightning Extensions** (`kostyl/ml_core/lightning`): custom LightningModule base class, callbacks, logging bridges, and the checkpoint loader mixin.
|
|
88
|
+
- **Schedulers** (`kostyl/ml_core/schedulers`): extensible LR schedulers (base/composite/cosine) with serialization helpers and on-step logging.
|
|
89
|
+
- **ClearML Utilities** (`kostyl/ml_core/clearml`): tag/version helpers and logging bridges for ClearML Tasks.
|
|
90
|
+
- **Distributed + Metrics Utils** (`kostyl/ml_core/dist_utils.py`, `metrics_formatting.py`): world-size-aware LR scaling, rank-aware metric naming, and per-class formatting.
|
|
91
|
+
- **Logging Helpers** (`kostyl/utils/logging.py`): rank-aware Loguru setup and uniform handling of incompatible checkpoint keys.
|
|
92
|
+
|
|
93
|
+
## Project Layout
|
|
94
|
+
```
|
|
95
|
+
kostyl/
|
|
96
|
+
ml_core/
|
|
97
|
+
configs/ # Pydantic configs + ClearML mixins
|
|
98
|
+
lightning/ # Lightning module, callbacks, loggers, extensions
|
|
99
|
+
schedulers/ # Base + composite/cosine schedulers
|
|
100
|
+
clearml/ # Logging + pulling utilities
|
|
101
|
+
utils/ # Dict helpers, logging utilities
|
|
102
|
+
```
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Kostyl Toolkit
|
|
2
|
+
|
|
3
|
+
Kickass Orchestration System for Training, Yielding & Logging — a batteries-included toolbox that glues PyTorch Lightning, Hugging Face Transformers, and ClearML into a single workflow.
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
- Rapidly bootstrap Lightning experiments with opinionated defaults (`KostylLightningModule`, custom schedulers, grad clipping and metric formatting).
|
|
7
|
+
- Keep model configs source-controlled via Pydantic mixins, with ClearML syncing out of the box (`ConfigLoadingMixin`, `ClearMLConfigMixin`).
|
|
8
|
+
- Reuse Lightning checkpoints directly inside Transformers models through `LightningCheckpointLoaderMixin`.
|
|
9
|
+
- Ship distributed-friendly utilities (deterministic logging, FSDP helpers, LR scaling, ClearML tag management).
|
|
10
|
+
|
|
11
|
+
## Installation
|
|
12
|
+
```bash
|
|
13
|
+
# Latest release from PyPI
|
|
14
|
+
pip install kostyl-toolkit
|
|
15
|
+
|
|
16
|
+
# or with uv
|
|
17
|
+
uv pip install kostyl-toolkit
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
Development setup:
|
|
21
|
+
```bash
|
|
22
|
+
uv sync # creates the virtualenv declared in pyproject.toml
|
|
23
|
+
source .venv/bin/activate.fish
|
|
24
|
+
pre-commit install # optional but recommended
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
## Quick Start
|
|
28
|
+
```python
|
|
29
|
+
from lightning import Trainer
|
|
30
|
+
from transformers import AutoModelForSequenceClassification
|
|
31
|
+
|
|
32
|
+
from kostyl.ml_core.configs.hyperparams import HyperparamsConfig
|
|
33
|
+
from kostyl.ml_core.configs.training_params import TrainingParams
|
|
34
|
+
from kostyl.ml_core.lightning.extenstions.custom_module import KostylLightningModule
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TextClassifier(KostylLightningModule):
|
|
38
|
+
def __init__(self, hyperparams: HyperparamsConfig):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.hyperparams = hyperparams # grad clipping + scheduler knobs
|
|
41
|
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
|
42
|
+
"distilbert-base-uncased",
|
|
43
|
+
num_labels=2,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def training_step(self, batch, batch_idx):
|
|
47
|
+
outputs = self.model(**batch)
|
|
48
|
+
self.log("train/loss", outputs.loss)
|
|
49
|
+
return outputs.loss
|
|
50
|
+
|
|
51
|
+
train_cfg = TrainingParams.from_file("configs/training.yaml")
|
|
52
|
+
hyperparams = HyperparamsConfig.from_file("configs/hyperparams.yaml")
|
|
53
|
+
|
|
54
|
+
module = TextClassifier(hyperparams)
|
|
55
|
+
|
|
56
|
+
trainer = Trainer(**train_cfg.trainer.model_dump())
|
|
57
|
+
trainer.fit(module)
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
Restoring a plain Transformers model from a Lightning checkpoint:
|
|
61
|
+
```python
|
|
62
|
+
from kostyl.ml_core.lightning.extenstions.pretrained_model import LightningCheckpointLoaderMixin
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
model = LightningCheckpointLoaderMixin.from_lighting_checkpoint(
|
|
66
|
+
"checkpoints/epoch=03-step=500.ckpt",
|
|
67
|
+
config_key="config",
|
|
68
|
+
weights_prefix="model.",
|
|
69
|
+
)
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
## Components
|
|
73
|
+
- **Configurations** (`kostyl/ml_core/configs`): strongly-typed training, optimizer, and scheduler configs with ClearML syncing helpers.
|
|
74
|
+
- **Lightning Extensions** (`kostyl/ml_core/lightning`): custom LightningModule base class, callbacks, logging bridges, and the checkpoint loader mixin.
|
|
75
|
+
- **Schedulers** (`kostyl/ml_core/schedulers`): extensible LR schedulers (base/composite/cosine) with serialization helpers and on-step logging.
|
|
76
|
+
- **ClearML Utilities** (`kostyl/ml_core/clearml`): tag/version helpers and logging bridges for ClearML Tasks.
|
|
77
|
+
- **Distributed + Metrics Utils** (`kostyl/ml_core/dist_utils.py`, `metrics_formatting.py`): world-size-aware LR scaling, rank-aware metric naming, and per-class formatting.
|
|
78
|
+
- **Logging Helpers** (`kostyl/utils/logging.py`): rank-aware Loguru setup and uniform handling of incompatible checkpoint keys.
|
|
79
|
+
|
|
80
|
+
## Project Layout
|
|
81
|
+
```
|
|
82
|
+
kostyl/
|
|
83
|
+
ml_core/
|
|
84
|
+
configs/ # Pydantic configs + ClearML mixins
|
|
85
|
+
lightning/ # Lightning module, callbacks, loggers, extensions
|
|
86
|
+
schedulers/ # Base + composite/cosine schedulers
|
|
87
|
+
clearml/ # Logging + pulling utilities
|
|
88
|
+
utils/ # Dict helpers, logging utilities
|
|
89
|
+
```
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
from kostyl.utils import setup_logger
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
logger = setup_logger(name="clearml_logging_utils.py", fmt="only_message")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def increment_version(s: str) -> str:
|
|
10
|
+
"""
|
|
11
|
+
Increments the minor part of a version string.
|
|
12
|
+
|
|
13
|
+
Examples:
|
|
14
|
+
v1.00 -> v1.01
|
|
15
|
+
v2.99 -> v2.100
|
|
16
|
+
v.3.009 -> v.3.010
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
s = s.strip()
|
|
20
|
+
m = re.fullmatch(r"v(\.?)(\d+)\.(\d+)", s)
|
|
21
|
+
if not m:
|
|
22
|
+
raise ValueError(f"Invalid version format: {s!r}. Expected 'vX.Y' or 'v.X.Y'.")
|
|
23
|
+
|
|
24
|
+
vdot, major_str, minor_str = m.groups()
|
|
25
|
+
major = int(major_str)
|
|
26
|
+
minor = int(minor_str) + 1
|
|
27
|
+
|
|
28
|
+
# preserve leading zeros based on original width, length may increase (99 -> 100)
|
|
29
|
+
minor_out = str(minor).zfill(len(minor_str))
|
|
30
|
+
prefix = f"v{vdot}" # 'v' or 'v.'
|
|
31
|
+
return f"{prefix}{major}.{minor_out}"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def find_version_in_tags(tags: list[str]) -> str | None:
|
|
35
|
+
"""
|
|
36
|
+
Finds the first version tag in the list of tags.
|
|
37
|
+
|
|
38
|
+
Note:
|
|
39
|
+
Version tags must be in the format 'vX.Y' or 'v.X.Y' (an optional dot after 'v' is supported).
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
version_pattern = re.compile(r"^v(\.?)(\d+)\.(\d+)$")
|
|
43
|
+
for tag in tags:
|
|
44
|
+
if version_pattern.match(tag):
|
|
45
|
+
return tag
|
|
46
|
+
return None
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from clearml import InputModel
|
|
4
|
+
from clearml import Task
|
|
5
|
+
from transformers import AutoTokenizer
|
|
6
|
+
from transformers import PreTrainedModel
|
|
7
|
+
from transformers import PreTrainedTokenizerBase
|
|
8
|
+
|
|
9
|
+
from kostyl.ml_core.lightning.extenstions.pretrained_model import (
|
|
10
|
+
LightningCheckpointLoaderMixin,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_tokenizer_from_clearml(
|
|
15
|
+
model_id: str, task: Task | None = None, ignore_remote_overrides: bool = True
|
|
16
|
+
) -> PreTrainedTokenizerBase:
|
|
17
|
+
"""
|
|
18
|
+
Retrieve a Hugging Face tokenizer stored in a ClearML.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
model_id (str): The ClearML InputModel identifier that holds the tokenizer artifacts.
|
|
22
|
+
task (Task | None, optional): An optional ClearML Task used to associate and sync
|
|
23
|
+
the model. Defaults to None.
|
|
24
|
+
ignore_remote_overrides (bool, optional): Whether to ignore remote hyperparameter
|
|
25
|
+
overrides when connecting the ClearML task. Defaults to True.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
PreTrainedTokenizerBase: The instantiated tokenizer loaded from the local copy
|
|
29
|
+
of the referenced ClearML InputModel.
|
|
30
|
+
|
|
31
|
+
"""
|
|
32
|
+
clearml_tokenizer = InputModel(model_id=model_id)
|
|
33
|
+
if task is not None:
|
|
34
|
+
clearml_tokenizer.connect(task, ignore_remote_overrides=ignore_remote_overrides)
|
|
35
|
+
|
|
36
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
37
|
+
clearml_tokenizer.get_local_copy(raise_on_error=True)
|
|
38
|
+
)
|
|
39
|
+
return tokenizer
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_model_from_clearml[TModel: PreTrainedModel | LightningCheckpointLoaderMixin](
|
|
43
|
+
model_id: str,
|
|
44
|
+
model: type[TModel],
|
|
45
|
+
task: Task | None = None,
|
|
46
|
+
ignore_remote_overrides: bool = True,
|
|
47
|
+
) -> TModel:
|
|
48
|
+
"""
|
|
49
|
+
Retrieve a pretrained model from ClearML and instantiate it using the appropriate loader.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
model_id: Identifier of the ClearML input model to retrieve.
|
|
53
|
+
model: The model class that implements either PreTrainedModel or LightningCheckpointLoaderMixin.
|
|
54
|
+
task: Optional ClearML task used to resolve the input model reference. If provided, the input model
|
|
55
|
+
will be connected to this task, with remote overrides optionally ignored.
|
|
56
|
+
ignore_remote_overrides: When connecting the input model to the provided task, determines whether
|
|
57
|
+
remote configuration overrides should be ignored.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
An instantiated model loaded either from a ClearML package directory or a Lightning checkpoint.
|
|
61
|
+
|
|
62
|
+
"""
|
|
63
|
+
input_model = InputModel(model_id=model_id)
|
|
64
|
+
|
|
65
|
+
if task is not None:
|
|
66
|
+
input_model.connect(task, ignore_remote_overrides=ignore_remote_overrides)
|
|
67
|
+
|
|
68
|
+
local_path = Path(input_model.get_local_copy(raise_on_error=True))
|
|
69
|
+
|
|
70
|
+
if local_path.is_dir() and input_model._is_package():
|
|
71
|
+
model_instance = model.from_pretrained(local_path)
|
|
72
|
+
elif local_path.suffix == ".ckpt":
|
|
73
|
+
if not issubclass(model, LightningCheckpointLoaderMixin):
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"Model class {model.__name__} is not compatible with Lightning checkpoints."
|
|
76
|
+
)
|
|
77
|
+
model_instance = model.from_lighting_checkpoint(local_path)
|
|
78
|
+
else:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"Unsupported model format for path: {local_path}. "
|
|
81
|
+
"Expected a ClearML package directory or a .ckpt file."
|
|
82
|
+
)
|
|
83
|
+
return model_instance
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from .config_mixins import ConfigLoadingMixin
|
|
2
|
+
from .hyperparams import HyperparamsConfig
|
|
3
|
+
from .hyperparams import Lr
|
|
4
|
+
from .hyperparams import Optimizer
|
|
5
|
+
from .hyperparams import WeightDecay
|
|
6
|
+
from .training_params import CheckpointConfig
|
|
7
|
+
from .training_params import ClearMLTrainingParameters
|
|
8
|
+
from .training_params import DataConfig
|
|
9
|
+
from .training_params import DDPStrategyConfig
|
|
10
|
+
from .training_params import EarlyStoppingConfig
|
|
11
|
+
from .training_params import FSDP1StrategyConfig
|
|
12
|
+
from .training_params import SingleDeviceStrategyConfig
|
|
13
|
+
from .training_params import TrainingParams
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"CheckpointConfig",
|
|
18
|
+
"ClearMLTrainingParameters",
|
|
19
|
+
"ConfigLoadingMixin",
|
|
20
|
+
"DDPStrategyConfig",
|
|
21
|
+
"DataConfig",
|
|
22
|
+
"EarlyStoppingConfig",
|
|
23
|
+
"FSDP1StrategyConfig",
|
|
24
|
+
"HyperparamsConfig",
|
|
25
|
+
"Lr",
|
|
26
|
+
"Optimizer",
|
|
27
|
+
"SingleDeviceStrategyConfig",
|
|
28
|
+
"TrainingParams",
|
|
29
|
+
"WeightDecay",
|
|
30
|
+
]
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import clearml
|
|
4
|
+
import yaml
|
|
5
|
+
from caseconverter import pascalcase
|
|
6
|
+
from caseconverter import snakecase
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from kostyl.utils import convert_to_flat_dict
|
|
10
|
+
from kostyl.utils import flattened_dict_to_nested
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def load_config(path: Path | str) -> dict:
|
|
14
|
+
"""Load a configuration from file."""
|
|
15
|
+
if isinstance(path, str):
|
|
16
|
+
path = Path(path)
|
|
17
|
+
|
|
18
|
+
if not path.is_file():
|
|
19
|
+
raise ValueError(f"Config file {path} does not exist or is not a file.")
|
|
20
|
+
|
|
21
|
+
match path.suffix:
|
|
22
|
+
case ".yaml" | ".yml":
|
|
23
|
+
config = yaml.safe_load(path.open("r"))
|
|
24
|
+
case _:
|
|
25
|
+
raise ValueError(f"Unsupported config file format: {path.suffix}")
|
|
26
|
+
return config
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ConfigLoadingMixin[TConfig: ConfigLoadingMixin](BaseModel):
|
|
30
|
+
"""Pydantic mixin class providing basic configuration loading functionality."""
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def from_file(
|
|
34
|
+
cls: type[TConfig],
|
|
35
|
+
path: str | Path,
|
|
36
|
+
) -> TConfig:
|
|
37
|
+
"""
|
|
38
|
+
Create an instance of the class from a configuration file.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
path (str | Path): Path to the configuration file.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
An instance of the class created from the configuration file.
|
|
45
|
+
|
|
46
|
+
"""
|
|
47
|
+
config = load_config(path)
|
|
48
|
+
instance = cls.model_validate(config)
|
|
49
|
+
return instance
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def from_dict(
|
|
53
|
+
cls: type[TConfig],
|
|
54
|
+
state_dict: dict,
|
|
55
|
+
) -> TConfig:
|
|
56
|
+
"""
|
|
57
|
+
Creates an instance from a dictionary.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
state_dict (dict): A dictionary representing the state of the
|
|
61
|
+
class that must be validated and used for initialization.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
An initialized instance of the class based on the
|
|
65
|
+
provided state dictionary.
|
|
66
|
+
|
|
67
|
+
"""
|
|
68
|
+
instance = cls.model_validate(state_dict)
|
|
69
|
+
return instance
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class ClearMLConfigMixin[TConfig: ClearMLConfigMixin](ConfigLoadingMixin[TConfig]):
|
|
73
|
+
"""Pydantic mixin class providing ClearML configuration loading and syncing functionality."""
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def connect_as_file(
|
|
77
|
+
cls: type[TConfig],
|
|
78
|
+
task: clearml.Task,
|
|
79
|
+
path: str | Path,
|
|
80
|
+
alias: str | None = None,
|
|
81
|
+
) -> TConfig:
|
|
82
|
+
"""
|
|
83
|
+
Connects the configuration file to a ClearML task and creates an instance of the class from it.
|
|
84
|
+
|
|
85
|
+
This method connects the specified configuration file to the given ClearML task for version control and monitoring,
|
|
86
|
+
then loads and validates the configuration to the class.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
cls: The class type to instantiate.
|
|
90
|
+
task: The ClearML Task object to connect the configuration to.
|
|
91
|
+
path: Path to the configuration file (supports YAML format).
|
|
92
|
+
alias: Optional alias for the configuration in ClearML. Defaults to PascalCase of the class name if None.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
An instance of the class created from the connected configuration file.
|
|
96
|
+
|
|
97
|
+
"""
|
|
98
|
+
if isinstance(path, Path):
|
|
99
|
+
str_path = str(path)
|
|
100
|
+
else:
|
|
101
|
+
str_path = path
|
|
102
|
+
|
|
103
|
+
name = alias if alias is not None else pascalcase(cls.__name__)
|
|
104
|
+
connected_path = task.connect_configuration(str_path, name=pascalcase(name))
|
|
105
|
+
|
|
106
|
+
if not isinstance(connected_path, str):
|
|
107
|
+
connected_path_str = str(connected_path)
|
|
108
|
+
else:
|
|
109
|
+
connected_path_str = connected_path
|
|
110
|
+
|
|
111
|
+
model = cls.from_file(path=connected_path_str)
|
|
112
|
+
return model
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def connect_as_dict(
|
|
116
|
+
cls: type[TConfig],
|
|
117
|
+
task: clearml.Task,
|
|
118
|
+
path: str | Path,
|
|
119
|
+
alias: str | None = None,
|
|
120
|
+
) -> TConfig:
|
|
121
|
+
"""
|
|
122
|
+
Connects configuration from a file as a dictionary to a ClearML task and creates an instance of the class.
|
|
123
|
+
|
|
124
|
+
This class method loads configuration from a file as a dictionary, flattens and sync them with ClearML
|
|
125
|
+
task parameters. Then it creates an instance of the class using the synced dictionary.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
cls: The class type of the model to be created (must be a TRetuningModel subclass).
|
|
129
|
+
task: The ClearML task to connect the configuration to.
|
|
130
|
+
path: Path to the configuration file to load parameters from.
|
|
131
|
+
alias: Optional alias name for the configuration. If None, uses snake_case of class name.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
An instance of the specified class created from the loaded configuration.
|
|
135
|
+
|
|
136
|
+
"""
|
|
137
|
+
name = alias if alias is not None else snakecase(cls.__name__)
|
|
138
|
+
|
|
139
|
+
config = load_config(path)
|
|
140
|
+
|
|
141
|
+
flattened_config = convert_to_flat_dict(config)
|
|
142
|
+
task.connect(flattened_config, name=pascalcase(name))
|
|
143
|
+
config = flattened_dict_to_nested(flattened_config)
|
|
144
|
+
|
|
145
|
+
model = cls.from_dict(config)
|
|
146
|
+
return model
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from pydantic import BaseModel
|
|
2
|
+
from pydantic import Field
|
|
3
|
+
from pydantic import model_validator
|
|
4
|
+
|
|
5
|
+
from kostyl.utils.logging import setup_logger
|
|
6
|
+
|
|
7
|
+
from .config_mixins import ClearMLConfigMixin
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
logger = setup_logger(fmt="only_message")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Optimizer(BaseModel):
|
|
14
|
+
"""Optimizer hyperparameters configuration."""
|
|
15
|
+
|
|
16
|
+
adamw_beta1: float = 0.9
|
|
17
|
+
adamw_beta2: float = 0.999
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Lr(BaseModel):
|
|
21
|
+
"""Learning rate hyperparameters configuration."""
|
|
22
|
+
|
|
23
|
+
use_scheduler: bool = False
|
|
24
|
+
warmup_iters_ratio: float | None = Field(
|
|
25
|
+
default=None, gt=0, lt=1, validate_default=False
|
|
26
|
+
)
|
|
27
|
+
warmup_value: float | None = Field(default=None, gt=0, validate_default=False)
|
|
28
|
+
base_value: float
|
|
29
|
+
final_value: float | None = Field(default=None, gt=0, validate_default=False)
|
|
30
|
+
|
|
31
|
+
@model_validator(mode="after")
|
|
32
|
+
def validate_warmup(self) -> "Lr":
|
|
33
|
+
"""Validates the warmup parameters based on use_scheduler."""
|
|
34
|
+
if (self.warmup_value is None) != (
|
|
35
|
+
self.warmup_iters_ratio is None
|
|
36
|
+
) and self.use_scheduler:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"Both warmup_value and warmup_iters_ratio must be provided or neither"
|
|
39
|
+
)
|
|
40
|
+
elif (
|
|
41
|
+
(self.warmup_value is not None) or (self.warmup_iters_ratio is not None)
|
|
42
|
+
) and (not self.use_scheduler):
|
|
43
|
+
logger.warning(
|
|
44
|
+
"use_scheduler is False, warmup_value and warmup_iters_ratio will be ignored."
|
|
45
|
+
)
|
|
46
|
+
self.warmup_value = None
|
|
47
|
+
self.warmup_iters_ratio = None
|
|
48
|
+
return self
|
|
49
|
+
|
|
50
|
+
@model_validator(mode="after")
|
|
51
|
+
def validate_final_value(self) -> "Lr":
|
|
52
|
+
"""Validates the final_value based on use_scheduler."""
|
|
53
|
+
if self.use_scheduler and (self.final_value is None):
|
|
54
|
+
raise ValueError("If use_scheduler is True, final_value must be provided.")
|
|
55
|
+
if (not self.use_scheduler) and (self.final_value is not None):
|
|
56
|
+
logger.warning("use_scheduler is False, final_value will be ignored.")
|
|
57
|
+
self.final_value = None
|
|
58
|
+
return self
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class WeightDecay(BaseModel):
|
|
62
|
+
"""Weight decay hyperparameters configuration."""
|
|
63
|
+
|
|
64
|
+
use_scheduler: bool = False
|
|
65
|
+
base_value: float
|
|
66
|
+
final_value: float | None = None
|
|
67
|
+
|
|
68
|
+
@model_validator(mode="after")
|
|
69
|
+
def validate_final_value(self) -> "WeightDecay":
|
|
70
|
+
"""Validates the final_value based on use_scheduler."""
|
|
71
|
+
if self.use_scheduler and self.final_value is None:
|
|
72
|
+
raise ValueError("If use_scheduler is True, final_value must be provided.")
|
|
73
|
+
if not self.use_scheduler and self.final_value is not None:
|
|
74
|
+
logger.warning("use_scheduler is False, final_value will be ignored.")
|
|
75
|
+
return self
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class HyperparamsConfig(ClearMLConfigMixin["HyperparamsConfig"]):
|
|
79
|
+
"""Model training hyperparameters configuration."""
|
|
80
|
+
|
|
81
|
+
grad_clip_val: float | None = Field(default=None, gt=0, validate_default=False)
|
|
82
|
+
optimizer: Optimizer = Optimizer()
|
|
83
|
+
lr: Lr
|
|
84
|
+
weight_decay: WeightDecay
|