kostyl-toolkit 0.1.35__py3-none-any.whl → 0.1.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kostyl/ml/configs/hyperparams.py +21 -5
- kostyl/ml/configs/training_settings.py +17 -6
- kostyl/ml/dist_utils.py +52 -30
- kostyl/ml/lightning/callbacks/checkpoint.py +10 -10
- kostyl/ml/lightning/extensions/custom_module.py +0 -5
- kostyl/ml/lightning/extensions/pretrained_model.py +6 -4
- kostyl/ml/lightning/loggers/tb_logger.py +2 -2
- kostyl/ml/lightning/utils.py +58 -0
- kostyl/ml/registry_uploader.py +56 -29
- kostyl/ml/schedulers/__init__.py +13 -1
- kostyl/ml/schedulers/base.py +9 -7
- kostyl/ml/schedulers/cosine.py +53 -24
- kostyl/ml/schedulers/cosine_with_plateu.py +277 -0
- kostyl/ml/schedulers/linear.py +36 -11
- kostyl/utils/logging.py +68 -53
- {kostyl_toolkit-0.1.35.dist-info → kostyl_toolkit-0.1.37.dist-info}/METADATA +1 -1
- {kostyl_toolkit-0.1.35.dist-info → kostyl_toolkit-0.1.37.dist-info}/RECORD +18 -17
- {kostyl_toolkit-0.1.35.dist-info → kostyl_toolkit-0.1.37.dist-info}/WHEEL +1 -1
- kostyl/ml/lightning/training_utils.py +0 -241
kostyl/ml/configs/hyperparams.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
1
3
|
from pydantic import BaseModel
|
|
2
4
|
from pydantic import Field
|
|
3
5
|
from pydantic import model_validator
|
|
@@ -8,11 +10,25 @@ from kostyl.utils.logging import setup_logger
|
|
|
8
10
|
logger = setup_logger(fmt="only_message")
|
|
9
11
|
|
|
10
12
|
|
|
11
|
-
class
|
|
12
|
-
"""
|
|
13
|
+
class AdamConfig(BaseModel):
|
|
14
|
+
"""AdamW optimizer hyperparameters configuration."""
|
|
15
|
+
|
|
16
|
+
type: Literal["AdamW"] = "AdamW"
|
|
17
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
18
|
+
is_adamw: bool = True
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AdamWithPrecisionConfig(BaseModel):
|
|
22
|
+
"""Adam optimizer with low-precision hyperparameters configuration."""
|
|
23
|
+
|
|
24
|
+
type: Literal["Adam8bit", "Adam4bit", "AdamFp8"]
|
|
25
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
26
|
+
block_size: int
|
|
27
|
+
bf16_stochastic_round: bool = False
|
|
28
|
+
is_adamw: bool = True
|
|
29
|
+
|
|
13
30
|
|
|
14
|
-
|
|
15
|
-
adamw_beta2: float = 0.999
|
|
31
|
+
Optimizer = AdamConfig | AdamWithPrecisionConfig
|
|
16
32
|
|
|
17
33
|
|
|
18
34
|
class Lr(BaseModel):
|
|
@@ -73,6 +89,6 @@ class HyperparamsConfig(BaseModel):
|
|
|
73
89
|
"""Model training hyperparameters configuration."""
|
|
74
90
|
|
|
75
91
|
grad_clip_val: float | None = Field(default=None, gt=0, validate_default=False)
|
|
76
|
-
optimizer: Optimizer
|
|
92
|
+
optimizer: Optimizer
|
|
77
93
|
lr: Lr
|
|
78
94
|
weight_decay: WeightDecay
|
|
@@ -25,21 +25,31 @@ PRECISION = Literal[
|
|
|
25
25
|
"16",
|
|
26
26
|
"bf16",
|
|
27
27
|
]
|
|
28
|
+
DTYPE = Literal["float32", "float16", "bfloat16", "float64"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SingleDeviceStrategyConfig(BaseModel):
|
|
32
|
+
"""Single device strategy configuration."""
|
|
33
|
+
|
|
34
|
+
type: Literal["single_device"]
|
|
28
35
|
|
|
29
36
|
|
|
30
37
|
class FSDP1StrategyConfig(BaseModel):
|
|
31
38
|
"""Fully Sharded Data Parallel (FSDP) strategy configuration."""
|
|
32
39
|
|
|
33
40
|
type: Literal["fsdp1"]
|
|
34
|
-
param_dtype:
|
|
35
|
-
reduce_dtype:
|
|
36
|
-
buffer_dtype:
|
|
41
|
+
param_dtype: DTYPE | None
|
|
42
|
+
reduce_dtype: DTYPE | None
|
|
43
|
+
buffer_dtype: DTYPE | None
|
|
37
44
|
|
|
38
45
|
|
|
39
|
-
class
|
|
40
|
-
"""
|
|
46
|
+
class FSDP2StrategyConfig(BaseModel):
|
|
47
|
+
"""Fully Sharded Data Parallel (FSDP) strategy configuration."""
|
|
41
48
|
|
|
42
|
-
type: Literal["
|
|
49
|
+
type: Literal["fsdp2"]
|
|
50
|
+
param_dtype: DTYPE | None
|
|
51
|
+
reduce_dtype: DTYPE | None
|
|
52
|
+
buffer_dtype: DTYPE | None
|
|
43
53
|
|
|
44
54
|
|
|
45
55
|
class DDPStrategyConfig(BaseModel):
|
|
@@ -82,6 +92,7 @@ class CheckpointConfig(BaseModel):
|
|
|
82
92
|
monitor: str = "val_loss"
|
|
83
93
|
mode: str = "min"
|
|
84
94
|
filename: str = "{epoch:02d}-{val_loss:.2f}"
|
|
95
|
+
save_weights_only: bool = True
|
|
85
96
|
|
|
86
97
|
|
|
87
98
|
class DataConfig(BaseModel):
|
kostyl/ml/dist_utils.py
CHANGED
|
@@ -4,38 +4,61 @@ from typing import Literal
|
|
|
4
4
|
|
|
5
5
|
import torch.distributed as dist
|
|
6
6
|
|
|
7
|
+
from kostyl.utils.logging import KostylLogger
|
|
7
8
|
from kostyl.utils.logging import setup_logger
|
|
8
9
|
|
|
9
10
|
|
|
10
|
-
|
|
11
|
+
module_logger = setup_logger()
|
|
11
12
|
|
|
12
13
|
|
|
13
|
-
def log_dist(
|
|
14
|
+
def log_dist(
|
|
15
|
+
msg: str,
|
|
16
|
+
logger: KostylLogger | None = None,
|
|
17
|
+
level: Literal["info", "warning", "error", "warning_once", "debug"] = "info",
|
|
18
|
+
log_scope: Literal["only-zero-rank", "world"] = "world",
|
|
19
|
+
group: dist.ProcessGroup | None = None,
|
|
20
|
+
) -> None:
|
|
14
21
|
"""
|
|
15
22
|
Log a message in a distributed environment based on the specified verbosity level.
|
|
16
23
|
|
|
17
24
|
Args:
|
|
18
25
|
msg (str): The message to log.
|
|
19
|
-
|
|
26
|
+
log_scope (Literal["only-zero-rank", "world"]): The verbosity level for logging.
|
|
20
27
|
- "only-zero-rank": Log only from the main process (rank 0).
|
|
21
28
|
- "world": Log from all processes in the distributed environment.
|
|
29
|
+
logger (KostylLogger | None): The logger instance to use. If None, the module logger is used.
|
|
30
|
+
level (Literal["info", "warning", "error", "warning_once", "debug"]): The logging level.
|
|
31
|
+
group (dist.ProcessGroup | None): Optional process group used to determine ranks. Defaults to the global process group.
|
|
22
32
|
|
|
23
33
|
"""
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
34
|
+
if logger is None:
|
|
35
|
+
logger = module_logger
|
|
36
|
+
|
|
37
|
+
log_attr = getattr(logger, level, None)
|
|
38
|
+
if log_attr is None:
|
|
39
|
+
raise ValueError(f"Invalid logging level: {level}")
|
|
40
|
+
|
|
41
|
+
if not dist.is_initialized():
|
|
42
|
+
module_logger.warning_once(
|
|
43
|
+
"Distributed process group is not initialized; logging from all ranks."
|
|
44
|
+
)
|
|
45
|
+
log_attr(msg)
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
match log_scope:
|
|
30
49
|
case "only-zero-rank":
|
|
31
|
-
if
|
|
32
|
-
|
|
50
|
+
if group is None:
|
|
51
|
+
module_logger.debug(
|
|
52
|
+
"No process group provided; assuming global group for rank check."
|
|
53
|
+
)
|
|
54
|
+
group = dist.group.WORLD
|
|
55
|
+
group_rank = dist.get_rank(group=group)
|
|
56
|
+
if dist.get_global_rank(group=group, group_rank=group_rank) == 0: # pyright: ignore[reportArgumentType]
|
|
57
|
+
log_attr(msg)
|
|
33
58
|
case "world":
|
|
34
|
-
|
|
59
|
+
log_attr(msg)
|
|
35
60
|
case _:
|
|
36
|
-
|
|
37
|
-
f"Invalid logging verbosity level requested: {how}. Message not logged."
|
|
38
|
-
)
|
|
61
|
+
raise ValueError(f"Invalid logging verbosity level: {log_scope}")
|
|
39
62
|
return
|
|
40
63
|
|
|
41
64
|
|
|
@@ -44,7 +67,7 @@ def scale_lrs_by_world_size(
|
|
|
44
67
|
group: dist.ProcessGroup | None = None,
|
|
45
68
|
config_name: str = "",
|
|
46
69
|
inv_scale: bool = False,
|
|
47
|
-
|
|
70
|
+
verbose_level: Literal["only-zero-rank", "world"] | None = None,
|
|
48
71
|
) -> dict[str, float]:
|
|
49
72
|
"""
|
|
50
73
|
Scale learning-rate configuration values to match the active distributed world size.
|
|
@@ -58,7 +81,7 @@ def scale_lrs_by_world_size(
|
|
|
58
81
|
the target world size. Defaults to the global process group.
|
|
59
82
|
config_name (str): Human-readable identifier included in log messages.
|
|
60
83
|
inv_scale (bool): If True, use the inverse square-root scale factor.
|
|
61
|
-
|
|
84
|
+
verbose_level (Literal["only-zero-rank", "world"] | None): Verbosity level for logging scaled values.
|
|
62
85
|
- "only-zero-rank": Log only from the main process (rank 0).
|
|
63
86
|
- "world": Log from all processes in the distributed environment.
|
|
64
87
|
- None: No logging.
|
|
@@ -77,31 +100,30 @@ def scale_lrs_by_world_size(
|
|
|
77
100
|
for name, value in lrs.items():
|
|
78
101
|
old_value = value
|
|
79
102
|
new_value = value * scale
|
|
80
|
-
if
|
|
103
|
+
if verbose_level is not None:
|
|
81
104
|
log_dist(
|
|
82
105
|
f"New {config_name} lr {name.upper()}: {new_value}; OLD: {old_value}",
|
|
83
|
-
|
|
106
|
+
log_scope=verbose_level,
|
|
107
|
+
group=group,
|
|
84
108
|
)
|
|
85
109
|
lrs[name] = new_value
|
|
86
110
|
return lrs
|
|
87
111
|
|
|
88
112
|
|
|
89
|
-
def
|
|
90
|
-
"""Gets the rank of the current process in a distributed setting."""
|
|
91
|
-
if dist.is_initialized():
|
|
92
|
-
return dist.get_rank()
|
|
93
|
-
if "
|
|
94
|
-
return int(os.environ["
|
|
95
|
-
if "SLURM_PROCID" in os.environ:
|
|
96
|
-
return int(os.environ["SLURM_PROCID"])
|
|
113
|
+
def get_local_rank(group: dist.ProcessGroup | None = None) -> int:
|
|
114
|
+
"""Gets the local rank of the current process in a distributed setting."""
|
|
115
|
+
if dist.is_initialized() and group is not None:
|
|
116
|
+
return dist.get_rank(group=group)
|
|
117
|
+
if "SLURM_LOCALID" in os.environ:
|
|
118
|
+
return int(os.environ["SLURM_LOCALID"])
|
|
97
119
|
if "LOCAL_RANK" in os.environ:
|
|
98
120
|
return int(os.environ["LOCAL_RANK"])
|
|
99
121
|
return 0
|
|
100
122
|
|
|
101
123
|
|
|
102
|
-
def
|
|
103
|
-
"""Checks if the current process is the main process (rank 0) in a distributed setting."""
|
|
104
|
-
rank =
|
|
124
|
+
def is_local_zero_rank() -> bool:
|
|
125
|
+
"""Checks if the current process is the main process (rank 0) for the local node in a distributed setting."""
|
|
126
|
+
rank = get_local_rank()
|
|
105
127
|
if rank != 0:
|
|
106
128
|
return False
|
|
107
129
|
return True
|
|
@@ -10,7 +10,7 @@ from lightning.fabric.utilities.types import _PATH
|
|
|
10
10
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
11
11
|
|
|
12
12
|
from kostyl.ml.configs import CheckpointConfig
|
|
13
|
-
from kostyl.ml.dist_utils import
|
|
13
|
+
from kostyl.ml.dist_utils import is_local_zero_rank
|
|
14
14
|
from kostyl.ml.lightning import KostylLightningModule
|
|
15
15
|
from kostyl.ml.registry_uploader import RegistryUploaderCallback
|
|
16
16
|
from kostyl.utils import setup_logger
|
|
@@ -299,9 +299,9 @@ class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
|
|
|
299
299
|
def setup_checkpoint_callback(
|
|
300
300
|
dirpath: Path,
|
|
301
301
|
ckpt_cfg: CheckpointConfig,
|
|
302
|
-
save_weights_only: bool = True,
|
|
303
302
|
registry_uploader_callback: RegistryUploaderCallback | None = None,
|
|
304
303
|
uploading_strategy: Literal["only-best", "every-checkpoint"] | None = None,
|
|
304
|
+
remove_folder_if_exists: bool = True,
|
|
305
305
|
) -> ModelCheckpointWithRegistryUploader | ModelCheckpoint:
|
|
306
306
|
"""
|
|
307
307
|
Create and configure a checkpoint callback for model saving.
|
|
@@ -313,14 +313,13 @@ def setup_checkpoint_callback(
|
|
|
313
313
|
Args:
|
|
314
314
|
dirpath: Path to the directory for saving checkpoints.
|
|
315
315
|
ckpt_cfg: Checkpoint configuration (filename, monitor, mode, save_top_k).
|
|
316
|
-
save_weights_only: If True, only model weights are saved without optimizer and lr-scheduler state.
|
|
317
|
-
Defaults to True.
|
|
318
316
|
registry_uploader_callback: Optional callback for uploading checkpoints to a remote registry.
|
|
319
317
|
Must be specified together with uploading_strategy.
|
|
320
318
|
uploading_strategy: Checkpoint upload mode:
|
|
321
319
|
- "only-best": only the best checkpoint is uploaded
|
|
322
320
|
- "every-checkpoint": every saved checkpoint is uploaded
|
|
323
321
|
Must be specified together with registry_uploader_callback.
|
|
322
|
+
remove_folder_if_exists: If True, removes existing checkpoint directory before creating a new one.
|
|
324
323
|
|
|
325
324
|
Returns:
|
|
326
325
|
ModelCheckpointWithRegistryUploader if registry_uploader_callback is provided,
|
|
@@ -331,7 +330,7 @@ def setup_checkpoint_callback(
|
|
|
331
330
|
|
|
332
331
|
Note:
|
|
333
332
|
If the dirpath directory already exists, it will be removed and recreated
|
|
334
|
-
(only on the main process in distributed training).
|
|
333
|
+
(only on the main process in distributed training) if remove_folder_if_exists is True.
|
|
335
334
|
|
|
336
335
|
"""
|
|
337
336
|
if (registry_uploader_callback is None) != (uploading_strategy is None):
|
|
@@ -340,10 +339,11 @@ def setup_checkpoint_callback(
|
|
|
340
339
|
)
|
|
341
340
|
|
|
342
341
|
if dirpath.exists():
|
|
343
|
-
if
|
|
342
|
+
if is_local_zero_rank():
|
|
344
343
|
logger.warning(f"Checkpoint directory {dirpath} already exists.")
|
|
345
|
-
|
|
346
|
-
|
|
344
|
+
if remove_folder_if_exists:
|
|
345
|
+
rmtree(dirpath)
|
|
346
|
+
logger.warning(f"Removed existing checkpoint directory {dirpath}.")
|
|
347
347
|
else:
|
|
348
348
|
logger.info(f"Creating checkpoint directory {dirpath}.")
|
|
349
349
|
dirpath.mkdir(parents=True, exist_ok=True)
|
|
@@ -356,7 +356,7 @@ def setup_checkpoint_callback(
|
|
|
356
356
|
monitor=ckpt_cfg.monitor,
|
|
357
357
|
mode=ckpt_cfg.mode,
|
|
358
358
|
verbose=True,
|
|
359
|
-
save_weights_only=save_weights_only,
|
|
359
|
+
save_weights_only=ckpt_cfg.save_weights_only,
|
|
360
360
|
registry_uploader_callback=registry_uploader_callback,
|
|
361
361
|
uploading_mode=uploading_strategy,
|
|
362
362
|
)
|
|
@@ -368,6 +368,6 @@ def setup_checkpoint_callback(
|
|
|
368
368
|
monitor=ckpt_cfg.monitor,
|
|
369
369
|
mode=ckpt_cfg.mode,
|
|
370
370
|
verbose=True,
|
|
371
|
-
save_weights_only=save_weights_only,
|
|
371
|
+
save_weights_only=ckpt_cfg.save_weights_only,
|
|
372
372
|
)
|
|
373
373
|
return checkpoint_callback
|
|
@@ -26,11 +26,6 @@ module_logger = setup_logger(fmt="only_message")
|
|
|
26
26
|
class KostylLightningModule(L.LightningModule):
|
|
27
27
|
"""Custom PyTorch Lightning Module with logging, checkpointing, and distributed training utilities."""
|
|
28
28
|
|
|
29
|
-
@property
|
|
30
|
-
def process_group(self) -> ProcessGroup | None:
|
|
31
|
-
"""Returns the data parallel process group for distributed training."""
|
|
32
|
-
return self.get_process_group()
|
|
33
|
-
|
|
34
29
|
def get_process_group(self) -> ProcessGroup | None:
|
|
35
30
|
"""
|
|
36
31
|
Retrieves the data parallel process group for distributed training.
|
|
@@ -12,12 +12,12 @@ from kostyl.utils.logging import setup_logger
|
|
|
12
12
|
logger = setup_logger("LightningPretrainedModelMixin", fmt="only_message")
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class LightningCheckpointLoaderMixin
|
|
15
|
+
class LightningCheckpointLoaderMixin:
|
|
16
16
|
"""A mixin class for loading pretrained models from PyTorch Lightning checkpoints."""
|
|
17
17
|
|
|
18
18
|
@classmethod
|
|
19
|
-
def from_lightning_checkpoint[TModelInstance:
|
|
20
|
-
cls: type[TModelInstance],
|
|
19
|
+
def from_lightning_checkpoint[TModelInstance: PreTrainedModel]( # noqa: C901
|
|
20
|
+
cls: type[TModelInstance], # pyright: ignore[reportGeneralTypeIssues]
|
|
21
21
|
checkpoint_path: str | Path,
|
|
22
22
|
config_key: str = "config",
|
|
23
23
|
weights_prefix: str | None = "model.",
|
|
@@ -78,7 +78,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
78
78
|
mmap=True,
|
|
79
79
|
)
|
|
80
80
|
|
|
81
|
-
#
|
|
81
|
+
# Load config
|
|
82
82
|
config_cls = cast(type[PretrainedConfig], cls.config_class)
|
|
83
83
|
config_dict = checkpoint_dict[config_key]
|
|
84
84
|
config_dict.update(kwargs)
|
|
@@ -91,6 +91,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
91
91
|
|
|
92
92
|
raw_state_dict: dict[str, torch.Tensor] = checkpoint_dict["state_dict"]
|
|
93
93
|
|
|
94
|
+
# Handle weights prefix
|
|
94
95
|
if weights_prefix:
|
|
95
96
|
if not weights_prefix.endswith("."):
|
|
96
97
|
weights_prefix = weights_prefix + "."
|
|
@@ -117,6 +118,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
|
|
|
117
118
|
else:
|
|
118
119
|
state_dict = raw_state_dict
|
|
119
120
|
|
|
121
|
+
# Instantiate model and load state dict
|
|
120
122
|
model = cls.from_pretrained(
|
|
121
123
|
pretrained_model_name_or_path=None,
|
|
122
124
|
config=config,
|
|
@@ -3,7 +3,7 @@ from shutil import rmtree
|
|
|
3
3
|
|
|
4
4
|
from lightning.pytorch.loggers import TensorBoardLogger
|
|
5
5
|
|
|
6
|
-
from kostyl.ml.dist_utils import
|
|
6
|
+
from kostyl.ml.dist_utils import is_local_zero_rank
|
|
7
7
|
from kostyl.utils.logging import setup_logger
|
|
8
8
|
|
|
9
9
|
|
|
@@ -15,7 +15,7 @@ def setup_tb_logger(
|
|
|
15
15
|
) -> TensorBoardLogger:
|
|
16
16
|
"""Sets up a TensorBoardLogger for PyTorch Lightning."""
|
|
17
17
|
if runs_dir.exists():
|
|
18
|
-
if
|
|
18
|
+
if is_local_zero_rank():
|
|
19
19
|
logger.warning(f"TensorBoard log directory {runs_dir} already exists.")
|
|
20
20
|
rmtree(runs_dir)
|
|
21
21
|
logger.warning(f"Removed existing TensorBoard log directory {runs_dir}.")
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from typing import cast
|
|
2
|
+
|
|
3
|
+
import lightning as L
|
|
4
|
+
import torch.distributed as dist
|
|
5
|
+
from torch.distributed import ProcessGroup
|
|
6
|
+
|
|
7
|
+
from kostyl.ml.configs import DDPStrategyConfig
|
|
8
|
+
from kostyl.ml.configs import FSDP1StrategyConfig
|
|
9
|
+
from kostyl.ml.configs import SingleDeviceStrategyConfig
|
|
10
|
+
from kostyl.utils.logging import setup_logger
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
TRAINING_STRATEGIES = (
|
|
14
|
+
FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
logger = setup_logger()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def estimate_total_steps(
|
|
21
|
+
trainer: L.Trainer, dp_process_group: ProcessGroup | None = None
|
|
22
|
+
) -> int:
|
|
23
|
+
"""
|
|
24
|
+
Estimates the total number of training steps with respect to data parallelism and gradient accumulation.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
trainer: The PyTorch Lightning Trainer instance.
|
|
28
|
+
dp_process_group: The data parallel process group. If None, the world process group will be used.
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
if dist.is_initialized():
|
|
32
|
+
world_size = dist.get_world_size(dp_process_group)
|
|
33
|
+
else:
|
|
34
|
+
world_size = 1
|
|
35
|
+
|
|
36
|
+
datamodule = trainer.datamodule # type: ignore
|
|
37
|
+
if datamodule is None:
|
|
38
|
+
raise ValueError("Trainer must have a datamodule to estimate total steps.")
|
|
39
|
+
datamodule = cast(L.LightningDataModule, datamodule)
|
|
40
|
+
|
|
41
|
+
logger.info("Loading `train_dataloader` to estimate number of stepping batches.")
|
|
42
|
+
datamodule.setup("fit")
|
|
43
|
+
|
|
44
|
+
dataloader_len = len(datamodule.train_dataloader())
|
|
45
|
+
steps_per_epoch = dataloader_len // trainer.accumulate_grad_batches // world_size
|
|
46
|
+
|
|
47
|
+
if trainer.max_epochs is None:
|
|
48
|
+
raise ValueError("Trainer must have `max_epochs` set to estimate total steps.")
|
|
49
|
+
total_steps = steps_per_epoch * trainer.max_epochs
|
|
50
|
+
|
|
51
|
+
logger.info(
|
|
52
|
+
f"Total steps: {total_steps} (per-epoch: {steps_per_epoch}) "
|
|
53
|
+
f"-> Dataloader len: {dataloader_len} "
|
|
54
|
+
f"-> Accumulate grad batches: {trainer.accumulate_grad_batches} "
|
|
55
|
+
f"-> Epochs: {trainer.max_epochs} "
|
|
56
|
+
f"-> DataParallel size: {world_size}"
|
|
57
|
+
)
|
|
58
|
+
return total_steps
|
kostyl/ml/registry_uploader.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
from abc import ABC
|
|
2
2
|
from abc import abstractmethod
|
|
3
3
|
from collections.abc import Callable
|
|
4
|
+
from functools import partial
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import override
|
|
6
7
|
|
|
7
8
|
from clearml import OutputModel
|
|
8
9
|
|
|
9
|
-
from kostyl.ml.clearml.logging_utils import find_version_in_tags
|
|
10
|
-
from kostyl.ml.clearml.logging_utils import increment_version
|
|
11
10
|
from kostyl.utils.logging import setup_logger
|
|
12
11
|
|
|
13
12
|
|
|
@@ -28,51 +27,79 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
|
28
27
|
|
|
29
28
|
def __init__(
|
|
30
29
|
self,
|
|
31
|
-
|
|
30
|
+
model_name: str,
|
|
32
31
|
config_dict: dict[str, str] | None = None,
|
|
32
|
+
label_enumeration: dict[str, int] | None = None,
|
|
33
|
+
tags: list[str] | None = None,
|
|
34
|
+
comment: str | None = None,
|
|
35
|
+
framework: str | None = None,
|
|
36
|
+
base_model_id: str | None = None,
|
|
37
|
+
new_model_per_upload: bool = True,
|
|
33
38
|
verbose: bool = True,
|
|
34
|
-
enable_tag_versioning: bool = False,
|
|
35
39
|
) -> None:
|
|
36
40
|
"""
|
|
37
41
|
Initializes the ClearMLRegistryUploaderCallback.
|
|
38
42
|
|
|
39
43
|
Args:
|
|
40
|
-
|
|
41
|
-
|
|
44
|
+
model_name: The name for the newly created model.
|
|
45
|
+
label_enumeration: The label enumeration dictionary of string (label) to integer (value) pairs.
|
|
42
46
|
config_dict: Optional configuration dictionary to associate with the model.
|
|
43
|
-
|
|
44
|
-
|
|
47
|
+
tags: A list of strings which are tags for the model.
|
|
48
|
+
comment: A comment / description for the model.
|
|
49
|
+
framework: The framework of the model (e.g., "PyTorch", "TensorFlow").
|
|
50
|
+
base_model_id: Optional ClearML model ID to use as a base for the new model
|
|
51
|
+
new_model_per_upload: Whether to create a new ClearML model
|
|
52
|
+
for every upload or update weights of the same model. When updating weights,
|
|
53
|
+
the last uploaded checkpoint will be replaced (and deleted).
|
|
54
|
+
verbose: Whether to log messages during upload.
|
|
45
55
|
|
|
46
56
|
"""
|
|
47
57
|
super().__init__()
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
58
|
+
if base_model_id is not None and new_model_per_upload:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"Cannot set base_model_id when new_model_per_upload is True."
|
|
61
|
+
)
|
|
52
62
|
|
|
63
|
+
self.verbose = verbose
|
|
64
|
+
self.new_model_per_upload = new_model_per_upload
|
|
53
65
|
self.best_model_path: str = ""
|
|
54
|
-
|
|
66
|
+
self.config_dict = config_dict
|
|
67
|
+
self._output_model: OutputModel | None = None
|
|
55
68
|
self._last_uploaded_model_path: str = ""
|
|
56
69
|
self._upload_callback: Callable | None = None
|
|
57
70
|
|
|
58
|
-
self._validate_tags()
|
|
71
|
+
self._validate_tags(tags)
|
|
72
|
+
self.model_fabric = partial(
|
|
73
|
+
OutputModel,
|
|
74
|
+
name=model_name,
|
|
75
|
+
label_enumeration=label_enumeration,
|
|
76
|
+
tags=tags,
|
|
77
|
+
comment=comment,
|
|
78
|
+
framework=framework,
|
|
79
|
+
base_model_id=base_model_id,
|
|
80
|
+
)
|
|
59
81
|
return
|
|
60
82
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
if
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
else:
|
|
68
|
-
new_version = increment_version(version)
|
|
69
|
-
output_model_tags.remove(version)
|
|
70
|
-
output_model_tags.append(new_version)
|
|
71
|
-
if "LightningCheckpoint" not in output_model_tags:
|
|
72
|
-
output_model_tags.append("LightningCheckpoint")
|
|
73
|
-
self.output_model.tags = output_model_tags
|
|
83
|
+
@staticmethod
|
|
84
|
+
def _validate_tags(tags: list[str] | None) -> None:
|
|
85
|
+
if tags is None:
|
|
86
|
+
return
|
|
87
|
+
if "LightningCheckpoint" not in tags:
|
|
88
|
+
tags.append("LightningCheckpoint")
|
|
74
89
|
return None
|
|
75
90
|
|
|
91
|
+
@property
|
|
92
|
+
def output_model_(self) -> OutputModel:
|
|
93
|
+
"""Returns the OutputModel instance based on `new_model_per_upload` setting."""
|
|
94
|
+
if self.new_model_per_upload:
|
|
95
|
+
model = self.model_fabric()
|
|
96
|
+
self._output_model = self.model_fabric()
|
|
97
|
+
else:
|
|
98
|
+
if self._output_model is None:
|
|
99
|
+
self._output_model = self.model_fabric()
|
|
100
|
+
model = self._output_model
|
|
101
|
+
return model
|
|
102
|
+
|
|
76
103
|
@override
|
|
77
104
|
def upload_checkpoint(
|
|
78
105
|
self,
|
|
@@ -88,12 +115,12 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
|
88
115
|
if self.verbose:
|
|
89
116
|
logger.info(f"Uploading model from {path}")
|
|
90
117
|
|
|
91
|
-
self.
|
|
118
|
+
self.output_model_.update_weights(
|
|
92
119
|
path,
|
|
93
120
|
auto_delete_file=False,
|
|
94
121
|
async_enable=False,
|
|
95
122
|
)
|
|
96
|
-
self.
|
|
123
|
+
self.output_model_.update_design(config_dict=self.config_dict)
|
|
97
124
|
|
|
98
125
|
self._last_uploaded_model_path = path
|
|
99
126
|
return
|
kostyl/ml/schedulers/__init__.py
CHANGED
|
@@ -1,6 +1,18 @@
|
|
|
1
1
|
from .composite import CompositeScheduler
|
|
2
2
|
from .cosine import CosineParamScheduler
|
|
3
3
|
from .cosine import CosineScheduler
|
|
4
|
+
from .cosine_with_plateu import CosineWithPlateauParamScheduler
|
|
5
|
+
from .cosine_with_plateu import CosineWithPlateuScheduler
|
|
6
|
+
from .linear import LinearParamScheduler
|
|
7
|
+
from .linear import LinearScheduler
|
|
4
8
|
|
|
5
9
|
|
|
6
|
-
__all__ = [
|
|
10
|
+
__all__ = [
|
|
11
|
+
"CompositeScheduler",
|
|
12
|
+
"CosineParamScheduler",
|
|
13
|
+
"CosineScheduler",
|
|
14
|
+
"CosineWithPlateauParamScheduler",
|
|
15
|
+
"CosineWithPlateuScheduler",
|
|
16
|
+
"LinearParamScheduler",
|
|
17
|
+
"LinearScheduler",
|
|
18
|
+
]
|
kostyl/ml/schedulers/base.py
CHANGED
|
@@ -6,18 +6,20 @@ from typing import Any
|
|
|
6
6
|
class BaseScheduler(ABC):
|
|
7
7
|
"""Base class for learning rate schedulers."""
|
|
8
8
|
|
|
9
|
+
@abstractmethod
|
|
9
10
|
def state_dict(self) -> dict[str, Any]:
|
|
10
11
|
"""Get the state as a state dictionary."""
|
|
11
|
-
|
|
12
|
-
key: value
|
|
13
|
-
for key, value in self.__dict__.items()
|
|
14
|
-
if key not in ["optimizer", "scheduler_values"]
|
|
15
|
-
}
|
|
12
|
+
raise NotImplementedError
|
|
16
13
|
|
|
14
|
+
@abstractmethod
|
|
17
15
|
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
18
16
|
"""Load the state from a state dictionary."""
|
|
19
|
-
|
|
20
|
-
|
|
17
|
+
raise NotImplementedError
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def _verify(self) -> None:
|
|
21
|
+
"""Verify the scheduler configuration."""
|
|
22
|
+
raise NotImplementedError
|
|
21
23
|
|
|
22
24
|
def __getstate__(self) -> dict[str, Any]:
|
|
23
25
|
"""Get the state for pickling."""
|