kostyl-toolkit 0.1.34__py3-none-any.whl → 0.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kostyl/ml/configs/hyperparams.py +21 -5
- kostyl/ml/configs/training_settings.py +17 -6
- kostyl/ml/lightning/callbacks/checkpoint.py +8 -8
- kostyl/ml/lightning/extensions/pretrained_model.py +27 -5
- kostyl/ml/lightning/utils.py +58 -0
- kostyl/ml/registry_uploader.py +56 -29
- kostyl/ml/schedulers/__init__.py +13 -1
- kostyl/ml/schedulers/base.py +9 -7
- kostyl/ml/schedulers/cosine.py +53 -24
- kostyl/ml/schedulers/cosine_with_plateu.py +277 -0
- kostyl/ml/schedulers/linear.py +36 -11
- kostyl/utils/logging.py +1 -1
- {kostyl_toolkit-0.1.34.dist-info → kostyl_toolkit-0.1.36.dist-info}/METADATA +1 -1
- {kostyl_toolkit-0.1.34.dist-info → kostyl_toolkit-0.1.36.dist-info}/RECORD +15 -14
- {kostyl_toolkit-0.1.34.dist-info → kostyl_toolkit-0.1.36.dist-info}/WHEEL +1 -1
- kostyl/ml/lightning/training_utils.py +0 -241
kostyl/ml/configs/hyperparams.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
1
3
|
from pydantic import BaseModel
|
|
2
4
|
from pydantic import Field
|
|
3
5
|
from pydantic import model_validator
|
|
@@ -8,11 +10,25 @@ from kostyl.utils.logging import setup_logger
|
|
|
8
10
|
logger = setup_logger(fmt="only_message")
|
|
9
11
|
|
|
10
12
|
|
|
11
|
-
class
|
|
12
|
-
"""
|
|
13
|
+
class AdamConfig(BaseModel):
|
|
14
|
+
"""AdamW optimizer hyperparameters configuration."""
|
|
15
|
+
|
|
16
|
+
type: Literal["AdamW"] = "AdamW"
|
|
17
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
18
|
+
is_adamw: bool = True
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AdamWithPrecisionConfig(BaseModel):
|
|
22
|
+
"""Adam optimizer with low-precision hyperparameters configuration."""
|
|
23
|
+
|
|
24
|
+
type: Literal["Adam8bit", "Adam4bit", "AdamFp8"]
|
|
25
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
26
|
+
block_size: int
|
|
27
|
+
bf16_stochastic_round: bool = False
|
|
28
|
+
is_adamw: bool = True
|
|
29
|
+
|
|
13
30
|
|
|
14
|
-
|
|
15
|
-
adamw_beta2: float = 0.999
|
|
31
|
+
Optimizer = AdamConfig | AdamWithPrecisionConfig
|
|
16
32
|
|
|
17
33
|
|
|
18
34
|
class Lr(BaseModel):
|
|
@@ -73,6 +89,6 @@ class HyperparamsConfig(BaseModel):
|
|
|
73
89
|
"""Model training hyperparameters configuration."""
|
|
74
90
|
|
|
75
91
|
grad_clip_val: float | None = Field(default=None, gt=0, validate_default=False)
|
|
76
|
-
optimizer: Optimizer
|
|
92
|
+
optimizer: Optimizer
|
|
77
93
|
lr: Lr
|
|
78
94
|
weight_decay: WeightDecay
|
|
@@ -25,21 +25,31 @@ PRECISION = Literal[
|
|
|
25
25
|
"16",
|
|
26
26
|
"bf16",
|
|
27
27
|
]
|
|
28
|
+
DTYPE = Literal["float32", "float16", "bfloat16", "float64"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SingleDeviceStrategyConfig(BaseModel):
|
|
32
|
+
"""Single device strategy configuration."""
|
|
33
|
+
|
|
34
|
+
type: Literal["single_device"]
|
|
28
35
|
|
|
29
36
|
|
|
30
37
|
class FSDP1StrategyConfig(BaseModel):
|
|
31
38
|
"""Fully Sharded Data Parallel (FSDP) strategy configuration."""
|
|
32
39
|
|
|
33
40
|
type: Literal["fsdp1"]
|
|
34
|
-
param_dtype:
|
|
35
|
-
reduce_dtype:
|
|
36
|
-
buffer_dtype:
|
|
41
|
+
param_dtype: DTYPE | None
|
|
42
|
+
reduce_dtype: DTYPE | None
|
|
43
|
+
buffer_dtype: DTYPE | None
|
|
37
44
|
|
|
38
45
|
|
|
39
|
-
class
|
|
40
|
-
"""
|
|
46
|
+
class FSDP2StrategyConfig(BaseModel):
|
|
47
|
+
"""Fully Sharded Data Parallel (FSDP) strategy configuration."""
|
|
41
48
|
|
|
42
|
-
type: Literal["
|
|
49
|
+
type: Literal["fsdp2"]
|
|
50
|
+
param_dtype: DTYPE | None
|
|
51
|
+
reduce_dtype: DTYPE | None
|
|
52
|
+
buffer_dtype: DTYPE | None
|
|
43
53
|
|
|
44
54
|
|
|
45
55
|
class DDPStrategyConfig(BaseModel):
|
|
@@ -82,6 +92,7 @@ class CheckpointConfig(BaseModel):
|
|
|
82
92
|
monitor: str = "val_loss"
|
|
83
93
|
mode: str = "min"
|
|
84
94
|
filename: str = "{epoch:02d}-{val_loss:.2f}"
|
|
95
|
+
save_weights_only: bool = True
|
|
85
96
|
|
|
86
97
|
|
|
87
98
|
class DataConfig(BaseModel):
|
|
@@ -299,9 +299,9 @@ class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
|
|
|
299
299
|
def setup_checkpoint_callback(
|
|
300
300
|
dirpath: Path,
|
|
301
301
|
ckpt_cfg: CheckpointConfig,
|
|
302
|
-
save_weights_only: bool = True,
|
|
303
302
|
registry_uploader_callback: RegistryUploaderCallback | None = None,
|
|
304
303
|
uploading_strategy: Literal["only-best", "every-checkpoint"] | None = None,
|
|
304
|
+
remove_folder_if_exists: bool = True,
|
|
305
305
|
) -> ModelCheckpointWithRegistryUploader | ModelCheckpoint:
|
|
306
306
|
"""
|
|
307
307
|
Create and configure a checkpoint callback for model saving.
|
|
@@ -313,14 +313,13 @@ def setup_checkpoint_callback(
|
|
|
313
313
|
Args:
|
|
314
314
|
dirpath: Path to the directory for saving checkpoints.
|
|
315
315
|
ckpt_cfg: Checkpoint configuration (filename, monitor, mode, save_top_k).
|
|
316
|
-
save_weights_only: If True, only model weights are saved without optimizer and lr-scheduler state.
|
|
317
|
-
Defaults to True.
|
|
318
316
|
registry_uploader_callback: Optional callback for uploading checkpoints to a remote registry.
|
|
319
317
|
Must be specified together with uploading_strategy.
|
|
320
318
|
uploading_strategy: Checkpoint upload mode:
|
|
321
319
|
- "only-best": only the best checkpoint is uploaded
|
|
322
320
|
- "every-checkpoint": every saved checkpoint is uploaded
|
|
323
321
|
Must be specified together with registry_uploader_callback.
|
|
322
|
+
remove_folder_if_exists: If True, removes existing checkpoint directory before creating a new one.
|
|
324
323
|
|
|
325
324
|
Returns:
|
|
326
325
|
ModelCheckpointWithRegistryUploader if registry_uploader_callback is provided,
|
|
@@ -331,7 +330,7 @@ def setup_checkpoint_callback(
|
|
|
331
330
|
|
|
332
331
|
Note:
|
|
333
332
|
If the dirpath directory already exists, it will be removed and recreated
|
|
334
|
-
(only on the main process in distributed training).
|
|
333
|
+
(only on the main process in distributed training) if remove_folder_if_exists is True.
|
|
335
334
|
|
|
336
335
|
"""
|
|
337
336
|
if (registry_uploader_callback is None) != (uploading_strategy is None):
|
|
@@ -342,8 +341,9 @@ def setup_checkpoint_callback(
|
|
|
342
341
|
if dirpath.exists():
|
|
343
342
|
if is_main_process():
|
|
344
343
|
logger.warning(f"Checkpoint directory {dirpath} already exists.")
|
|
345
|
-
|
|
346
|
-
|
|
344
|
+
if remove_folder_if_exists:
|
|
345
|
+
rmtree(dirpath)
|
|
346
|
+
logger.warning(f"Removed existing checkpoint directory {dirpath}.")
|
|
347
347
|
else:
|
|
348
348
|
logger.info(f"Creating checkpoint directory {dirpath}.")
|
|
349
349
|
dirpath.mkdir(parents=True, exist_ok=True)
|
|
@@ -356,7 +356,7 @@ def setup_checkpoint_callback(
|
|
|
356
356
|
monitor=ckpt_cfg.monitor,
|
|
357
357
|
mode=ckpt_cfg.mode,
|
|
358
358
|
verbose=True,
|
|
359
|
-
save_weights_only=save_weights_only,
|
|
359
|
+
save_weights_only=ckpt_cfg.save_weights_only,
|
|
360
360
|
registry_uploader_callback=registry_uploader_callback,
|
|
361
361
|
uploading_mode=uploading_strategy,
|
|
362
362
|
)
|
|
@@ -368,6 +368,6 @@ def setup_checkpoint_callback(
|
|
|
368
368
|
monitor=ckpt_cfg.monitor,
|
|
369
369
|
mode=ckpt_cfg.mode,
|
|
370
370
|
verbose=True,
|
|
371
|
-
save_weights_only=save_weights_only,
|
|
371
|
+
save_weights_only=ckpt_cfg.save_weights_only,
|
|
372
372
|
)
|
|
373
373
|
return checkpoint_callback
|
|
@@ -20,7 +20,8 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
20
20
|
cls: type[TModelInstance],
|
|
21
21
|
checkpoint_path: str | Path,
|
|
22
22
|
config_key: str = "config",
|
|
23
|
-
weights_prefix: str = "model.",
|
|
23
|
+
weights_prefix: str | None = "model.",
|
|
24
|
+
strict_prefix: bool = False,
|
|
24
25
|
**kwargs: Any,
|
|
25
26
|
) -> TModelInstance:
|
|
26
27
|
"""
|
|
@@ -39,8 +40,10 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
39
40
|
checkpoint_path (str | Path): Path to the checkpoint file. Must be a .ckpt file.
|
|
40
41
|
config_key (str, optional): Key in the checkpoint dictionary where the config is stored.
|
|
41
42
|
Defaults to "config".
|
|
42
|
-
weights_prefix (str, optional): Prefix to strip from state dict keys. Defaults to "model.".
|
|
43
|
-
If not empty and doesn't end with ".", a "." is appended.
|
|
43
|
+
weights_prefix (str | None, optional): Prefix to strip from state dict keys. Defaults to "model.".
|
|
44
|
+
If not empty and doesn't end with ".", a "." is appended. If empty or None, no prefix stripping will be skipped.
|
|
45
|
+
strict_prefix (bool, optional): If True, drop tensors those keys that do not start with the
|
|
46
|
+
specified prefix. Defaults to False.
|
|
44
47
|
kwargs: Additional keyword arguments to pass to the model's `from_pretrained` method.
|
|
45
48
|
|
|
46
49
|
Returns:
|
|
@@ -53,6 +56,13 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
53
56
|
"""
|
|
54
57
|
if isinstance(checkpoint_path, str):
|
|
55
58
|
checkpoint_path = Path(checkpoint_path)
|
|
59
|
+
if weights_prefix is None:
|
|
60
|
+
weights_prefix = ""
|
|
61
|
+
weights_prefix = cast(str, weights_prefix)
|
|
62
|
+
if weights_prefix == "" and strict_prefix:
|
|
63
|
+
logger.warning(
|
|
64
|
+
"strict_prefix=True has no effect when weights_prefix is empty or None."
|
|
65
|
+
)
|
|
56
66
|
|
|
57
67
|
if checkpoint_path.is_dir():
|
|
58
68
|
raise ValueError(f"{checkpoint_path} is a directory")
|
|
@@ -85,13 +95,25 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
85
95
|
if not weights_prefix.endswith("."):
|
|
86
96
|
weights_prefix = weights_prefix + "."
|
|
87
97
|
state_dict: dict[str, torch.Tensor] = {}
|
|
88
|
-
|
|
98
|
+
matched_keys_counter = 0
|
|
89
99
|
for key, value in raw_state_dict.items():
|
|
90
100
|
if key.startswith(weights_prefix):
|
|
91
101
|
new_key = key[len(weights_prefix) :]
|
|
92
102
|
state_dict[new_key] = value
|
|
93
|
-
|
|
103
|
+
matched_keys_counter += 1
|
|
104
|
+
elif not strict_prefix:
|
|
94
105
|
state_dict[key] = value
|
|
106
|
+
|
|
107
|
+
if matched_keys_counter == 0:
|
|
108
|
+
if strict_prefix:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"No keys in the checkpoint start with the specified prefix '{weights_prefix}'. "
|
|
111
|
+
"Try to load with `strict_prefix=False` or verify the prefix."
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
logger.warning(
|
|
115
|
+
f"No keys in the checkpoint start with the specified prefix '{weights_prefix}'. "
|
|
116
|
+
)
|
|
95
117
|
else:
|
|
96
118
|
state_dict = raw_state_dict
|
|
97
119
|
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from typing import cast
|
|
2
|
+
|
|
3
|
+
import lightning as L
|
|
4
|
+
import torch.distributed as dist
|
|
5
|
+
from torch.distributed import ProcessGroup
|
|
6
|
+
|
|
7
|
+
from kostyl.ml.configs import DDPStrategyConfig
|
|
8
|
+
from kostyl.ml.configs import FSDP1StrategyConfig
|
|
9
|
+
from kostyl.ml.configs import SingleDeviceStrategyConfig
|
|
10
|
+
from kostyl.utils.logging import setup_logger
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
TRAINING_STRATEGIES = (
|
|
14
|
+
FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
logger = setup_logger(add_rank=True)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def estimate_total_steps(
|
|
21
|
+
trainer: L.Trainer, dp_process_group: ProcessGroup | None = None
|
|
22
|
+
) -> int:
|
|
23
|
+
"""
|
|
24
|
+
Estimates the total number of training steps with respect to data parallelism and gradient accumulation.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
trainer: The PyTorch Lightning Trainer instance.
|
|
28
|
+
dp_process_group: The data parallel process group. If None, the world process group will be used.
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
if dist.is_initialized():
|
|
32
|
+
world_size = dist.get_world_size(dp_process_group)
|
|
33
|
+
else:
|
|
34
|
+
world_size = 1
|
|
35
|
+
|
|
36
|
+
datamodule = trainer.datamodule # type: ignore
|
|
37
|
+
if datamodule is None:
|
|
38
|
+
raise ValueError("Trainer must have a datamodule to estimate total steps.")
|
|
39
|
+
datamodule = cast(L.LightningDataModule, datamodule)
|
|
40
|
+
|
|
41
|
+
logger.info("Loading `train_dataloader` to estimate number of stepping batches.")
|
|
42
|
+
datamodule.setup("fit")
|
|
43
|
+
|
|
44
|
+
dataloader_len = len(datamodule.train_dataloader())
|
|
45
|
+
steps_per_epoch = dataloader_len // trainer.accumulate_grad_batches // world_size
|
|
46
|
+
|
|
47
|
+
if trainer.max_epochs is None:
|
|
48
|
+
raise ValueError("Trainer must have `max_epochs` set to estimate total steps.")
|
|
49
|
+
total_steps = steps_per_epoch * trainer.max_epochs
|
|
50
|
+
|
|
51
|
+
logger.info(
|
|
52
|
+
f"Total steps: {total_steps} (per-epoch: {steps_per_epoch}) "
|
|
53
|
+
f"-> Dataloader len: {dataloader_len} "
|
|
54
|
+
f"-> Accumulate grad batches: {trainer.accumulate_grad_batches} "
|
|
55
|
+
f"-> Epochs: {trainer.max_epochs} "
|
|
56
|
+
f"-> DataParallel size: {world_size}"
|
|
57
|
+
)
|
|
58
|
+
return total_steps
|
kostyl/ml/registry_uploader.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
from abc import ABC
|
|
2
2
|
from abc import abstractmethod
|
|
3
3
|
from collections.abc import Callable
|
|
4
|
+
from functools import partial
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import override
|
|
6
7
|
|
|
7
8
|
from clearml import OutputModel
|
|
8
9
|
|
|
9
|
-
from kostyl.ml.clearml.logging_utils import find_version_in_tags
|
|
10
|
-
from kostyl.ml.clearml.logging_utils import increment_version
|
|
11
10
|
from kostyl.utils.logging import setup_logger
|
|
12
11
|
|
|
13
12
|
|
|
@@ -28,51 +27,79 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
|
28
27
|
|
|
29
28
|
def __init__(
|
|
30
29
|
self,
|
|
31
|
-
|
|
30
|
+
model_name: str,
|
|
32
31
|
config_dict: dict[str, str] | None = None,
|
|
32
|
+
label_enumeration: dict[str, int] | None = None,
|
|
33
|
+
tags: list[str] | None = None,
|
|
34
|
+
comment: str | None = None,
|
|
35
|
+
framework: str | None = None,
|
|
36
|
+
base_model_id: str | None = None,
|
|
37
|
+
new_model_per_upload: bool = True,
|
|
33
38
|
verbose: bool = True,
|
|
34
|
-
enable_tag_versioning: bool = False,
|
|
35
39
|
) -> None:
|
|
36
40
|
"""
|
|
37
41
|
Initializes the ClearMLRegistryUploaderCallback.
|
|
38
42
|
|
|
39
43
|
Args:
|
|
40
|
-
|
|
41
|
-
|
|
44
|
+
model_name: The name for the newly created model.
|
|
45
|
+
label_enumeration: The label enumeration dictionary of string (label) to integer (value) pairs.
|
|
42
46
|
config_dict: Optional configuration dictionary to associate with the model.
|
|
43
|
-
|
|
44
|
-
|
|
47
|
+
tags: A list of strings which are tags for the model.
|
|
48
|
+
comment: A comment / description for the model.
|
|
49
|
+
framework: The framework of the model (e.g., "PyTorch", "TensorFlow").
|
|
50
|
+
base_model_id: Optional ClearML model ID to use as a base for the new model
|
|
51
|
+
new_model_per_upload: Whether to create a new ClearML model
|
|
52
|
+
for every upload or update weights of the same model. When updating weights,
|
|
53
|
+
the last uploaded checkpoint will be replaced (and deleted).
|
|
54
|
+
verbose: Whether to log messages during upload.
|
|
45
55
|
|
|
46
56
|
"""
|
|
47
57
|
super().__init__()
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
58
|
+
if base_model_id is not None and new_model_per_upload:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"Cannot set base_model_id when new_model_per_upload is True."
|
|
61
|
+
)
|
|
52
62
|
|
|
63
|
+
self.verbose = verbose
|
|
64
|
+
self.new_model_per_upload = new_model_per_upload
|
|
53
65
|
self.best_model_path: str = ""
|
|
54
|
-
|
|
66
|
+
self.config_dict = config_dict
|
|
67
|
+
self._output_model: OutputModel | None = None
|
|
55
68
|
self._last_uploaded_model_path: str = ""
|
|
56
69
|
self._upload_callback: Callable | None = None
|
|
57
70
|
|
|
58
|
-
self._validate_tags()
|
|
71
|
+
self._validate_tags(tags)
|
|
72
|
+
self.model_fabric = partial(
|
|
73
|
+
OutputModel,
|
|
74
|
+
name=model_name,
|
|
75
|
+
label_enumeration=label_enumeration,
|
|
76
|
+
tags=tags,
|
|
77
|
+
comment=comment,
|
|
78
|
+
framework=framework,
|
|
79
|
+
base_model_id=base_model_id,
|
|
80
|
+
)
|
|
59
81
|
return
|
|
60
82
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
if
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
else:
|
|
68
|
-
new_version = increment_version(version)
|
|
69
|
-
output_model_tags.remove(version)
|
|
70
|
-
output_model_tags.append(new_version)
|
|
71
|
-
if "LightningCheckpoint" not in output_model_tags:
|
|
72
|
-
output_model_tags.append("LightningCheckpoint")
|
|
73
|
-
self.output_model.tags = output_model_tags
|
|
83
|
+
@staticmethod
|
|
84
|
+
def _validate_tags(tags: list[str] | None) -> None:
|
|
85
|
+
if tags is None:
|
|
86
|
+
return
|
|
87
|
+
if "LightningCheckpoint" not in tags:
|
|
88
|
+
tags.append("LightningCheckpoint")
|
|
74
89
|
return None
|
|
75
90
|
|
|
91
|
+
@property
|
|
92
|
+
def output_model_(self) -> OutputModel:
|
|
93
|
+
"""Returns the OutputModel instance based on `new_model_per_upload` setting."""
|
|
94
|
+
if self.new_model_per_upload:
|
|
95
|
+
model = self.model_fabric()
|
|
96
|
+
self._output_model = self.model_fabric()
|
|
97
|
+
else:
|
|
98
|
+
if self._output_model is None:
|
|
99
|
+
self._output_model = self.model_fabric()
|
|
100
|
+
model = self._output_model
|
|
101
|
+
return model
|
|
102
|
+
|
|
76
103
|
@override
|
|
77
104
|
def upload_checkpoint(
|
|
78
105
|
self,
|
|
@@ -88,12 +115,12 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
|
88
115
|
if self.verbose:
|
|
89
116
|
logger.info(f"Uploading model from {path}")
|
|
90
117
|
|
|
91
|
-
self.
|
|
118
|
+
self.output_model_.update_weights(
|
|
92
119
|
path,
|
|
93
120
|
auto_delete_file=False,
|
|
94
121
|
async_enable=False,
|
|
95
122
|
)
|
|
96
|
-
self.
|
|
123
|
+
self.output_model_.update_design(config_dict=self.config_dict)
|
|
97
124
|
|
|
98
125
|
self._last_uploaded_model_path = path
|
|
99
126
|
return
|
kostyl/ml/schedulers/__init__.py
CHANGED
|
@@ -1,6 +1,18 @@
|
|
|
1
1
|
from .composite import CompositeScheduler
|
|
2
2
|
from .cosine import CosineParamScheduler
|
|
3
3
|
from .cosine import CosineScheduler
|
|
4
|
+
from .cosine_with_plateu import CosineWithPlateauParamScheduler
|
|
5
|
+
from .cosine_with_plateu import CosineWithPlateuScheduler
|
|
6
|
+
from .linear import LinearParamScheduler
|
|
7
|
+
from .linear import LinearScheduler
|
|
4
8
|
|
|
5
9
|
|
|
6
|
-
__all__ = [
|
|
10
|
+
__all__ = [
|
|
11
|
+
"CompositeScheduler",
|
|
12
|
+
"CosineParamScheduler",
|
|
13
|
+
"CosineScheduler",
|
|
14
|
+
"CosineWithPlateauParamScheduler",
|
|
15
|
+
"CosineWithPlateuScheduler",
|
|
16
|
+
"LinearParamScheduler",
|
|
17
|
+
"LinearScheduler",
|
|
18
|
+
]
|
kostyl/ml/schedulers/base.py
CHANGED
|
@@ -6,18 +6,20 @@ from typing import Any
|
|
|
6
6
|
class BaseScheduler(ABC):
|
|
7
7
|
"""Base class for learning rate schedulers."""
|
|
8
8
|
|
|
9
|
+
@abstractmethod
|
|
9
10
|
def state_dict(self) -> dict[str, Any]:
|
|
10
11
|
"""Get the state as a state dictionary."""
|
|
11
|
-
|
|
12
|
-
key: value
|
|
13
|
-
for key, value in self.__dict__.items()
|
|
14
|
-
if key not in ["optimizer", "scheduler_values"]
|
|
15
|
-
}
|
|
12
|
+
raise NotImplementedError
|
|
16
13
|
|
|
14
|
+
@abstractmethod
|
|
17
15
|
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
18
16
|
"""Load the state from a state dictionary."""
|
|
19
|
-
|
|
20
|
-
|
|
17
|
+
raise NotImplementedError
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def _verify(self) -> None:
|
|
21
|
+
"""Verify the scheduler configuration."""
|
|
22
|
+
raise NotImplementedError
|
|
21
23
|
|
|
22
24
|
def __getstate__(self) -> dict[str, Any]:
|
|
23
25
|
"""Get the state for pickling."""
|
kostyl/ml/schedulers/cosine.py
CHANGED
|
@@ -2,7 +2,6 @@ from typing import Any
|
|
|
2
2
|
from typing import override
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
|
-
import numpy.typing as npt
|
|
6
5
|
import torch
|
|
7
6
|
|
|
8
7
|
from .base import BaseScheduler
|
|
@@ -29,18 +28,24 @@ class _CosineSchedulerCore(BaseScheduler):
|
|
|
29
28
|
if freeze_ratio is not None:
|
|
30
29
|
if not (0 < freeze_ratio < 1):
|
|
31
30
|
raise ValueError(f"Freeze ratio must be in (0, 1), got {freeze_ratio}.")
|
|
31
|
+
pre_annealing_ratio = (warmup_ratio if warmup_ratio is not None else 0) + (
|
|
32
|
+
freeze_ratio if freeze_ratio is not None else 0
|
|
33
|
+
)
|
|
34
|
+
if pre_annealing_ratio > 1:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
"The sum of warmup_ratio and freeze_ratio must <= 1, got "
|
|
37
|
+
f"{pre_annealing_ratio}."
|
|
38
|
+
)
|
|
32
39
|
|
|
33
40
|
self.param_name = param_name
|
|
34
41
|
self.num_iters = num_iters
|
|
35
42
|
self.base_value = base_value
|
|
36
43
|
self.final_value = final_value
|
|
37
|
-
|
|
38
44
|
self.warmup_ratio = warmup_ratio
|
|
39
45
|
self.warmup_value = warmup_value
|
|
40
|
-
|
|
41
46
|
self.freeze_ratio = freeze_ratio
|
|
42
47
|
|
|
43
|
-
self.
|
|
48
|
+
self.scheduled_values: np.ndarray = np.array([], dtype=np.float64)
|
|
44
49
|
self.current_value_ = self.base_value
|
|
45
50
|
return
|
|
46
51
|
|
|
@@ -63,31 +68,29 @@ class _CosineSchedulerCore(BaseScheduler):
|
|
|
63
68
|
warmup_iters = 0
|
|
64
69
|
warmup_schedule = np.array([], dtype=np.float64)
|
|
65
70
|
|
|
71
|
+
# Create cosine annealing schedule
|
|
66
72
|
cosine_annealing_iters = self.num_iters - warmup_iters - freeze_iters
|
|
67
|
-
if cosine_annealing_iters
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
)
|
|
73
|
+
if cosine_annealing_iters > 0:
|
|
74
|
+
iters = np.arange(cosine_annealing_iters)
|
|
75
|
+
cosine_annealing_schedule = self.final_value + 0.5 * (
|
|
76
|
+
self.base_value - self.final_value
|
|
77
|
+
) * (1 + np.cos(np.pi * iters / len(iters)))
|
|
78
|
+
else:
|
|
79
|
+
cosine_annealing_schedule = np.array([], dtype=np.float64)
|
|
75
80
|
|
|
76
81
|
# Concatenate all parts of the schedule
|
|
77
|
-
self.
|
|
78
|
-
(freeze_schedule, warmup_schedule,
|
|
82
|
+
self.scheduled_values = np.concatenate(
|
|
83
|
+
(freeze_schedule, warmup_schedule, cosine_annealing_schedule)
|
|
79
84
|
)
|
|
80
|
-
|
|
81
|
-
if len(self.scheduler_values) != self.num_iters:
|
|
82
|
-
raise ValueError(
|
|
83
|
-
f"Scheduler length ({len(self.scheduler_values)}) does not match num_iters ({self.num_iters})."
|
|
84
|
-
)
|
|
85
|
+
self._verify()
|
|
85
86
|
return
|
|
86
87
|
|
|
87
88
|
@override
|
|
88
|
-
def
|
|
89
|
-
|
|
90
|
-
|
|
89
|
+
def _verify(self) -> None:
|
|
90
|
+
if len(self.scheduled_values) != self.num_iters:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"Scheduler length ({len(self.scheduled_values)}) does not match num_iters ({self.num_iters})."
|
|
93
|
+
)
|
|
91
94
|
return
|
|
92
95
|
|
|
93
96
|
@override
|
|
@@ -95,13 +98,13 @@ class _CosineSchedulerCore(BaseScheduler):
|
|
|
95
98
|
raise NotImplementedError
|
|
96
99
|
|
|
97
100
|
def _get_value(self, it: int) -> float:
|
|
98
|
-
if len(self.
|
|
101
|
+
if len(self.scheduled_values) == 0:
|
|
99
102
|
self._create_scheduler()
|
|
100
103
|
|
|
101
104
|
if it >= self.num_iters:
|
|
102
105
|
value: float = self.final_value
|
|
103
106
|
else:
|
|
104
|
-
value: float = self.
|
|
107
|
+
value: float = self.scheduled_values[it]
|
|
105
108
|
self.current_value_ = value
|
|
106
109
|
return value
|
|
107
110
|
|
|
@@ -163,6 +166,21 @@ class CosineScheduler(_CosineSchedulerCore):
|
|
|
163
166
|
self.param_group_field = param_group_field
|
|
164
167
|
return
|
|
165
168
|
|
|
169
|
+
@override
|
|
170
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
171
|
+
self.__dict__.update(state_dict)
|
|
172
|
+
self.scheduled_values = np.array([], dtype=np.float64)
|
|
173
|
+
return
|
|
174
|
+
|
|
175
|
+
@override
|
|
176
|
+
def state_dict(self) -> dict[str, Any]:
|
|
177
|
+
state = {
|
|
178
|
+
k: v
|
|
179
|
+
for k, v in self.__dict__.items()
|
|
180
|
+
if k not in ["scheduled_values", "optimizer"]
|
|
181
|
+
}
|
|
182
|
+
return state
|
|
183
|
+
|
|
166
184
|
@override
|
|
167
185
|
def step(self, it: int) -> None:
|
|
168
186
|
value = self._get_value(it)
|
|
@@ -209,3 +227,14 @@ class CosineParamScheduler(_CosineSchedulerCore):
|
|
|
209
227
|
"""
|
|
210
228
|
value = self._get_value(it)
|
|
211
229
|
return value
|
|
230
|
+
|
|
231
|
+
@override
|
|
232
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
233
|
+
self.__dict__.update(state_dict)
|
|
234
|
+
self.scheduled_values = np.array([], dtype=np.float64)
|
|
235
|
+
return
|
|
236
|
+
|
|
237
|
+
@override
|
|
238
|
+
def state_dict(self) -> dict[str, Any]:
|
|
239
|
+
state = {k: v for k, v in self.__dict__.items() if k != "scheduled_values"}
|
|
240
|
+
return state
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from typing import override
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from .base import BaseScheduler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _CosineWithPlateauSchedulerCore(BaseScheduler):
|
|
11
|
+
"""Core cosine with plateau scheduler logic."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
param_name: str,
|
|
16
|
+
num_iters: int,
|
|
17
|
+
base_value: float,
|
|
18
|
+
final_value: float,
|
|
19
|
+
plateau_ratio: float,
|
|
20
|
+
warmup_value: float | None = None,
|
|
21
|
+
warmup_ratio: float | None = None,
|
|
22
|
+
freeze_ratio: float | None = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
if warmup_ratio is not None:
|
|
25
|
+
if not (0 < warmup_ratio < 1):
|
|
26
|
+
raise ValueError(f"Warmup ratio must be in (0, 1), got {warmup_ratio}.")
|
|
27
|
+
if (warmup_value is None) != (warmup_ratio is None):
|
|
28
|
+
raise ValueError(
|
|
29
|
+
"Both warmup_ratio and warmup_value must be provided or neither."
|
|
30
|
+
)
|
|
31
|
+
if freeze_ratio is not None:
|
|
32
|
+
if not (0 < freeze_ratio < 1):
|
|
33
|
+
raise ValueError(f"Freeze ratio must be in (0, 1), got {freeze_ratio}.")
|
|
34
|
+
if not (0 < plateau_ratio < 1):
|
|
35
|
+
raise ValueError(f"Plateau ratio must be in (0, 1), got {plateau_ratio}.")
|
|
36
|
+
|
|
37
|
+
pre_annealing_ratio = (
|
|
38
|
+
plateau_ratio
|
|
39
|
+
+ (warmup_ratio if warmup_ratio is not None else 0)
|
|
40
|
+
+ (freeze_ratio if freeze_ratio is not None else 0)
|
|
41
|
+
)
|
|
42
|
+
if pre_annealing_ratio > 1:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"The sum of plateau_ratio, warmup_ratio, and freeze_ratio must <= 1, got "
|
|
45
|
+
f"{pre_annealing_ratio}."
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
self.param_name = param_name
|
|
49
|
+
self.num_iters = num_iters
|
|
50
|
+
self.base_value = base_value
|
|
51
|
+
self.final_value = final_value
|
|
52
|
+
self.cosine_annealing_ratio = 1 - pre_annealing_ratio
|
|
53
|
+
self.plateau_ratio = plateau_ratio
|
|
54
|
+
self.warmup_ratio = warmup_ratio
|
|
55
|
+
self.warmup_value = warmup_value
|
|
56
|
+
self.freeze_ratio = freeze_ratio
|
|
57
|
+
|
|
58
|
+
self.scheduled_values: np.ndarray = np.array([], dtype=np.float64)
|
|
59
|
+
self.current_value_ = self.base_value
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
def _create_scheduler(self) -> None:
|
|
63
|
+
# Create freeze schedule
|
|
64
|
+
if self.freeze_ratio is not None:
|
|
65
|
+
freeze_iters = int(self.num_iters * self.freeze_ratio)
|
|
66
|
+
freeze_schedule = np.zeros(freeze_iters, dtype=np.float64)
|
|
67
|
+
else:
|
|
68
|
+
freeze_iters = 0
|
|
69
|
+
freeze_schedule = np.array([], dtype=np.float64)
|
|
70
|
+
|
|
71
|
+
# Create linear warmup schedule
|
|
72
|
+
if self.warmup_ratio is not None and self.warmup_value is not None:
|
|
73
|
+
warmup_iters = int(self.num_iters * self.warmup_ratio)
|
|
74
|
+
warmup_schedule = np.linspace(
|
|
75
|
+
self.warmup_value, self.base_value, warmup_iters, dtype=np.float64
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
warmup_iters = 0
|
|
79
|
+
warmup_schedule = np.array([], dtype=np.float64)
|
|
80
|
+
|
|
81
|
+
# Create cosine annealing schedule
|
|
82
|
+
if self.cosine_annealing_ratio > 0:
|
|
83
|
+
cosine_annealing_iters = int(self.num_iters * self.cosine_annealing_ratio)
|
|
84
|
+
iters = np.arange(cosine_annealing_iters)
|
|
85
|
+
cosine_annealing_schedule = self.final_value + 0.5 * (
|
|
86
|
+
self.base_value - self.final_value
|
|
87
|
+
) * (1 + np.cos(np.pi * iters / len(iters)))
|
|
88
|
+
else:
|
|
89
|
+
cosine_annealing_iters = 0
|
|
90
|
+
cosine_annealing_schedule = np.array([], dtype=np.float64)
|
|
91
|
+
|
|
92
|
+
plateau_iters = (
|
|
93
|
+
self.num_iters - warmup_iters - freeze_iters - cosine_annealing_iters
|
|
94
|
+
)
|
|
95
|
+
if plateau_iters > 0:
|
|
96
|
+
plateau_schedule = np.full(plateau_iters, self.base_value, dtype=np.float64)
|
|
97
|
+
else:
|
|
98
|
+
plateau_schedule = np.array([], dtype=np.float64)
|
|
99
|
+
|
|
100
|
+
# Concatenate all parts of the schedule
|
|
101
|
+
self.scheduled_values = np.concatenate(
|
|
102
|
+
(
|
|
103
|
+
freeze_schedule,
|
|
104
|
+
warmup_schedule,
|
|
105
|
+
plateau_schedule,
|
|
106
|
+
cosine_annealing_schedule,
|
|
107
|
+
)
|
|
108
|
+
)
|
|
109
|
+
self._verify()
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
@override
|
|
113
|
+
def _verify(self) -> None:
|
|
114
|
+
if len(self.scheduled_values) != self.num_iters:
|
|
115
|
+
raise ValueError(
|
|
116
|
+
f"Scheduler length ({len(self.scheduled_values)}) does not match num_iters ({self.num_iters})."
|
|
117
|
+
)
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
@override
|
|
121
|
+
def step(self, it: int) -> None | float:
|
|
122
|
+
raise NotImplementedError
|
|
123
|
+
|
|
124
|
+
def _get_value(self, it: int) -> float:
|
|
125
|
+
if len(self.scheduled_values) == 0:
|
|
126
|
+
self._create_scheduler()
|
|
127
|
+
|
|
128
|
+
if it >= self.num_iters:
|
|
129
|
+
value: float = self.final_value
|
|
130
|
+
else:
|
|
131
|
+
value: float = self.scheduled_values[it]
|
|
132
|
+
self.current_value_ = value
|
|
133
|
+
return value
|
|
134
|
+
|
|
135
|
+
@override
|
|
136
|
+
def current_value(self) -> dict[str, float]:
|
|
137
|
+
return {self.param_name: self.current_value_}
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class CosineWithPlateuScheduler(_CosineWithPlateauSchedulerCore):
|
|
141
|
+
"""
|
|
142
|
+
Applies a cosine schedule with plateau to an optimizer param-group field.
|
|
143
|
+
|
|
144
|
+
Schedule phases: freeze (0) → warmup → plateau (base_value) → cosine annealing to final_value.
|
|
145
|
+
The plateau phase maintains the base_value before cosine annealing begins.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
def __init__(
|
|
149
|
+
self,
|
|
150
|
+
optimizer: torch.optim.Optimizer,
|
|
151
|
+
param_group_field: str,
|
|
152
|
+
num_iters: int,
|
|
153
|
+
base_value: float,
|
|
154
|
+
final_value: float,
|
|
155
|
+
plateau_ratio: float,
|
|
156
|
+
warmup_value: float | None = None,
|
|
157
|
+
warmup_ratio: float | None = None,
|
|
158
|
+
freeze_ratio: float | None = None,
|
|
159
|
+
multiplier_field: str | None = None,
|
|
160
|
+
skip_if_zero: bool = False,
|
|
161
|
+
apply_if_field: str | None = None,
|
|
162
|
+
ignore_if_field: str | None = None,
|
|
163
|
+
) -> None:
|
|
164
|
+
"""
|
|
165
|
+
Configure cosine scheduling for matching optimizer groups.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
optimizer: Optimizer whose param groups are updated in-place.
|
|
169
|
+
param_group_field: Name of the field that receives the scheduled value.
|
|
170
|
+
num_iters: Number of scheduler iterations before clamping at ``final_value``.
|
|
171
|
+
base_value: Value maintained during plateau phase and used as cosine start.
|
|
172
|
+
final_value: Value approached as iterations progress during cosine annealing.
|
|
173
|
+
plateau_ratio: Fraction of iterations to maintain ``base_value`` before cosine annealing.
|
|
174
|
+
warmup_ratio: Optional fraction of iterations to linearly ramp from ``warmup_value`` to ``base_value``.
|
|
175
|
+
warmup_value: Starting value for the warmup ramp.
|
|
176
|
+
freeze_ratio: Optional fraction of iterations to keep the value frozen at zero at the beginning.
|
|
177
|
+
multiplier_field: Optional per-group multiplier applied to the scheduled value.
|
|
178
|
+
skip_if_zero: Leave groups untouched when their target field equals zero.
|
|
179
|
+
apply_if_field: Require this flag to be present in a param group before updating.
|
|
180
|
+
ignore_if_field: Skip groups that declare this flag.
|
|
181
|
+
|
|
182
|
+
"""
|
|
183
|
+
self.apply_if_field = apply_if_field
|
|
184
|
+
self.ignore_if_field = ignore_if_field
|
|
185
|
+
self.optimizer = optimizer
|
|
186
|
+
self.multiplier_field = multiplier_field
|
|
187
|
+
self.skip_if_zero = skip_if_zero
|
|
188
|
+
super().__init__(
|
|
189
|
+
param_name=param_group_field,
|
|
190
|
+
num_iters=num_iters,
|
|
191
|
+
base_value=base_value,
|
|
192
|
+
final_value=final_value,
|
|
193
|
+
plateau_ratio=plateau_ratio,
|
|
194
|
+
warmup_ratio=warmup_ratio,
|
|
195
|
+
warmup_value=warmup_value,
|
|
196
|
+
freeze_ratio=freeze_ratio,
|
|
197
|
+
)
|
|
198
|
+
self.param_group_field = param_group_field
|
|
199
|
+
return
|
|
200
|
+
|
|
201
|
+
@override
|
|
202
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
203
|
+
self.__dict__.update(state_dict)
|
|
204
|
+
self.scheduled_values = np.array([], dtype=np.float64)
|
|
205
|
+
return
|
|
206
|
+
|
|
207
|
+
@override
|
|
208
|
+
def state_dict(self) -> dict[str, Any]:
|
|
209
|
+
state = {
|
|
210
|
+
k: v
|
|
211
|
+
for k, v in self.__dict__.items()
|
|
212
|
+
if k not in ["scheduled_values", "optimizer"]
|
|
213
|
+
}
|
|
214
|
+
return state
|
|
215
|
+
|
|
216
|
+
@override
|
|
217
|
+
def step(self, it: int) -> None:
|
|
218
|
+
value = self._get_value(it)
|
|
219
|
+
for pg in self.optimizer.param_groups:
|
|
220
|
+
if self.param_group_field not in pg:
|
|
221
|
+
raise ValueError(
|
|
222
|
+
f"Parameter group field '{self.param_group_field}' not found in optimizer parameter groups."
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
if (self.apply_if_field is not None) and (self.apply_if_field not in pg):
|
|
226
|
+
continue
|
|
227
|
+
|
|
228
|
+
if (self.ignore_if_field is not None) and (self.ignore_if_field in pg):
|
|
229
|
+
continue
|
|
230
|
+
|
|
231
|
+
if self.skip_if_zero and pg[self.param_group_field] == 0:
|
|
232
|
+
continue
|
|
233
|
+
|
|
234
|
+
if self.multiplier_field is not None:
|
|
235
|
+
if self.multiplier_field not in pg:
|
|
236
|
+
multiplier = 1.0
|
|
237
|
+
else:
|
|
238
|
+
multiplier = pg[self.multiplier_field]
|
|
239
|
+
pg[self.param_group_field] = value * multiplier
|
|
240
|
+
else:
|
|
241
|
+
pg[self.param_group_field] = value
|
|
242
|
+
return
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class CosineWithPlateauParamScheduler(_CosineWithPlateauSchedulerCore):
|
|
246
|
+
"""
|
|
247
|
+
Standalone cosine scheduler with plateau for non-optimizer parameters.
|
|
248
|
+
|
|
249
|
+
Schedule phases: freeze (0) → warmup → plateau (base_value) → cosine annealing to final_value.
|
|
250
|
+
The plateau phase maintains the base_value before cosine annealing begins.
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
@override
|
|
254
|
+
def step(self, it: int) -> float:
|
|
255
|
+
"""
|
|
256
|
+
Computes the value corresponding to the given iteration step.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
it: The current iteration index used for value computation.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
The computed value for the provided iteration step as a float.
|
|
263
|
+
|
|
264
|
+
"""
|
|
265
|
+
value = self._get_value(it)
|
|
266
|
+
return value
|
|
267
|
+
|
|
268
|
+
@override
|
|
269
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
270
|
+
self.__dict__.update(state_dict)
|
|
271
|
+
self.scheduled_values = np.array([], dtype=np.float64)
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
@override
|
|
275
|
+
def state_dict(self) -> dict[str, Any]:
|
|
276
|
+
state = {k: v for k, v in self.__dict__.items() if k != "scheduled_values"}
|
|
277
|
+
return state
|
kostyl/ml/schedulers/linear.py
CHANGED
|
@@ -21,24 +21,23 @@ class _LinearScheduleBase(BaseScheduler):
|
|
|
21
21
|
self.start_value = start_value
|
|
22
22
|
self.final_value = final_value
|
|
23
23
|
|
|
24
|
-
self.
|
|
24
|
+
self.scheduled_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
|
|
25
25
|
self.current_value_ = self.start_value
|
|
26
26
|
return
|
|
27
27
|
|
|
28
28
|
def _create_scheduler(self) -> None:
|
|
29
|
-
self.
|
|
29
|
+
self.scheduled_values = np.linspace(
|
|
30
30
|
self.start_value, self.final_value, num=self.num_iters, dtype=np.float64
|
|
31
31
|
)
|
|
32
|
-
|
|
33
|
-
raise ValueError(
|
|
34
|
-
f"Scheduler length ({len(self.scheduler_values)}) does not match total_iters ({self.num_iters})."
|
|
35
|
-
)
|
|
32
|
+
self._verify()
|
|
36
33
|
return
|
|
37
34
|
|
|
38
35
|
@override
|
|
39
|
-
def
|
|
40
|
-
|
|
41
|
-
|
|
36
|
+
def _verify(self) -> None:
|
|
37
|
+
if len(self.scheduled_values) != self.num_iters:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
f"Scheduler length ({len(self.scheduled_values)}) does not match total_iters ({self.num_iters})."
|
|
40
|
+
)
|
|
42
41
|
return
|
|
43
42
|
|
|
44
43
|
@override
|
|
@@ -46,13 +45,13 @@ class _LinearScheduleBase(BaseScheduler):
|
|
|
46
45
|
raise NotImplementedError
|
|
47
46
|
|
|
48
47
|
def _get_value(self, it: int) -> float:
|
|
49
|
-
if len(self.
|
|
48
|
+
if len(self.scheduled_values) == 0:
|
|
50
49
|
self._create_scheduler()
|
|
51
50
|
|
|
52
51
|
if it >= self.num_iters:
|
|
53
52
|
value: float = self.final_value
|
|
54
53
|
else:
|
|
55
|
-
value: float = self.
|
|
54
|
+
value: float = self.scheduled_values[it]
|
|
56
55
|
self.current_value_ = value
|
|
57
56
|
return value
|
|
58
57
|
|
|
@@ -105,6 +104,21 @@ class LinearScheduler(_LinearScheduleBase):
|
|
|
105
104
|
self.param_group_field = param_group_field
|
|
106
105
|
return
|
|
107
106
|
|
|
107
|
+
@override
|
|
108
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
109
|
+
self.__dict__.update(state_dict)
|
|
110
|
+
self.scheduled_values = np.array([], dtype=np.float64)
|
|
111
|
+
return
|
|
112
|
+
|
|
113
|
+
@override
|
|
114
|
+
def state_dict(self) -> dict[str, Any]:
|
|
115
|
+
state = {
|
|
116
|
+
k: v
|
|
117
|
+
for k, v in self.__dict__.items()
|
|
118
|
+
if k not in ["scheduled_values", "optimizer"]
|
|
119
|
+
}
|
|
120
|
+
return state
|
|
121
|
+
|
|
108
122
|
@override
|
|
109
123
|
def step(self, it: int) -> None:
|
|
110
124
|
value = self._get_value(it)
|
|
@@ -137,6 +151,17 @@ class LinearScheduler(_LinearScheduleBase):
|
|
|
137
151
|
class LinearParamScheduler(_LinearScheduleBase):
|
|
138
152
|
"""LinearParamScheduler adjusts a parameter value using a linear scheduler."""
|
|
139
153
|
|
|
154
|
+
@override
|
|
155
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
156
|
+
self.__dict__.update(state_dict)
|
|
157
|
+
self.scheduled_values = np.array([], dtype=np.float64)
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
@override
|
|
161
|
+
def state_dict(self) -> dict[str, Any]:
|
|
162
|
+
state = {k: v for k, v in self.__dict__.items() if k != "scheduled_values"}
|
|
163
|
+
return state
|
|
164
|
+
|
|
140
165
|
@override
|
|
141
166
|
def step(self, it: int) -> float:
|
|
142
167
|
"""
|
kostyl/utils/logging.py
CHANGED
|
@@ -94,7 +94,7 @@ _PRESETS = {"default": _DEFAULT_FMT, "only_message": _ONLY_MESSAGE_FMT}
|
|
|
94
94
|
|
|
95
95
|
def setup_logger(
|
|
96
96
|
name: str | None = None,
|
|
97
|
-
fmt: Literal["default", "only_message"] | str = "
|
|
97
|
+
fmt: Literal["default", "only_message"] | str = "only_message",
|
|
98
98
|
level: str = "INFO",
|
|
99
99
|
add_rank: bool | None = None,
|
|
100
100
|
sink=sys.stdout,
|
|
@@ -6,32 +6,33 @@ kostyl/ml/clearml/logging_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWa
|
|
|
6
6
|
kostyl/ml/clearml/pulling_utils.py,sha256=jMlVXcYRumwWnPlELRlgEdfq5L6Wir_EcfTmOoWBLTA,4077
|
|
7
7
|
kostyl/ml/configs/__init__.py,sha256=IetcivbqYGutowLqxdKp7QR4tkXKBr4m8t4Zkk9jHZU,911
|
|
8
8
|
kostyl/ml/configs/base_model.py,sha256=Eofn14J9RsjpVx_J4rp6C19pDDCANU4hr3JtX-d0FpQ,4820
|
|
9
|
-
kostyl/ml/configs/hyperparams.py,sha256=
|
|
10
|
-
kostyl/ml/configs/training_settings.py,sha256=
|
|
9
|
+
kostyl/ml/configs/hyperparams.py,sha256=lvtbvOFEoTBAJug7FR35xMQdPLgDQjRoP2fyDP-jD7E,3305
|
|
10
|
+
kostyl/ml/configs/training_settings.py,sha256=wT9CHuLaKrLwonsc87Ee421EyFis_c9fqOgn9bSClm8,2747
|
|
11
11
|
kostyl/ml/data_processing_utils.py,sha256=jjEjV0S0wREgZkzg27ip0LpI8cQqkwe2QwATmAqm9-g,3832
|
|
12
12
|
kostyl/ml/dist_utils.py,sha256=Onf0KHVLA8oeUgZTcTdmR9qiM22f2uYLoNwgLbMGJWk,3495
|
|
13
13
|
kostyl/ml/lightning/__init__.py,sha256=R36PImjVvzBF9t_z9u6RYVnUFJJ-sNDUOdboWUojHmM,173
|
|
14
14
|
kostyl/ml/lightning/callbacks/__init__.py,sha256=EnKkNwwNDZnEqKRlpY4FVrqP88ECPF6nlT2bSLUIKRk,194
|
|
15
|
-
kostyl/ml/lightning/callbacks/checkpoint.py,sha256=
|
|
15
|
+
kostyl/ml/lightning/callbacks/checkpoint.py,sha256=COW7WErj4EMxJNMn97WQO-G2A3LbI6GQOCpIZu3Cblk,19060
|
|
16
16
|
kostyl/ml/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
|
|
17
17
|
kostyl/ml/lightning/extensions/__init__.py,sha256=OY6QGv1agYgqqKf1xJBrxgp_i8FunVfPzYezfaRrGXU,182
|
|
18
18
|
kostyl/ml/lightning/extensions/custom_module.py,sha256=iQrnPz-WTmRfvLo94C5fQc2Qwa1IpHtUh1sCpVwTSFM,6602
|
|
19
|
-
kostyl/ml/lightning/extensions/pretrained_model.py,sha256=
|
|
19
|
+
kostyl/ml/lightning/extensions/pretrained_model.py,sha256=eRfQBzAjVernHl9A4PP5uTLvjjmcNKPdTu7ABFLq7HI,5196
|
|
20
20
|
kostyl/ml/lightning/loggers/__init__.py,sha256=e51dszaoJbuzwBkbdugmuDsPldoSO4yaRgmZUg1Bdy0,71
|
|
21
21
|
kostyl/ml/lightning/loggers/tb_logger.py,sha256=j02HK5ue8yzXXV8FWKmmXyHkFlIxgHx-ahHWk_rFCZs,893
|
|
22
|
-
kostyl/ml/lightning/
|
|
22
|
+
kostyl/ml/lightning/utils.py,sha256=imvMbgOKRtCUiiRGEcVtN-hxw-aEFKHdCWc0J_CIZp4,1980
|
|
23
23
|
kostyl/ml/metrics_formatting.py,sha256=U6vdNENZLvp2dT1L3HqFKtXrHwGKoDXN93hvamPGHjM,1341
|
|
24
24
|
kostyl/ml/params_groups.py,sha256=nUyw5d06Pvy9QPiYtZzLYR87xwXqJLxbHthgQH8oSCM,3583
|
|
25
|
-
kostyl/ml/registry_uploader.py,sha256=
|
|
26
|
-
kostyl/ml/schedulers/__init__.py,sha256=
|
|
27
|
-
kostyl/ml/schedulers/base.py,sha256=
|
|
25
|
+
kostyl/ml/registry_uploader.py,sha256=BbyLXvF8AL145k7g6MRkJ7gf_3Um53p3Pn5280vVD9U,4384
|
|
26
|
+
kostyl/ml/schedulers/__init__.py,sha256=_EtZu8DwTCSv4-eR84kRstEZblHylVqda7WQUOXIKfw,534
|
|
27
|
+
kostyl/ml/schedulers/base.py,sha256=bjmwgdZpnSqpCnHPnKC6MEiRO79cwxMJpZq-eQVNs2M,1353
|
|
28
28
|
kostyl/ml/schedulers/composite.py,sha256=ee4xlMDMMtjKPkbTF2ue9GTr9DuGCGjZWf11mHbi6aE,2387
|
|
29
|
-
kostyl/ml/schedulers/cosine.py,sha256=
|
|
30
|
-
kostyl/ml/schedulers/
|
|
29
|
+
kostyl/ml/schedulers/cosine.py,sha256=y8ylrgVOkVcr2-ExoqqNW--tdDX88TBYPQCOppIf2_M,8685
|
|
30
|
+
kostyl/ml/schedulers/cosine_with_plateu.py,sha256=0-X6wl3HgsTiLIbISb9lOxIVWXHDEND7rILitMWtIiM,10195
|
|
31
|
+
kostyl/ml/schedulers/linear.py,sha256=RnnnblRuRXP3LT03QVIHUaK2kNsiMP1AedrMoeyh3qk,5843
|
|
31
32
|
kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
|
|
32
33
|
kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
|
|
33
34
|
kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
|
|
34
|
-
kostyl/utils/logging.py,sha256=
|
|
35
|
-
kostyl_toolkit-0.1.
|
|
36
|
-
kostyl_toolkit-0.1.
|
|
37
|
-
kostyl_toolkit-0.1.
|
|
35
|
+
kostyl/utils/logging.py,sha256=LSbyQFLAIa89xPb4tcobE2BwVIHHUSaDXqOIKVzLoWs,5801
|
|
36
|
+
kostyl_toolkit-0.1.36.dist-info/WHEEL,sha256=eycQt0QpYmJMLKpE3X9iDk8R04v2ZF0x82ogq-zP6bQ,79
|
|
37
|
+
kostyl_toolkit-0.1.36.dist-info/METADATA,sha256=Lfyx6u3LKZ6co4s7GZgJp31zoy-NViriSGqwjIzOQFA,4269
|
|
38
|
+
kostyl_toolkit-0.1.36.dist-info/RECORD,,
|
|
@@ -1,241 +0,0 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
2
|
-
from dataclasses import fields
|
|
3
|
-
from pathlib import Path
|
|
4
|
-
from typing import Literal
|
|
5
|
-
from typing import cast
|
|
6
|
-
|
|
7
|
-
import lightning as L
|
|
8
|
-
import torch
|
|
9
|
-
import torch.distributed as dist
|
|
10
|
-
from clearml import OutputModel
|
|
11
|
-
from clearml import Task
|
|
12
|
-
from lightning.pytorch.callbacks import Callback
|
|
13
|
-
from lightning.pytorch.callbacks import EarlyStopping
|
|
14
|
-
from lightning.pytorch.callbacks import LearningRateMonitor
|
|
15
|
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
16
|
-
from lightning.pytorch.loggers import TensorBoardLogger
|
|
17
|
-
from lightning.pytorch.strategies import DDPStrategy
|
|
18
|
-
from lightning.pytorch.strategies import FSDPStrategy
|
|
19
|
-
from torch.distributed import ProcessGroup
|
|
20
|
-
from torch.distributed.fsdp import MixedPrecision
|
|
21
|
-
from torch.nn import Module
|
|
22
|
-
|
|
23
|
-
from kostyl.ml.configs import CheckpointConfig
|
|
24
|
-
from kostyl.ml.configs import DDPStrategyConfig
|
|
25
|
-
from kostyl.ml.configs import EarlyStoppingConfig
|
|
26
|
-
from kostyl.ml.configs import FSDP1StrategyConfig
|
|
27
|
-
from kostyl.ml.configs import SingleDeviceStrategyConfig
|
|
28
|
-
from kostyl.ml.lightning.callbacks import setup_checkpoint_callback
|
|
29
|
-
from kostyl.ml.lightning.callbacks import setup_early_stopping_callback
|
|
30
|
-
from kostyl.ml.lightning.loggers import setup_tb_logger
|
|
31
|
-
from kostyl.ml.registry_uploader import ClearMLRegistryUploaderCallback
|
|
32
|
-
from kostyl.utils.logging import setup_logger
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
TRAINING_STRATEGIES = (
|
|
36
|
-
FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
logger = setup_logger(add_rank=True)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def estimate_total_steps(
|
|
43
|
-
trainer: L.Trainer, process_group: ProcessGroup | None = None
|
|
44
|
-
) -> int:
|
|
45
|
-
"""
|
|
46
|
-
Estimates the total number of training steps based on the
|
|
47
|
-
dataloader length, accumulation steps, and distributed world size.
|
|
48
|
-
""" # noqa: D205
|
|
49
|
-
if dist.is_initialized():
|
|
50
|
-
world_size = dist.get_world_size(process_group)
|
|
51
|
-
else:
|
|
52
|
-
world_size = 1
|
|
53
|
-
|
|
54
|
-
datamodule = trainer.datamodule # type: ignore
|
|
55
|
-
if datamodule is None:
|
|
56
|
-
raise ValueError("Trainer must have a datamodule to estimate total steps.")
|
|
57
|
-
datamodule = cast(L.LightningDataModule, datamodule)
|
|
58
|
-
|
|
59
|
-
logger.info("Loading `train_dataloader` to estimate number of stepping batches.")
|
|
60
|
-
datamodule.setup("fit")
|
|
61
|
-
|
|
62
|
-
dataloader_len = len(datamodule.train_dataloader())
|
|
63
|
-
steps_per_epoch = dataloader_len // trainer.accumulate_grad_batches // world_size
|
|
64
|
-
|
|
65
|
-
if trainer.max_epochs is None:
|
|
66
|
-
raise ValueError("Trainer must have `max_epochs` set to estimate total steps.")
|
|
67
|
-
total_steps = steps_per_epoch * trainer.max_epochs
|
|
68
|
-
|
|
69
|
-
logger.info(
|
|
70
|
-
f"Total steps: {total_steps} (per-epoch: {steps_per_epoch})\n"
|
|
71
|
-
f"-> Dataloader len: {dataloader_len}\n"
|
|
72
|
-
f"-> Accumulate grad batches: {trainer.accumulate_grad_batches}\n"
|
|
73
|
-
f"-> Epochs: {trainer.max_epochs}\n "
|
|
74
|
-
f"-> World size: {world_size}"
|
|
75
|
-
)
|
|
76
|
-
return total_steps
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
@dataclass
|
|
80
|
-
class Callbacks:
|
|
81
|
-
"""Dataclass to hold PyTorch Lightning callbacks."""
|
|
82
|
-
|
|
83
|
-
checkpoint: ModelCheckpoint
|
|
84
|
-
lr_monitor: LearningRateMonitor
|
|
85
|
-
early_stopping: EarlyStopping | None = None
|
|
86
|
-
|
|
87
|
-
def to_list(self) -> list[Callback]:
|
|
88
|
-
"""Convert dataclass fields to a list of Callbacks. None values are omitted."""
|
|
89
|
-
callbacks: list[Callback] = [
|
|
90
|
-
getattr(self, field.name)
|
|
91
|
-
for field in fields(self)
|
|
92
|
-
if getattr(self, field.name) is not None
|
|
93
|
-
]
|
|
94
|
-
return callbacks
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def setup_callbacks(
|
|
98
|
-
task: Task,
|
|
99
|
-
root_path: Path,
|
|
100
|
-
checkpoint_cfg: CheckpointConfig,
|
|
101
|
-
early_stopping_cfg: EarlyStoppingConfig | None,
|
|
102
|
-
output_model: OutputModel,
|
|
103
|
-
checkpoint_upload_strategy: Literal["only-best", "every-checkpoint"],
|
|
104
|
-
config_dict: dict[str, str] | None = None,
|
|
105
|
-
enable_tag_versioning: bool = False,
|
|
106
|
-
) -> Callbacks:
|
|
107
|
-
"""
|
|
108
|
-
Set up PyTorch Lightning callbacks for training.
|
|
109
|
-
|
|
110
|
-
Creates and configures a set of callbacks including checkpoint saving,
|
|
111
|
-
learning rate monitoring, model registry uploading, and optional early stopping.
|
|
112
|
-
|
|
113
|
-
Args:
|
|
114
|
-
task: ClearML task for organizing checkpoints by task name and ID.
|
|
115
|
-
root_path: Root directory for saving checkpoints.
|
|
116
|
-
checkpoint_cfg: Configuration for checkpoint saving behavior.
|
|
117
|
-
checkpoint_upload_strategy: Model upload strategy:
|
|
118
|
-
- `"only-best"`: Upload only the best checkpoint based on monitored metric.
|
|
119
|
-
- `"every-checkpoint"`: Upload every saved checkpoint.
|
|
120
|
-
output_model: ClearML OutputModel instance for model registry integration.
|
|
121
|
-
early_stopping_cfg: Configuration for early stopping. If None, early stopping
|
|
122
|
-
is disabled.
|
|
123
|
-
config_dict: Optional configuration dictionary to store with the model
|
|
124
|
-
in the registry.
|
|
125
|
-
enable_tag_versioning: Whether to auto-increment version tags (e.g., "v1.0")
|
|
126
|
-
on the uploaded model.
|
|
127
|
-
|
|
128
|
-
Returns:
|
|
129
|
-
Callbacks dataclass containing configured checkpoint, lr_monitor,
|
|
130
|
-
and optionally early_stopping callbacks.
|
|
131
|
-
|
|
132
|
-
"""
|
|
133
|
-
lr_monitor = LearningRateMonitor(
|
|
134
|
-
logging_interval="step", log_weight_decay=True, log_momentum=False
|
|
135
|
-
)
|
|
136
|
-
model_uploader = ClearMLRegistryUploaderCallback(
|
|
137
|
-
output_model=output_model,
|
|
138
|
-
config_dict=config_dict,
|
|
139
|
-
verbose=True,
|
|
140
|
-
enable_tag_versioning=enable_tag_versioning,
|
|
141
|
-
)
|
|
142
|
-
checkpoint_callback = setup_checkpoint_callback(
|
|
143
|
-
root_path / "checkpoints" / task.name / task.id,
|
|
144
|
-
checkpoint_cfg,
|
|
145
|
-
registry_uploader_callback=model_uploader,
|
|
146
|
-
uploading_strategy=checkpoint_upload_strategy,
|
|
147
|
-
)
|
|
148
|
-
if early_stopping_cfg is not None:
|
|
149
|
-
early_stopping_callback = setup_early_stopping_callback(early_stopping_cfg)
|
|
150
|
-
else:
|
|
151
|
-
early_stopping_callback = None
|
|
152
|
-
|
|
153
|
-
callbacks = Callbacks(
|
|
154
|
-
checkpoint=checkpoint_callback,
|
|
155
|
-
lr_monitor=lr_monitor,
|
|
156
|
-
early_stopping=early_stopping_callback,
|
|
157
|
-
)
|
|
158
|
-
return callbacks
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
def setup_loggers(task: Task, root_path: Path) -> list[TensorBoardLogger]:
|
|
162
|
-
"""
|
|
163
|
-
Set up PyTorch Lightning loggers for training.
|
|
164
|
-
|
|
165
|
-
Args:
|
|
166
|
-
task: ClearML task used to organize log directories by task name and ID.
|
|
167
|
-
root_path: Root directory for storing TensorBoard logs.
|
|
168
|
-
|
|
169
|
-
Returns:
|
|
170
|
-
List of configured TensorBoard loggers.
|
|
171
|
-
|
|
172
|
-
"""
|
|
173
|
-
loggers = [
|
|
174
|
-
setup_tb_logger(root_path / "runs" / task.name / task.id),
|
|
175
|
-
]
|
|
176
|
-
return loggers
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
def setup_strategy(
|
|
180
|
-
strategy_settings: TRAINING_STRATEGIES,
|
|
181
|
-
devices: list[int] | int,
|
|
182
|
-
auto_wrap_policy: set[type[Module]] | None = None,
|
|
183
|
-
) -> Literal["auto"] | FSDPStrategy | DDPStrategy:
|
|
184
|
-
"""
|
|
185
|
-
Configure and return a PyTorch Lightning training strategy.
|
|
186
|
-
|
|
187
|
-
Args:
|
|
188
|
-
strategy_settings: Strategy configuration object. Must be one of:
|
|
189
|
-
- `FSDP1StrategyConfig`: Fully Sharded Data Parallel strategy (requires 2+ devices).
|
|
190
|
-
- `DDPStrategyConfig`: Distributed Data Parallel strategy (requires 2+ devices).
|
|
191
|
-
- `SingleDeviceStrategyConfig`: Single device training (requires exactly 1 device).
|
|
192
|
-
devices: Device(s) to use for training. Either a list of device IDs or
|
|
193
|
-
a single integer representing the number of devices.
|
|
194
|
-
auto_wrap_policy: Set of module types that should be wrapped for FSDP.
|
|
195
|
-
Required when using `FSDP1StrategyConfig`, ignored otherwise.
|
|
196
|
-
|
|
197
|
-
Returns:
|
|
198
|
-
Configured strategy: `FSDPStrategy`, `DDPStrategy`, or `"auto"` for single device.
|
|
199
|
-
|
|
200
|
-
Raises:
|
|
201
|
-
ValueError: If device count doesn't match strategy requirements or
|
|
202
|
-
if `auto_wrap_policy` is missing for FSDP.
|
|
203
|
-
|
|
204
|
-
"""
|
|
205
|
-
if isinstance(devices, list):
|
|
206
|
-
num_devices = len(devices)
|
|
207
|
-
else:
|
|
208
|
-
num_devices = devices
|
|
209
|
-
|
|
210
|
-
match strategy_settings:
|
|
211
|
-
case FSDP1StrategyConfig():
|
|
212
|
-
if num_devices < 2:
|
|
213
|
-
raise ValueError("FSDP strategy requires multiple devices.")
|
|
214
|
-
|
|
215
|
-
if auto_wrap_policy is None:
|
|
216
|
-
raise ValueError("auto_wrap_policy must be provided for FSDP strategy.")
|
|
217
|
-
|
|
218
|
-
mixed_precision_config = MixedPrecision(
|
|
219
|
-
param_dtype=getattr(torch, strategy_settings.param_dtype),
|
|
220
|
-
reduce_dtype=getattr(torch, strategy_settings.reduce_dtype),
|
|
221
|
-
buffer_dtype=getattr(torch, strategy_settings.buffer_dtype),
|
|
222
|
-
)
|
|
223
|
-
strategy = FSDPStrategy(
|
|
224
|
-
auto_wrap_policy=auto_wrap_policy,
|
|
225
|
-
mixed_precision=mixed_precision_config,
|
|
226
|
-
)
|
|
227
|
-
case DDPStrategyConfig():
|
|
228
|
-
if num_devices < 2:
|
|
229
|
-
raise ValueError("DDP strategy requires at least two devices.")
|
|
230
|
-
strategy = DDPStrategy(
|
|
231
|
-
find_unused_parameters=strategy_settings.find_unused_parameters
|
|
232
|
-
)
|
|
233
|
-
case SingleDeviceStrategyConfig():
|
|
234
|
-
if num_devices != 1:
|
|
235
|
-
raise ValueError("SingleDevice strategy requires exactly one device.")
|
|
236
|
-
strategy = "auto"
|
|
237
|
-
case _:
|
|
238
|
-
raise ValueError(
|
|
239
|
-
f"Unsupported strategy type: {type(strategy_settings.trainer.strategy)}"
|
|
240
|
-
)
|
|
241
|
-
return strategy
|