kostyl-toolkit 0.1.30__tar.gz → 0.1.32__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/PKG-INFO +1 -1
  2. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/clearml/pulling_utils.py +1 -1
  3. kostyl_toolkit-0.1.32/kostyl/ml/data_processing_utils.py +102 -0
  4. kostyl_toolkit-0.1.32/kostyl/ml/lightning/__init__.py +5 -0
  5. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/lightning/callbacks/__init__.py +0 -2
  6. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/lightning/callbacks/checkpoint.py +1 -2
  7. {kostyl_toolkit-0.1.30/kostyl/ml/lightning/extenstions → kostyl_toolkit-0.1.32/kostyl/ml/lightning/extensions}/custom_module.py +21 -8
  8. kostyl_toolkit-0.1.32/kostyl/ml/lightning/training_utils.py +241 -0
  9. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/metrics_formatting.py +2 -3
  10. {kostyl_toolkit-0.1.30/kostyl/ml/lightning/callbacks → kostyl_toolkit-0.1.32/kostyl/ml}/registry_uploader.py +20 -43
  11. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/pyproject.toml +1 -1
  12. kostyl_toolkit-0.1.30/kostyl/ml/lightning/__init__.py +0 -5
  13. kostyl_toolkit-0.1.30/kostyl/ml/lightning/steps_estimation.py +0 -44
  14. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/README.md +0 -0
  15. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/__init__.py +0 -0
  16. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/__init__.py +0 -0
  17. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/clearml/__init__.py +0 -0
  18. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/clearml/dataset_utils.py +0 -0
  19. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/clearml/logging_utils.py +0 -0
  20. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/configs/__init__.py +0 -0
  21. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/configs/base_model.py +0 -0
  22. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/configs/hyperparams.py +0 -0
  23. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/configs/training_settings.py +0 -0
  24. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/dist_utils.py +0 -0
  25. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/lightning/callbacks/early_stopping.py +0 -0
  26. {kostyl_toolkit-0.1.30/kostyl/ml/lightning/extenstions → kostyl_toolkit-0.1.32/kostyl/ml/lightning/extensions}/__init__.py +0 -0
  27. {kostyl_toolkit-0.1.30/kostyl/ml/lightning/extenstions → kostyl_toolkit-0.1.32/kostyl/ml/lightning/extensions}/pretrained_model.py +0 -0
  28. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/lightning/loggers/__init__.py +0 -0
  29. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/lightning/loggers/tb_logger.py +0 -0
  30. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/params_groups.py +0 -0
  31. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/schedulers/__init__.py +0 -0
  32. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/schedulers/base.py +0 -0
  33. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/schedulers/composite.py +0 -0
  34. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/schedulers/cosine.py +0 -0
  35. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/ml/schedulers/linear.py +0 -0
  36. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/utils/__init__.py +0 -0
  37. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/utils/dict_manipulations.py +0 -0
  38. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/utils/fs.py +0 -0
  39. {kostyl_toolkit-0.1.30 → kostyl_toolkit-0.1.32}/kostyl/utils/logging.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.30
3
+ Version: 0.1.32
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -9,7 +9,7 @@ from transformers import AutoTokenizer
9
9
  from transformers import PreTrainedModel
10
10
  from transformers import PreTrainedTokenizerBase
11
11
 
