kostyl-toolkit 0.1.30__tar.gz → 0.1.32__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/PKG-INFO +1 -1
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/clearml/pulling_utils.py +1 -1
- kostyl_toolkit-0.1.32/kostyl/ml/data_processing_utils.py +102 -0
- kostyl_toolkit-0.1.32/kostyl/ml/lightning/__init__.py +5 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/lightning/callbacks/__init__.py +0 -2
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/lightning/callbacks/checkpoint.py +1 -2
- {kostyl_toolkit-0.1.30/kostyl/ml/lightning/extenstions → kostyl_toolkit-0.1.32/kostyl/ml/lightning/extensions}/custom_module.py +21 -8
- kostyl_toolkit-0.1.32/kostyl/ml/lightning/training_utils.py +241 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/metrics_formatting.py +2 -3
- {kostyl_toolkit-0.1.30/kostyl/ml/lightning/callbacks → kostyl_toolkit-0.1.32/kostyl/ml}/registry_uploader.py +20 -43
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/pyproject.toml +1 -1
- kostyl_toolkit-0.1.30/kostyl/ml/lightning/__init__.py +0 -5
- kostyl_toolkit-0.1.30/kostyl/ml/lightning/steps_estimation.py +0 -44
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/README.md +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/__init__.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/__init__.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/clearml/__init__.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/clearml/dataset_utils.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/clearml/logging_utils.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/configs/__init__.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/configs/base_model.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/configs/hyperparams.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/configs/training_settings.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/dist_utils.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/lightning/callbacks/early_stopping.py +0 -0
- {kostyl_toolkit-0.1.30/kostyl/ml/lightning/extenstions → kostyl_toolkit-0.1.32/kostyl/ml/lightning/extensions}/__init__.py +0 -0
- {kostyl_toolkit-0.1.30/kostyl/ml/lightning/extenstions → kostyl_toolkit-0.1.32/kostyl/ml/lightning/extensions}/pretrained_model.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/lightning/loggers/__init__.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/lightning/loggers/tb_logger.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/params_groups.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/schedulers/__init__.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/schedulers/base.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/schedulers/composite.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/schedulers/cosine.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/schedulers/linear.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/utils/__init__.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/utils/dict_manipulations.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/utils/fs.py +0 -0
- {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/utils/logging.py +0 -0
|
@@ -9,7 +9,7 @@ from transformers import AutoTokenizer
|
|
|
9
9
|
from transformers import PreTrainedModel
|
|
10
10
|
from transformers import PreTrainedTokenizerBase
|
|
11
11
|
|
|
12
|
-
from kostyl.ml.lightning.
|
|
12
|
+
from kostyl.ml.lightning.extensions.pretrained_model import (
|
|
13
13
|
LightningCheckpointLoaderMixin,
|
|
14
14
|
)
|
|
15
15
|
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from transformers import DataCollatorWithPadding
|
|
6
|
+
from transformers.data.data_collator import DataCollatorMixin
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BatchCollatorWithKeyAlignment:
|
|
10
|
+
"""
|
|
11
|
+
Maps dataset keys to HuggingFace DataCollator expected keys and collates the batch.
|
|
12
|
+
|
|
13
|
+
HuggingFace collators expect specific keys depending on the collator type:
|
|
14
|
+
- `DataCollatorWithPadding`: "input_ids", "attention_mask", "token_type_ids" (optional).
|
|
15
|
+
- `DataCollatorForLanguageModeling`: "input_ids", "attention_mask", "special_tokens_mask" (optional).
|
|
16
|
+
- `DataCollatorForSeq2Seq`: "input_ids", "attention_mask", "labels".
|
|
17
|
+
- `DataCollatorForTokenClassification`: "input_ids", "attention_mask", "labels".
|
|
18
|
+
|
|
19
|
+
This wrapper allows you to map arbitrary dataset keys to these expected names before collation,
|
|
20
|
+
optionally truncating sequences to a maximum length.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
collator: DataCollatorWithPadding | DataCollatorMixin,
|
|
26
|
+
keys_mapping: dict[str, str] | None = None,
|
|
27
|
+
keys_to_keep: set[str] | None = None,
|
|
28
|
+
max_length: int | None = None,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""
|
|
31
|
+
Initialize the BatchCollatorWithKeyAlignment.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
collator: A callable (usually a Hugging Face DataCollator) that takes a list
|
|
35
|
+
of dictionaries and returns a collated batch (e.g., padded tensors).
|
|
36
|
+
keys_mapping: A dictionary mapping original keys to new keys.
|
|
37
|
+
keys_to_keep: A set of keys to retain as-is from the original items.
|
|
38
|
+
max_length: If provided, truncates "input_ids" and "attention_mask" to this length.
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
ValueError: If both `keys_mapping` and `keys_to_keep` are None.
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
if (keys_mapping is None) and (keys_to_keep is None):
|
|
45
|
+
raise ValueError("Either keys_mapping or keys_to_keep must be provided.")
|
|
46
|
+
|
|
47
|
+
if keys_mapping is None:
|
|
48
|
+
keys_mapping = {}
|
|
49
|
+
if keys_to_keep is None:
|
|
50
|
+
keys_to_keep = set()
|
|
51
|
+
|
|
52
|
+
self.collator = collator
|
|
53
|
+
self.keys_mapping = deepcopy(keys_mapping)
|
|
54
|
+
self.max_length = max_length
|
|
55
|
+
|
|
56
|
+
keys_to_keep_mapping = {v: v for v in keys_to_keep}
|
|
57
|
+
self.keys_mapping.update(keys_to_keep_mapping)
|
|
58
|
+
|
|
59
|
+
def _truncate_data(self, key: str, value: Any) -> Any:
|
|
60
|
+
match value:
|
|
61
|
+
case torch.Tensor():
|
|
62
|
+
if value.dim() > 2:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"Expected value with dim <= 2 for key {key}, got {value.dim()}"
|
|
65
|
+
)
|
|
66
|
+
case list():
|
|
67
|
+
if isinstance(value[0], list):
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"Expected value with dim <= 2 for key {key}, got nested lists"
|
|
70
|
+
)
|
|
71
|
+
value = value[: self.max_length]
|
|
72
|
+
return value
|
|
73
|
+
|
|
74
|
+
def __call__(self, batch: list[dict[str, Any]]) -> dict[str, Any]:
|
|
75
|
+
"""
|
|
76
|
+
Align keys and collate the batch.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
batch: A list of dictionaries representing the data batch.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
The collated batch returned by the underlying collator.
|
|
83
|
+
|
|
84
|
+
"""
|
|
85
|
+
aligned_batch = []
|
|
86
|
+
for item in batch:
|
|
87
|
+
new_item = {}
|
|
88
|
+
for k in item.keys():
|
|
89
|
+
new_key = self.keys_mapping.get(k, None)
|
|
90
|
+
if new_key is None:
|
|
91
|
+
continue
|
|
92
|
+
value = item[k]
|
|
93
|
+
if self.max_length is not None and new_key in (
|
|
94
|
+
"input_ids",
|
|
95
|
+
"attention_mask",
|
|
96
|
+
):
|
|
97
|
+
value = self._truncate_data(new_key, value)
|
|
98
|
+
new_item[new_key] = value
|
|
99
|
+
aligned_batch.append(new_item)
|
|
100
|
+
|
|
101
|
+
collated_batch = self.collator(aligned_batch)
|
|
102
|
+
return collated_batch
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
from .checkpoint import setup_checkpoint_callback
|
|
2
2
|
from .early_stopping import setup_early_stopping_callback
|
|
3
|
-
from .registry_uploader import ClearMLRegistryUploaderCallback
|
|
4
3
|
|
|
5
4
|
|
|
6
5
|
__all__ = [
|
|
7
|
-
"ClearMLRegistryUploaderCallback",
|
|
8
6
|
"setup_checkpoint_callback",
|
|
9
7
|
"setup_early_stopping_callback",
|
|
10
8
|
]
|
|
@@ -12,10 +12,9 @@ from lightning.pytorch.callbacks import ModelCheckpoint
|
|
|
12
12
|
from kostyl.ml.configs import CheckpointConfig
|
|
13
13
|
from kostyl.ml.dist_utils import is_main_process
|
|
14
14
|
from kostyl.ml.lightning import KostylLightningModule
|
|
15
|
+
from kostyl.ml.registry_uploader import RegistryUploaderCallback
|
|
15
16
|
from kostyl.utils import setup_logger
|
|
16
17
|
|
|
17
|
-
from .registry_uploader import RegistryUploaderCallback
|
|
18
|
-
|
|
19
18
|
|
|
20
19
|
logger = setup_logger("callbacks/checkpoint.py")
|
|
21
20
|
|
|
@@ -20,12 +20,17 @@ from kostyl.ml.schedulers.base import BaseScheduler
|
|
|
20
20
|
from kostyl.utils import setup_logger
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
|
|
23
|
+
module_logger = setup_logger(fmt="only_message")
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class KostylLightningModule(L.LightningModule):
|
|
27
27
|
"""Custom PyTorch Lightning Module with logging, checkpointing, and distributed training utilities."""
|
|
28
28
|
|
|
29
|
+
@property
|
|
30
|
+
def process_group(self) -> ProcessGroup | None:
|
|
31
|
+
"""Returns the data parallel process group for distributed training."""
|
|
32
|
+
return self.get_process_group()
|
|
33
|
+
|
|
29
34
|
def get_process_group(self) -> ProcessGroup | None:
|
|
30
35
|
"""
|
|
31
36
|
Retrieves the data parallel process group for distributed training.
|
|
@@ -45,7 +50,7 @@ class KostylLightningModule(L.LightningModule):
|
|
|
45
50
|
if self.device_mesh is not None:
|
|
46
51
|
dp_mesh = self.device_mesh["data_parallel"]
|
|
47
52
|
if dp_mesh.size() == 1:
|
|
48
|
-
|
|
53
|
+
module_logger.warning("Data parallel mesh size is 1, returning None")
|
|
49
54
|
return None
|
|
50
55
|
dp_pg = dp_mesh.get_group()
|
|
51
56
|
else:
|
|
@@ -129,11 +134,16 @@ class KostylLightningModule(L.LightningModule):
|
|
|
129
134
|
stage: str | None = None,
|
|
130
135
|
) -> None:
|
|
131
136
|
if stage is not None:
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
+
if not isinstance(dictionary, MetricCollection):
|
|
138
|
+
dictionary = apply_suffix(
|
|
139
|
+
metrics=dictionary,
|
|
140
|
+
suffix=stage,
|
|
141
|
+
add_dist_rank=False,
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
module_logger.warning_once(
|
|
145
|
+
"Stage suffixing for MetricCollection is not implemented. Skipping suffixing."
|
|
146
|
+
)
|
|
137
147
|
super().log_dict(
|
|
138
148
|
dictionary,
|
|
139
149
|
prog_bar,
|
|
@@ -161,9 +171,12 @@ class KostylLightningModule(L.LightningModule):
|
|
|
161
171
|
"""
|
|
162
172
|
scheduler: BaseScheduler = self.lr_schedulers() # type: ignore
|
|
163
173
|
if not isinstance(scheduler, BaseScheduler):
|
|
174
|
+
module_logger.warning_once(
|
|
175
|
+
"Scheduler is not an instance of BaseScheduler. Skipping scheduled values logging."
|
|
176
|
+
)
|
|
164
177
|
return
|
|
165
178
|
scheduler_state_dict = scheduler.current_value()
|
|
166
|
-
scheduler_state_dict = apply_suffix(scheduler_state_dict, "
|
|
179
|
+
scheduler_state_dict = apply_suffix(scheduler_state_dict, "scheduled")
|
|
167
180
|
self.log_dict(
|
|
168
181
|
scheduler_state_dict,
|
|
169
182
|
prog_bar=False,
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from dataclasses import fields
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Literal
|
|
5
|
+
from typing import cast
|
|
6
|
+
|
|
7
|
+
import lightning as L
|
|
8
|
+
import torch
|
|
9
|
+
import torch.distributed as dist
|
|
10
|
+
from clearml import OutputModel
|
|
11
|
+
from clearml import Task
|
|
12
|
+
from lightning.pytorch.callbacks import Callback
|
|
13
|
+
from lightning.pytorch.callbacks import EarlyStopping
|
|
14
|
+
from lightning.pytorch.callbacks import LearningRateMonitor
|
|
15
|
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
16
|
+
from lightning.pytorch.loggers import TensorBoardLogger
|
|
17
|
+
from lightning.pytorch.strategies import DDPStrategy
|
|
18
|
+
from lightning.pytorch.strategies import FSDPStrategy
|
|
19
|
+
from torch.distributed import ProcessGroup
|
|
20
|
+
from torch.distributed.fsdp import MixedPrecision
|
|
21
|
+
from torch.nn import Module
|
|
22
|
+
|
|
23
|
+
from kostyl.ml.configs import CheckpointConfig
|
|
24
|
+
from kostyl.ml.configs import DDPStrategyConfig
|
|
25
|
+
from kostyl.ml.configs import EarlyStoppingConfig
|
|
26
|
+
from kostyl.ml.configs import FSDP1StrategyConfig
|
|
27
|
+
from kostyl.ml.configs import SingleDeviceStrategyConfig
|
|
28
|
+
from kostyl.ml.lightning.callbacks import setup_checkpoint_callback
|
|
29
|
+
from kostyl.ml.lightning.callbacks import setup_early_stopping_callback
|
|
30
|
+
from kostyl.ml.lightning.loggers import setup_tb_logger
|
|
31
|
+
from kostyl.ml.registry_uploader import ClearMLRegistryUploaderCallback
|
|
32
|
+
from kostyl.utils.logging import setup_logger
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
TRAINING_STRATEGIES = (
|
|
36
|
+
FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
logger = setup_logger(add_rank=True)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def estimate_total_steps(
|
|
43
|
+
trainer: L.Trainer, process_group: ProcessGroup | None = None
|
|
44
|
+
) -> int:
|
|
45
|
+
"""
|
|
46
|
+
Estimates the total number of training steps based on the
|
|
47
|
+
dataloader length, accumulation steps, and distributed world size.
|
|
48
|
+
""" # noqa: D205
|
|
49
|
+
if dist.is_initialized():
|
|
50
|
+
world_size = dist.get_world_size(process_group)
|
|
51
|
+
else:
|
|
52
|
+
world_size = 1
|
|
53
|
+
|
|
54
|
+
datamodule = trainer.datamodule # type: ignore
|
|
55
|
+
if datamodule is None:
|
|
56
|
+
raise ValueError("Trainer must have a datamodule to estimate total steps.")
|
|
57
|
+
datamodule = cast(L.LightningDataModule, datamodule)
|
|
58
|
+
|
|
59
|
+
logger.info("Loading `train_dataloader` to estimate number of stepping batches.")
|
|
60
|
+
datamodule.setup("fit")
|
|
61
|
+
|
|
62
|
+
dataloader_len = len(datamodule.train_dataloader())
|
|
63
|
+
steps_per_epoch = dataloader_len // trainer.accumulate_grad_batches // world_size
|
|
64
|
+
|
|
65
|
+
if trainer.max_epochs is None:
|
|
66
|
+
raise ValueError("Trainer must have `max_epochs` set to estimate total steps.")
|
|
67
|
+
total_steps = steps_per_epoch * trainer.max_epochs
|
|
68
|
+
|
|
69
|
+
logger.info(
|
|
70
|
+
f"Total steps: {total_steps} (per-epoch: {steps_per_epoch})\n"
|
|
71
|
+
f"-> Dataloader len: {dataloader_len}\n"
|
|
72
|
+
f"-> Accumulate grad batches: {trainer.accumulate_grad_batches}\n"
|
|
73
|
+
f"-> Epochs: {trainer.max_epochs}\n "
|
|
74
|
+
f"-> World size: {world_size}"
|
|
75
|
+
)
|
|
76
|
+
return total_steps
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass
|
|
80
|
+
class Callbacks:
|
|
81
|
+
"""Dataclass to hold PyTorch Lightning callbacks."""
|
|
82
|
+
|
|
83
|
+
checkpoint: ModelCheckpoint
|
|
84
|
+
lr_monitor: LearningRateMonitor
|
|
85
|
+
early_stopping: EarlyStopping | None = None
|
|
86
|
+
|
|
87
|
+
def to_list(self) -> list[Callback]:
|
|
88
|
+
"""Convert dataclass fields to a list of Callbacks. None values are omitted."""
|
|
89
|
+
callbacks: list[Callback] = [
|
|
90
|
+
getattr(self, field.name)
|
|
91
|
+
for field in fields(self)
|
|
92
|
+
if getattr(self, field.name) is not None
|
|
93
|
+
]
|
|
94
|
+
return callbacks
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def setup_callbacks(
|
|
98
|
+
task: Task,
|
|
99
|
+
root_path: Path,
|
|
100
|
+
checkpoint_cfg: CheckpointConfig,
|
|
101
|
+
uploading_mode: Literal["only-best", "every-checkpoint"],
|
|
102
|
+
output_model: OutputModel,
|
|
103
|
+
early_stopping_cfg: EarlyStoppingConfig | None = None,
|
|
104
|
+
config_dict: dict[str, str] | None = None,
|
|
105
|
+
enable_tag_versioning: bool = False,
|
|
106
|
+
) -> Callbacks:
|
|
107
|
+
"""
|
|
108
|
+
Set up PyTorch Lightning callbacks for training.
|
|
109
|
+
|
|
110
|
+
Creates and configures a set of callbacks including checkpoint saving,
|
|
111
|
+
learning rate monitoring, model registry uploading, and optional early stopping.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
task: ClearML task for organizing checkpoints by task name and ID.
|
|
115
|
+
root_path: Root directory for saving checkpoints.
|
|
116
|
+
checkpoint_cfg: Configuration for checkpoint saving behavior.
|
|
117
|
+
uploading_mode: Model upload strategy:
|
|
118
|
+
- `"only-best"`: Upload only the best checkpoint based on monitored metric.
|
|
119
|
+
- `"every-checkpoint"`: Upload every saved checkpoint.
|
|
120
|
+
output_model: ClearML OutputModel instance for model registry integration.
|
|
121
|
+
early_stopping_cfg: Configuration for early stopping. If None, early stopping
|
|
122
|
+
is disabled.
|
|
123
|
+
config_dict: Optional configuration dictionary to store with the model
|
|
124
|
+
in the registry.
|
|
125
|
+
enable_tag_versioning: Whether to auto-increment version tags (e.g., "v1.0")
|
|
126
|
+
on the uploaded model.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Callbacks dataclass containing configured checkpoint, lr_monitor,
|
|
130
|
+
and optionally early_stopping callbacks.
|
|
131
|
+
|
|
132
|
+
"""
|
|
133
|
+
lr_monitor = LearningRateMonitor(
|
|
134
|
+
logging_interval="step", log_weight_decay=True, log_momentum=False
|
|
135
|
+
)
|
|
136
|
+
model_uploader = ClearMLRegistryUploaderCallback(
|
|
137
|
+
output_model=output_model,
|
|
138
|
+
config_dict=config_dict,
|
|
139
|
+
verbose=True,
|
|
140
|
+
enable_tag_versioning=enable_tag_versioning,
|
|
141
|
+
)
|
|
142
|
+
checkpoint_callback = setup_checkpoint_callback(
|
|
143
|
+
root_path / "checkpoints" / task.name / task.id,
|
|
144
|
+
checkpoint_cfg,
|
|
145
|
+
registry_uploader_callback=model_uploader,
|
|
146
|
+
uploading_mode=uploading_mode,
|
|
147
|
+
)
|
|
148
|
+
if early_stopping_cfg is not None:
|
|
149
|
+
early_stopping_callback = setup_early_stopping_callback(early_stopping_cfg)
|
|
150
|
+
else:
|
|
151
|
+
early_stopping_callback = None
|
|
152
|
+
|
|
153
|
+
callbacks = Callbacks(
|
|
154
|
+
checkpoint=checkpoint_callback,
|
|
155
|
+
lr_monitor=lr_monitor,
|
|
156
|
+
early_stopping=early_stopping_callback,
|
|
157
|
+
)
|
|
158
|
+
return callbacks
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def setup_loggers(task: Task, root_path: Path) -> list[TensorBoardLogger]:
|
|
162
|
+
"""
|
|
163
|
+
Set up PyTorch Lightning loggers for training.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
task: ClearML task used to organize log directories by task name and ID.
|
|
167
|
+
root_path: Root directory for storing TensorBoard logs.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
List of configured TensorBoard loggers.
|
|
171
|
+
|
|
172
|
+
"""
|
|
173
|
+
loggers = [
|
|
174
|
+
setup_tb_logger(root_path / "runs" / task.name / task.id),
|
|
175
|
+
]
|
|
176
|
+
return loggers
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def setup_strategy(
|
|
180
|
+
strategy_settings: TRAINING_STRATEGIES,
|
|
181
|
+
devices: list[int] | int,
|
|
182
|
+
auto_wrap_policy: set[type[Module]] | None = None,
|
|
183
|
+
) -> Literal["auto"] | FSDPStrategy | DDPStrategy:
|
|
184
|
+
"""
|
|
185
|
+
Configure and return a PyTorch Lightning training strategy.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
strategy_settings: Strategy configuration object. Must be one of:
|
|
189
|
+
- `FSDP1StrategyConfig`: Fully Sharded Data Parallel strategy (requires 2+ devices).
|
|
190
|
+
- `DDPStrategyConfig`: Distributed Data Parallel strategy (requires 2+ devices).
|
|
191
|
+
- `SingleDeviceStrategyConfig`: Single device training (requires exactly 1 device).
|
|
192
|
+
devices: Device(s) to use for training. Either a list of device IDs or
|
|
193
|
+
a single integer representing the number of devices.
|
|
194
|
+
auto_wrap_policy: Set of module types that should be wrapped for FSDP.
|
|
195
|
+
Required when using `FSDP1StrategyConfig`, ignored otherwise.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Configured strategy: `FSDPStrategy`, `DDPStrategy`, or `"auto"` for single device.
|
|
199
|
+
|
|
200
|
+
Raises:
|
|
201
|
+
ValueError: If device count doesn't match strategy requirements or
|
|
202
|
+
if `auto_wrap_policy` is missing for FSDP.
|
|
203
|
+
|
|
204
|
+
"""
|
|
205
|
+
if isinstance(devices, list):
|
|
206
|
+
num_devices = len(devices)
|
|
207
|
+
else:
|
|
208
|
+
num_devices = devices
|
|
209
|
+
|
|
210
|
+
match strategy_settings:
|
|
211
|
+
case FSDP1StrategyConfig():
|
|
212
|
+
if num_devices < 2:
|
|
213
|
+
raise ValueError("FSDP strategy requires multiple devices.")
|
|
214
|
+
|
|
215
|
+
if auto_wrap_policy is None:
|
|
216
|
+
raise ValueError("auto_wrap_policy must be provided for FSDP strategy.")
|
|
217
|
+
|
|
218
|
+
mixed_precision_config = MixedPrecision(
|
|
219
|
+
param_dtype=getattr(torch, strategy_settings.param_dtype),
|
|
220
|
+
reduce_dtype=getattr(torch, strategy_settings.reduce_dtype),
|
|
221
|
+
buffer_dtype=getattr(torch, strategy_settings.buffer_dtype),
|
|
222
|
+
)
|
|
223
|
+
strategy = FSDPStrategy(
|
|
224
|
+
auto_wrap_policy=auto_wrap_policy,
|
|
225
|
+
mixed_precision=mixed_precision_config,
|
|
226
|
+
)
|
|
227
|
+
case DDPStrategyConfig():
|
|
228
|
+
if num_devices < 2:
|
|
229
|
+
raise ValueError("DDP strategy requires at least two devices.")
|
|
230
|
+
strategy = DDPStrategy(
|
|
231
|
+
find_unused_parameters=strategy_settings.find_unused_parameters
|
|
232
|
+
)
|
|
233
|
+
case SingleDeviceStrategyConfig():
|
|
234
|
+
if num_devices != 1:
|
|
235
|
+
raise ValueError("SingleDevice strategy requires exactly one device.")
|
|
236
|
+
strategy = "auto"
|
|
237
|
+
case _:
|
|
238
|
+
raise ValueError(
|
|
239
|
+
f"Unsupported strategy type: {type(strategy_settings.trainer.strategy)}"
|
|
240
|
+
)
|
|
241
|
+
return strategy
|
|
@@ -3,14 +3,13 @@ from collections.abc import Mapping
|
|
|
3
3
|
import torch.distributed as dist
|
|
4
4
|
from torch import Tensor
|
|
5
5
|
from torchmetrics import Metric
|
|
6
|
-
from torchmetrics import MetricCollection
|
|
7
6
|
|
|
8
7
|
|
|
9
8
|
def apply_suffix(
|
|
10
|
-
metrics: Mapping[str, Metric | Tensor | int | float]
|
|
9
|
+
metrics: Mapping[str, Metric | Tensor | int | float],
|
|
11
10
|
suffix: str,
|
|
12
11
|
add_dist_rank: bool = False,
|
|
13
|
-
) -> Mapping[str, Metric | Tensor | int | float]
|
|
12
|
+
) -> Mapping[str, Metric | Tensor | int | float]:
|
|
14
13
|
"""Add stage prefix to metric names."""
|
|
15
14
|
new_metrics_dict = {}
|
|
16
15
|
for key, value in metrics.items():
|
|
@@ -5,7 +5,6 @@ from pathlib import Path
|
|
|
5
5
|
from typing import override
|
|
6
6
|
|
|
7
7
|
from clearml import OutputModel
|
|
8
|
-
from clearml import Task
|
|
9
8
|
|
|
10
9
|
from kostyl.ml.clearml.logging_utils import find_version_in_tags
|
|
11
10
|
from kostyl.ml.clearml.logging_utils import increment_version
|
|
@@ -29,69 +28,50 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
|
29
28
|
|
|
30
29
|
def __init__(
|
|
31
30
|
self,
|
|
32
|
-
|
|
33
|
-
output_model_name: str,
|
|
34
|
-
output_model_tags: list[str] | None = None,
|
|
35
|
-
verbose: bool = True,
|
|
36
|
-
enable_tag_versioning: bool = True,
|
|
37
|
-
label_enumeration: dict[str, int] | None = None,
|
|
31
|
+
output_model: OutputModel,
|
|
38
32
|
config_dict: dict[str, str] | None = None,
|
|
33
|
+
verbose: bool = True,
|
|
34
|
+
enable_tag_versioning: bool = False,
|
|
39
35
|
) -> None:
|
|
40
36
|
"""
|
|
41
37
|
Initializes the ClearMLRegistryUploaderCallback.
|
|
42
38
|
|
|
43
39
|
Args:
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
output_model_name: Name for the ClearML output model.
|
|
47
|
-
output_model_tags: Tags for the output model.
|
|
48
|
-
verbose: Whether to log messages.
|
|
49
|
-
label_enumeration: Optional mapping of label names to integer IDs.
|
|
40
|
+
output_model: ClearML OutputModel instance representing the model to upload.
|
|
41
|
+
verbose: Whether to log messages during upload.
|
|
50
42
|
config_dict: Optional configuration dictionary to associate with the model.
|
|
51
43
|
enable_tag_versioning: Whether to enable versioning in tags. If True,
|
|
52
44
|
the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
|
|
53
45
|
|
|
54
46
|
"""
|
|
55
47
|
super().__init__()
|
|
56
|
-
|
|
57
|
-
output_model_tags = []
|
|
58
|
-
|
|
59
|
-
self.task = task
|
|
60
|
-
self.output_model_name = output_model_name
|
|
61
|
-
self.output_model_tags = output_model_tags
|
|
48
|
+
self.output_model = output_model
|
|
62
49
|
self.config_dict = config_dict
|
|
63
|
-
self.label_enumeration = label_enumeration
|
|
64
50
|
self.verbose = verbose
|
|
65
51
|
self.enable_tag_versioning = enable_tag_versioning
|
|
66
52
|
|
|
67
53
|
self.best_model_path: str = ""
|
|
68
54
|
|
|
69
|
-
self._output_model: OutputModel | None = None
|
|
70
55
|
self._last_uploaded_model_path: str = ""
|
|
71
56
|
self._upload_callback: Callable | None = None
|
|
57
|
+
|
|
58
|
+
self._validate_tags()
|
|
72
59
|
return
|
|
73
60
|
|
|
74
|
-
def
|
|
61
|
+
def _validate_tags(self) -> None:
|
|
62
|
+
output_model_tags = self.output_model.tags or []
|
|
75
63
|
if self.enable_tag_versioning:
|
|
76
|
-
version = find_version_in_tags(
|
|
64
|
+
version = find_version_in_tags(output_model_tags)
|
|
77
65
|
if version is None:
|
|
78
|
-
|
|
66
|
+
output_model_tags.append("v1.0")
|
|
79
67
|
else:
|
|
80
68
|
new_version = increment_version(version)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
return OutputModel(
|
|
88
|
-
task=self.task,
|
|
89
|
-
name=self.output_model_name,
|
|
90
|
-
framework="PyTorch",
|
|
91
|
-
tags=self.output_model_tags,
|
|
92
|
-
config_dict=None,
|
|
93
|
-
label_enumeration=self.label_enumeration,
|
|
94
|
-
)
|
|
69
|
+
output_model_tags.remove(version)
|
|
70
|
+
output_model_tags.append(new_version)
|
|
71
|
+
if "LightningCheckpoint" not in output_model_tags:
|
|
72
|
+
output_model_tags.append("LightningCheckpoint")
|
|
73
|
+
self.output_model.tags = output_model_tags
|
|
74
|
+
return None
|
|
95
75
|
|
|
96
76
|
@override
|
|
97
77
|
def upload_checkpoint(
|
|
@@ -105,18 +85,15 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
|
105
85
|
logger.info("Model unchanged since last upload")
|
|
106
86
|
return
|
|
107
87
|
|
|
108
|
-
if self._output_model is None:
|
|
109
|
-
self._output_model = self._create_output_model()
|
|
110
|
-
|
|
111
88
|
if self.verbose:
|
|
112
89
|
logger.info(f"Uploading model from {path}")
|
|
113
90
|
|
|
114
|
-
self.
|
|
91
|
+
self.output_model.update_weights(
|
|
115
92
|
path,
|
|
116
93
|
auto_delete_file=False,
|
|
117
94
|
async_enable=False,
|
|
118
95
|
)
|
|
119
|
-
self.
|
|
96
|
+
self.output_model.update_design(config_dict=self.config_dict)
|
|
120
97
|
|
|
121
98
|
self._last_uploaded_model_path = path
|
|
122
99
|
return
|
|
@@ -1,44 +0,0 @@
|
|
|
1
|
-
from typing import cast
|
|
2
|
-
|
|
3
|
-
import lightning as L
|
|
4
|
-
import torch.distributed as dist
|
|
5
|
-
from torch.distributed import ProcessGroup
|
|
6
|
-
|
|
7
|
-
from kostyl.utils.logging import setup_logger
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
logger = setup_logger(add_rank=True)
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def estimate_total_steps(
|
|
14
|
-
trainer: L.Trainer, process_group: ProcessGroup | None = None
|
|
15
|
-
) -> int:
|
|
16
|
-
"""Estimates the total number of training steps for a given PyTorch Lightning Trainer."""
|
|
17
|
-
if dist.is_initialized():
|
|
18
|
-
world_size = dist.get_world_size(process_group)
|
|
19
|
-
else:
|
|
20
|
-
world_size = 1
|
|
21
|
-
|
|
22
|
-
datamodule = trainer.datamodule # type: ignore
|
|
23
|
-
if datamodule is None:
|
|
24
|
-
raise ValueError("Trainer must have a datamodule to estimate total steps.")
|
|
25
|
-
datamodule = cast(L.LightningDataModule, datamodule)
|
|
26
|
-
|
|
27
|
-
logger.info("Loading `train_dataloader` to estimate number of stepping batches.")
|
|
28
|
-
datamodule.setup("fit")
|
|
29
|
-
|
|
30
|
-
dataloader_len = len(datamodule.train_dataloader())
|
|
31
|
-
steps_per_epoch = dataloader_len // trainer.accumulate_grad_batches // world_size
|
|
32
|
-
|
|
33
|
-
if trainer.max_epochs is None:
|
|
34
|
-
raise ValueError("Trainer must have `max_epochs` set to estimate total steps.")
|
|
35
|
-
total_steps = steps_per_epoch * trainer.max_epochs
|
|
36
|
-
|
|
37
|
-
logger.info(
|
|
38
|
-
f"Total steps: {total_steps} (per-epoch: {steps_per_epoch})\n"
|
|
39
|
-
f"-> Dataloader len: {dataloader_len}\n"
|
|
40
|
-
f"-> Accumulate grad batches: {trainer.accumulate_grad_batches}\n"
|
|
41
|
-
f"-> Epochs: {trainer.max_epochs}\n "
|
|
42
|
-
f"-> World size: {world_size}"
|
|
43
|
-
)
|
|
44
|
-
return total_steps
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/lightning/callbacks/early_stopping.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|