12
- from kostyl.ml.lightning.extenstions.pretrained_model import (
12
+ from kostyl.ml.lightning.extensions.pretrained_model import (
13
13
  LightningCheckpointLoaderMixin,
14
14
  )
15
15
 
@@ -0,0 +1,102 @@
1
+ from copy import deepcopy
2
+ from typing import Any
3
+
4
+ import torch
5
+ from transformers import DataCollatorWithPadding
6
+ from transformers.data.data_collator import DataCollatorMixin
7
+
8
+
9
+ class BatchCollatorWithKeyAlignment:
10
+ """
11
+ Maps dataset keys to HuggingFace DataCollator expected keys and collates the batch.
12
+
13
+ HuggingFace collators expect specific keys depending on the collator type:
14
+ - `DataCollatorWithPadding`: "input_ids", "attention_mask", "token_type_ids" (optional).
15
+ - `DataCollatorForLanguageModeling`: "input_ids", "attention_mask", "special_tokens_mask" (optional).
16
+ - `DataCollatorForSeq2Seq`: "input_ids", "attention_mask", "labels".
17
+ - `DataCollatorForTokenClassification`: "input_ids", "attention_mask", "labels".
18
+
19
+ This wrapper allows you to map arbitrary dataset keys to these expected names before collation,
20
+ optionally truncating sequences to a maximum length.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ collator: DataCollatorWithPadding | DataCollatorMixin,
26
+ keys_mapping: dict[str, str] | None = None,
27
+ keys_to_keep: set[str] | None = None,
28
+ max_length: int | None = None,
29
+ ) -> None:
30
+ """
31
+ Initialize the BatchCollatorWithKeyAlignment.
32
+
33
+ Args:
34
+ collator: A callable (usually a Hugging Face DataCollator) that takes a list
35
+ of dictionaries and returns a collated batch (e.g., padded tensors).
36
+ keys_mapping: A dictionary mapping original keys to new keys.
37
+ keys_to_keep: A set of keys to retain as-is from the original items.
38
+ max_length: If provided, truncates "input_ids" and "attention_mask" to this length.
39
+
40
+ Raises:
41
+ ValueError: If both `keys_mapping` and `keys_to_keep` are None.
42
+
43
+ """
44
+ if (keys_mapping is None) and (keys_to_keep is None):
45
+ raise ValueError("Either keys_mapping or keys_to_keep must be provided.")
46
+
47
+ if keys_mapping is None:
48
+ keys_mapping = {}
49
+ if keys_to_keep is None:
50
+ keys_to_keep = set()
51
+
52
+ self.collator = collator
53
+ self.keys_mapping = deepcopy(keys_mapping)
54
+ self.max_length = max_length
55
+
56
+ keys_to_keep_mapping = {v: v for v in keys_to_keep}
57
+ self.keys_mapping.update(keys_to_keep_mapping)
58
+
59
+ def _truncate_data(self, key: str, value: Any) -> Any:
60
+ match value:
61
+ case torch.Tensor():
62
+ if value.dim() > 2:
63
+ raise ValueError(
64
+ f"Expected value with dim <= 2 for key {key}, got {value.dim()}"
65
+ )
66
+ case list():
67
+ if isinstance(value[0], list):
68
+ raise ValueError(
69
+ f"Expected value with dim <= 2 for key {key}, got nested lists"
70
+ )
71
+ value = value[: self.max_length]
72
+ return value
73
+
74
+ def __call__(self, batch: list[dict[str, Any]]) -> dict[str, Any]:
75
+ """
76
+ Align keys and collate the batch.
77
+
78
+ Args:
79
+ batch: A list of dictionaries representing the data batch.
80
+
81
+ Returns:
82
+ The collated batch returned by the underlying collator.
83
+
84
+ """
85
+ aligned_batch = []
86
+ for item in batch:
87
+ new_item = {}
88
+ for k in item.keys():
89
+ new_key = self.keys_mapping.get(k, None)
90
+ if new_key is None:
91
+ continue
92
+ value = item[k]
93
+ if self.max_length is not None and new_key in (
94
+ "input_ids",
95
+ "attention_mask",
96
+ ):
97
+ value = self._truncate_data(new_key, value)
98
+ new_item[new_key] = value
99
+ aligned_batch.append(new_item)
100
+
101
+ collated_batch = self.collator(aligned_batch)
102
+ return collated_batch
@@ -0,0 +1,5 @@
1
+ from .extensions import KostylLightningModule
2
+ from .extensions import LightningCheckpointLoaderMixin
3
+
4
+
5
+ __all__ = ["KostylLightningModule", "LightningCheckpointLoaderMixin"]
@@ -1,10 +1,8 @@
1
1
  from .checkpoint import setup_checkpoint_callback
2
2
  from .early_stopping import setup_early_stopping_callback
3
- from .registry_uploader import ClearMLRegistryUploaderCallback
4
3
 
5
4
 
6
5
  __all__ = [
7
- "ClearMLRegistryUploaderCallback",
8
6
  "setup_checkpoint_callback",
9
7
  "setup_early_stopping_callback",
10
8
  ]
@@ -12,10 +12,9 @@ from lightning.pytorch.callbacks import ModelCheckpoint
12
12
  from kostyl.ml.configs import CheckpointConfig
13
13
  from kostyl.ml.dist_utils import is_main_process
14
14
  from kostyl.ml.lightning import KostylLightningModule
15
+ from kostyl.ml.registry_uploader import RegistryUploaderCallback
15
16
  from kostyl.utils import setup_logger
16
17
 
17
- from .registry_uploader import RegistryUploaderCallback
18
-
19
18
 
20
19
  logger = setup_logger("callbacks/checkpoint.py")
21
20
 
@@ -20,12 +20,17 @@ from kostyl.ml.schedulers.base import BaseScheduler
20
20
  from kostyl.utils import setup_logger
21
21
 
22
22
 
23
- logger = setup_logger(fmt="only_message")
23
+ module_logger = setup_logger(fmt="only_message")
24
24
 
25
25
 
26
26
  class KostylLightningModule(L.LightningModule):
27
27
  """Custom PyTorch Lightning Module with logging, checkpointing, and distributed training utilities."""
28
28
 
29
+ @property
30
+ def process_group(self) -> ProcessGroup | None:
31
+ """Returns the data parallel process group for distributed training."""
32
+ return self.get_process_group()
33
+
29
34
  def get_process_group(self) -> ProcessGroup | None:
30
35
  """
31
36
  Retrieves the data parallel process group for distributed training.
@@ -45,7 +50,7 @@ class KostylLightningModule(L.LightningModule):
45
50
  if self.device_mesh is not None:
46
51
  dp_mesh = self.device_mesh["data_parallel"]
47
52
  if dp_mesh.size() == 1:
48
- logger.warning("Data parallel mesh size is 1, returning None")
53
+ module_logger.warning("Data parallel mesh size is 1, returning None")
49
54
  return None
50
55
  dp_pg = dp_mesh.get_group()
51
56
  else:
@@ -129,11 +134,16 @@ class KostylLightningModule(L.LightningModule):
129
134
  stage: str | None = None,
130
135
  ) -> None:
131
136
  if stage is not None:
132
- dictionary = apply_suffix(
133
- metrics=dictionary,
134
- suffix=stage,
135
- add_dist_rank=False,
136
- )
137
+ if not isinstance(dictionary, MetricCollection):
138
+ dictionary = apply_suffix(
139
+ metrics=dictionary,
140
+ suffix=stage,
141
+ add_dist_rank=False,
142
+ )
143
+ else:
144
+ module_logger.warning_once(
145
+ "Stage suffixing for MetricCollection is not implemented. Skipping suffixing."
146
+ )
137
147
  super().log_dict(
138
148
  dictionary,
139
149
  prog_bar,
@@ -161,9 +171,12 @@ class KostylLightningModule(L.LightningModule):
161
171
  """
162
172
  scheduler: BaseScheduler = self.lr_schedulers() # type: ignore
163
173
  if not isinstance(scheduler, BaseScheduler):
174
+ module_logger.warning_once(
175
+ "Scheduler is not an instance of BaseScheduler. Skipping scheduled values logging."
176
+ )
164
177
  return
165
178
  scheduler_state_dict = scheduler.current_value()
166
- scheduler_state_dict = apply_suffix(scheduler_state_dict, "scheduler")
179
+ scheduler_state_dict = apply_suffix(scheduler_state_dict, "scheduled")
167
180
  self.log_dict(
168
181
  scheduler_state_dict,
169
182
  prog_bar=False,
@@ -0,0 +1,241 @@
1
+ from dataclasses import dataclass
2
+ from dataclasses import fields
3
+ from pathlib import Path
4
+ from typing import Literal
5
+ from typing import cast
6
+
7
+ import lightning as L
8
+ import torch
9
+ import torch.distributed as dist
10
+ from clearml import OutputModel
11
+ from clearml import Task
12
+ from lightning.pytorch.callbacks import Callback
13
+ from lightning.pytorch.callbacks import EarlyStopping
14
+ from lightning.pytorch.callbacks import LearningRateMonitor
15
+ from lightning.pytorch.callbacks import ModelCheckpoint
16
+ from lightning.pytorch.loggers import TensorBoardLogger
17
+ from lightning.pytorch.strategies import DDPStrategy
18
+ from lightning.pytorch.strategies import FSDPStrategy
19
+ from torch.distributed import ProcessGroup
20
+ from torch.distributed.fsdp import MixedPrecision
21
+ from torch.nn import Module
22
+
23
+ from kostyl.ml.configs import CheckpointConfig
24
+ from kostyl.ml.configs import DDPStrategyConfig
25
+ from kostyl.ml.configs import EarlyStoppingConfig
26
+ from kostyl.ml.configs import FSDP1StrategyConfig
27
+ from kostyl.ml.configs import SingleDeviceStrategyConfig
28
+ from kostyl.ml.lightning.callbacks import setup_checkpoint_callback
29
+ from kostyl.ml.lightning.callbacks import setup_early_stopping_callback
30
+ from kostyl.ml.lightning.loggers import setup_tb_logger
31
+ from kostyl.ml.registry_uploader import ClearMLRegistryUploaderCallback
32
+ from kostyl.utils.logging import setup_logger
33
+
34
+
35
+ TRAINING_STRATEGIES = (
36
+ FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
37
+ )
38
+
39
+ logger = setup_logger(add_rank=True)
40
+
41
+
42
+ def estimate_total_steps(
43
+ trainer: L.Trainer, process_group: ProcessGroup | None = None
44
+ ) -> int:
45
+ """
46
+ Estimates the total number of training steps based on the
47
+ dataloader length, accumulation steps, and distributed world size.
48
+ """ # noqa: D205
49
+ if dist.is_initialized():
50
+ world_size = dist.get_world_size(process_group)
51
+ else:
52
+ world_size = 1
53
+
54
+ datamodule = trainer.datamodule # type: ignore
55
+ if datamodule is None:
56
+ raise ValueError("Trainer must have a datamodule to estimate total steps.")
57
+ datamodule = cast(L.LightningDataModule, datamodule)
58
+
59
+ logger.info("Loading `train_dataloader` to estimate number of stepping batches.")
60
+ datamodule.setup("fit")
61
+
62
+ dataloader_len = len(datamodule.train_dataloader())
63
+ steps_per_epoch = dataloader_len // trainer.accumulate_grad_batches // world_size
64
+
65
+ if trainer.max_epochs is None:
66
+ raise ValueError("Trainer must have `max_epochs` set to estimate total steps.")
67
+ total_steps = steps_per_epoch * trainer.max_epochs
68
+
69
+ logger.info(
70
+ f"Total steps: {total_steps} (per-epoch: {steps_per_epoch})\n"
71
+ f"-> Dataloader len: {dataloader_len}\n"
72
+ f"-> Accumulate grad batches: {trainer.accumulate_grad_batches}\n"
73
+ f"-> Epochs: {trainer.max_epochs}\n "
74
+ f"-> World size: {world_size}"
75
+ )
76
+ return total_steps
77
+
78
+
79
+ @dataclass
80
+ class Callbacks:
81
+ """Dataclass to hold PyTorch Lightning callbacks."""
82
+
83
+ checkpoint: ModelCheckpoint
84
+ lr_monitor: LearningRateMonitor
85
+ early_stopping: EarlyStopping | None = None
86
+
87
+ def to_list(self) -> list[Callback]:
88
+ """Convert dataclass fields to a list of Callbacks. None values are omitted."""
89
+ callbacks: list[Callback] = [
90
+ getattr(self, field.name)
91
+ for field in fields(self)
92
+ if getattr(self, field.name) is not None
93
+ ]
94
+ return callbacks
95
+
96
+
97
+ def setup_callbacks(
98
+ task: Task,
99
+ root_path: Path,
100
+ checkpoint_cfg: CheckpointConfig,
101
+ uploading_mode: Literal["only-best", "every-checkpoint"],
102
+ output_model: OutputModel,
103
+ early_stopping_cfg: EarlyStoppingConfig | None = None,
104
+ config_dict: dict[str, str] | None = None,
105
+ enable_tag_versioning: bool = False,
106
+ ) -> Callbacks:
107
+ """
108
+ Set up PyTorch Lightning callbacks for training.
109
+
110
+ Creates and configures a set of callbacks including checkpoint saving,
111
+ learning rate monitoring, model registry uploading, and optional early stopping.
112
+
113
+ Args:
114
+ task: ClearML task for organizing checkpoints by task name and ID.
115
+ root_path: Root directory for saving checkpoints.
116
+ checkpoint_cfg: Configuration for checkpoint saving behavior.
117
+ uploading_mode: Model upload strategy:
118
+ - `"only-best"`: Upload only the best checkpoint based on monitored metric.
119
+ - `"every-checkpoint"`: Upload every saved checkpoint.
120
+ output_model: ClearML OutputModel instance for model registry integration.
121
+ early_stopping_cfg: Configuration for early stopping. If None, early stopping
122
+ is disabled.
123
+ config_dict: Optional configuration dictionary to store with the model
124
+ in the registry.
125
+ enable_tag_versioning: Whether to auto-increment version tags (e.g., "v1.0")
126
+ on the uploaded model.
127
+
128
+ Returns:
129
+ Callbacks dataclass containing configured checkpoint, lr_monitor,
130
+ and optionally early_stopping callbacks.
131
+
132
+ """
133
+ lr_monitor = LearningRateMonitor(
134
+ logging_interval="step", log_weight_decay=True, log_momentum=False
135
+ )
136
+ model_uploader = ClearMLRegistryUploaderCallback(
137
+ output_model=output_model,
138
+ config_dict=config_dict,
139
+ verbose=True,
140
+ enable_tag_versioning=enable_tag_versioning,
141
+ )
142
+ checkpoint_callback = setup_checkpoint_callback(
143
+ root_path / "checkpoints" / task.name / task.id,
144
+ checkpoint_cfg,
145
+ registry_uploader_callback=model_uploader,
146
+ uploading_mode=uploading_mode,
147
+ )
148
+ if early_stopping_cfg is not None:
149
+ early_stopping_callback = setup_early_stopping_callback(early_stopping_cfg)
150
+ else:
151
+ early_stopping_callback = None
152
+
153
+ callbacks = Callbacks(
154
+ checkpoint=checkpoint_callback,
155
+ lr_monitor=lr_monitor,
156
+ early_stopping=early_stopping_callback,
157
+ )
158
+ return callbacks
159
+
160
+
161
+ def setup_loggers(task: Task, root_path: Path) -> list[TensorBoardLogger]:
162
+ """
163
+ Set up PyTorch Lightning loggers for training.
164
+
165
+ Args:
166
+ task: ClearML task used to organize log directories by task name and ID.
167
+ root_path: Root directory for storing TensorBoard logs.
168
+
169
+ Returns:
170
+ List of configured TensorBoard loggers.
171
+
172
+ """
173
+ loggers = [
174
+ setup_tb_logger(root_path / "runs" / task.name / task.id),
175
+ ]
176
+ return loggers
177
+
178
+
179
+ def setup_strategy(
180
+ strategy_settings: TRAINING_STRATEGIES,
181
+ devices: list[int] | int,
182
+ auto_wrap_policy: set[type[Module]] | None = None,
183
+ ) -> Literal["auto"] | FSDPStrategy | DDPStrategy:
184
+ """
185
+ Configure and return a PyTorch Lightning training strategy.
186
+
187
+ Args:
188
+ strategy_settings: Strategy configuration object. Must be one of:
189
+ - `FSDP1StrategyConfig`: Fully Sharded Data Parallel strategy (requires 2+ devices).
190
+ - `DDPStrategyConfig`: Distributed Data Parallel strategy (requires 2+ devices).
191
+ - `SingleDeviceStrategyConfig`: Single device training (requires exactly 1 device).
192
+ devices: Device(s) to use for training. Either a list of device IDs or
193
+ a single integer representing the number of devices.
194
+ auto_wrap_policy: Set of module types that should be wrapped for FSDP.
195
+ Required when using `FSDP1StrategyConfig`, ignored otherwise.
196
+
197
+ Returns:
198
+ Configured strategy: `FSDPStrategy`, `DDPStrategy`, or `"auto"` for single device.
199
+
200
+ Raises:
201
+ ValueError: If device count doesn't match strategy requirements or
202
+ if `auto_wrap_policy` is missing for FSDP.
203
+
204
+ """
205
+ if isinstance(devices, list):
206
+ num_devices = len(devices)
207
+ else:
208
+ num_devices = devices
209
+
210
+ match strategy_settings:
211
+ case FSDP1StrategyConfig():
212
+ if num_devices < 2:
213
+ raise ValueError("FSDP strategy requires multiple devices.")
214
+
215
+ if auto_wrap_policy is None:
216
+ raise ValueError("auto_wrap_policy must be provided for FSDP strategy.")
217
+
218
+ mixed_precision_config = MixedPrecision(
219
+ param_dtype=getattr(torch, strategy_settings.param_dtype),
220
+ reduce_dtype=getattr(torch, strategy_settings.reduce_dtype),
221
+ buffer_dtype=getattr(torch, strategy_settings.buffer_dtype),
222
+ )
223
+ strategy = FSDPStrategy(
224
+ auto_wrap_policy=auto_wrap_policy,
225
+ mixed_precision=mixed_precision_config,
226
+ )
227
+ case DDPStrategyConfig():
228
+ if num_devices < 2:
229
+ raise ValueError("DDP strategy requires at least two devices.")
230
+ strategy = DDPStrategy(
231
+ find_unused_parameters=strategy_settings.find_unused_parameters
232
+ )
233
+ case SingleDeviceStrategyConfig():
234
+ if num_devices != 1:
235
+ raise ValueError("SingleDevice strategy requires exactly one device.")
236
+ strategy = "auto"
237
+ case _:
238
+ raise ValueError(
239
+ f"Unsupported strategy type: {type(strategy_settings.trainer.strategy)}"
240
+ )
241
+ return strategy
@@ -3,14 +3,13 @@ from collections.abc import Mapping
3
3
  import torch.distributed as dist
4
4
  from torch import Tensor
5
5
  from torchmetrics import Metric
6
- from torchmetrics import MetricCollection
7
6
 
8
7
 
9
8
  def apply_suffix(
10
- metrics: Mapping[str, Metric | Tensor | int | float] | MetricCollection,
9
+ metrics: Mapping[str, Metric | Tensor | int | float],
11
10
  suffix: str,
12
11
  add_dist_rank: bool = False,
13
- ) -> Mapping[str, Metric | Tensor | int | float] | MetricCollection:
12
+ ) -> Mapping[str, Metric | Tensor | int | float]:
14
13
  """Add stage prefix to metric names."""
15
14
  new_metrics_dict = {}
16
15
  for key, value in metrics.items():
@@ -5,7 +5,6 @@ from pathlib import Path
5
5
  from typing import override
6
6
 
7
7
  from clearml import OutputModel
8
- from clearml import Task
9
8
 
10
9
  from kostyl.ml.clearml.logging_utils import find_version_in_tags
11
10
  from kostyl.ml.clearml.logging_utils import increment_version
@@ -29,69 +28,50 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
29
28
 
30
29
  def __init__(
31
30
  self,
32
- task: Task,
33
- output_model_name: str,
34
- output_model_tags: list[str] | None = None,
35
- verbose: bool = True,
36
- enable_tag_versioning: bool = True,
37
- label_enumeration: dict[str, int] | None = None,
31
+ output_model: OutputModel,
38
32
  config_dict: dict[str, str] | None = None,
33
+ verbose: bool = True,
34
+ enable_tag_versioning: bool = False,
39
35
  ) -> None:
40
36
  """
41
37
  Initializes the ClearMLRegistryUploaderCallback.
42
38
 
43
39
  Args:
44
- task: ClearML task.
45
- ckpt_callback: ModelCheckpoint instance used by Trainer.
46
- output_model_name: Name for the ClearML output model.
47
- output_model_tags: Tags for the output model.
48
- verbose: Whether to log messages.
49
- label_enumeration: Optional mapping of label names to integer IDs.
40
+ output_model: ClearML OutputModel instance representing the model to upload.
41
+ verbose: Whether to log messages during upload.
50
42
  config_dict: Optional configuration dictionary to associate with the model.
51
43
  enable_tag_versioning: Whether to enable versioning in tags. If True,
52
44
  the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
53
45
 
54
46
  """
55
47
  super().__init__()
56
- if output_model_tags is None:
57
- output_model_tags = []
58
-
59
- self.task = task
60
- self.output_model_name = output_model_name
61
- self.output_model_tags = output_model_tags
48
+ self.output_model = output_model
62
49
  self.config_dict = config_dict
63
- self.label_enumeration = label_enumeration
64
50
  self.verbose = verbose
65
51
  self.enable_tag_versioning = enable_tag_versioning
66
52
 
67
53
  self.best_model_path: str = ""
68
54
 
69
- self._output_model: OutputModel | None = None
70
55
  self._last_uploaded_model_path: str = ""
71
56
  self._upload_callback: Callable | None = None
57
+
58
+ self._validate_tags()
72
59
  return
73
60
 
74
- def _create_output_model(self) -> OutputModel:
61
+ def _validate_tags(self) -> None:
62
+ output_model_tags = self.output_model.tags or []
75
63
  if self.enable_tag_versioning:
76
- version = find_version_in_tags(self.output_model_tags)
64
+ version = find_version_in_tags(output_model_tags)
77
65
  if version is None:
78
- self.output_model_tags.append("v1.0")
66
+ output_model_tags.append("v1.0")
79
67
  else:
80
68
  new_version = increment_version(version)
81
- self.output_model_tags.remove(version)
82
- self.output_model_tags.append(new_version)
83
-
84
- if "LightningCheckpoint" not in self.output_model_tags:
85
- self.output_model_tags.append("LightningCheckpoint")
86
-
87
- return OutputModel(
88
- task=self.task,
89
- name=self.output_model_name,
90
- framework="PyTorch",
91
- tags=self.output_model_tags,
92
- config_dict=None,
93
- label_enumeration=self.label_enumeration,
94
- )
69
+ output_model_tags.remove(version)
70
+ output_model_tags.append(new_version)
71
+ if "LightningCheckpoint" not in output_model_tags:
72
+ output_model_tags.append("LightningCheckpoint")
73
+ self.output_model.tags = output_model_tags
74
+ return None
95
75
 
96
76
  @override
97
77
  def upload_checkpoint(
@@ -105,18 +85,15 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
105
85
  logger.info("Model unchanged since last upload")
106
86
  return
107
87
 
108
- if self._output_model is None:
109
- self._output_model = self._create_output_model()
110
-
111
88
  if self.verbose:
112
89
  logger.info(f"Uploading model from {path}")
113
90
 
114
- self._output_model.update_weights(
91
+ self.output_model.update_weights(
115
92
  path,
116
93
  auto_delete_file=False,
117
94
  async_enable=False,
118
95
  )
119
- self._output_model.update_design(config_dict=self.config_dict)
96
+ self.output_model.update_design(config_dict=self.config_dict)
120
97
 
121
98
  self._last_uploaded_model_path = path
122
99
  return
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "kostyl-toolkit"
3
- version = "0.1.30"
3
+ version = "0.1.32"
4
4
  description = "Kickass Orchestration System for Training, Yielding & Logging "
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -1,5 +0,0 @@
1
- from .extenstions import KostylLightningModule
2
- from .extenstions import LightningCheckpointLoaderMixin
3
-
4
-
5
- __all__ = ["KostylLightningModule", "LightningCheckpointLoaderMixin"]
@@ -1,44 +0,0 @@
1
- from typing import cast
2
-
3
- import lightning as L
4
- import torch.distributed as dist
5
- from torch.distributed import ProcessGroup
6
-
7
- from kostyl.utils.logging import setup_logger
8
-
9
-
10
- logger = setup_logger(add_rank=True)
11
-
12
-
13
- def estimate_total_steps(
14
- trainer: L.Trainer, process_group: ProcessGroup | None = None
15
- ) -> int:
16
- """Estimates the total number of training steps for a given PyTorch Lightning Trainer."""
17
- if dist.is_initialized():
18
- world_size = dist.get_world_size(process_group)
19
- else:
20
- world_size = 1
21
-
22
- datamodule = trainer.datamodule # type: ignore
23
- if datamodule is None:
24
- raise ValueError("Trainer must have a datamodule to estimate total steps.")
25
- datamodule = cast(L.LightningDataModule, datamodule)
26
-
27
- logger.info("Loading `train_dataloader` to estimate number of stepping batches.")
28
- datamodule.setup("fit")
29
-
30
- dataloader_len = len(datamodule.train_dataloader())
31
- steps_per_epoch = dataloader_len // trainer.accumulate_grad_batches // world_size
32
-
33
- if trainer.max_epochs is None:
34
- raise ValueError("Trainer must have `max_epochs` set to estimate total steps.")
35
- total_steps = steps_per_epoch * trainer.max_epochs
36
-
37
- logger.info(
38
- f"Total steps: {total_steps} (per-epoch: {steps_per_epoch})\n"
39
- f"-> Dataloader len: {dataloader_len}\n"
40
- f"-> Accumulate grad batches: {trainer.accumulate_grad_batches}\n"
41
- f"-> Epochs: {trainer.max_epochs}\n "
42
- f"-> World size: {world_size}"
43
- )
44
- return total_steps