kostyl-toolkit 0.1.36__py3-none-any.whl → 0.1.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kostyl/ml/base_uploader.py +17 -0
- kostyl/ml/configs/__init__.py +2 -2
- kostyl/ml/configs/mixins.py +50 -0
- kostyl/ml/{data_processing_utils.py → data_collator.py} +6 -3
- kostyl/ml/dist_utils.py +53 -33
- kostyl/ml/integrations/clearml/__init__.py +7 -0
- kostyl/ml/{registry_uploader.py → integrations/clearml/checkpoint_uploader.py} +3 -13
- kostyl/ml/{configs/base_model.py → integrations/clearml/config_mixin.py} +7 -63
- kostyl/ml/{clearml/pulling_utils.py → integrations/clearml/loading_utils.py} +32 -5
- kostyl/ml/integrations/lightning/__init__.py +14 -0
- kostyl/ml/{lightning → integrations/lightning}/callbacks/checkpoint.py +27 -42
- kostyl/ml/{lightning → integrations/lightning}/loggers/tb_logger.py +2 -2
- kostyl/ml/{lightning/extensions/pretrained_model.py → integrations/lightning/mixins.py} +6 -4
- kostyl/ml/{lightning/extensions/custom_module.py → integrations/lightning/module.py} +2 -38
- kostyl/ml/{lightning → integrations/lightning}/utils.py +1 -1
- kostyl/ml/schedulers/__init__.py +4 -4
- kostyl/ml/schedulers/{cosine_with_plateu.py → plateau.py} +59 -36
- kostyl/utils/logging.py +67 -52
- {kostyl_toolkit-0.1.36.dist-info → kostyl_toolkit-0.1.38.dist-info}/METADATA +1 -1
- kostyl_toolkit-0.1.38.dist-info/RECORD +40 -0
- {kostyl_toolkit-0.1.36.dist-info → kostyl_toolkit-0.1.38.dist-info}/WHEEL +2 -2
- kostyl/ml/lightning/__init__.py +0 -5
- kostyl/ml/lightning/extensions/__init__.py +0 -5
- kostyl_toolkit-0.1.36.dist-info/RECORD +0 -38
- /kostyl/ml/{clearml → integrations}/__init__.py +0 -0
- /kostyl/ml/{clearml → integrations/clearml}/dataset_utils.py +0 -0
- /kostyl/ml/{clearml/logging_utils.py → integrations/clearml/version_utils.py} +0 -0
- /kostyl/ml/{lightning → integrations/lightning}/callbacks/__init__.py +0 -0
- /kostyl/ml/{lightning → integrations/lightning}/callbacks/early_stopping.py +0 -0
- /kostyl/ml/{lightning → integrations/lightning}/loggers/__init__.py +0 -0
- /kostyl/ml/{metrics_formatting.py → integrations/lightning/metrics_formatting.py} +0 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from kostyl.utils.logging import setup_logger
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
logger = setup_logger()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ModelCheckpointUploader(ABC):
|
|
12
|
+
"""Abstract base class for uploading model checkpoints to a registry backend."""
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def upload_checkpoint(self, path: str | Path) -> None:
|
|
16
|
+
"""Upload the checkpoint located at the given path to the configured registry backend."""
|
|
17
|
+
raise NotImplementedError
|
kostyl/ml/configs/__init__.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
from .base_model import KostylBaseModel
|
|
2
1
|
from .hyperparams import HyperparamsConfig
|
|
3
2
|
from .hyperparams import Lr
|
|
4
3
|
from .hyperparams import Optimizer
|
|
5
4
|
from .hyperparams import WeightDecay
|
|
5
|
+
from .mixins import ConfigLoadingMixin
|
|
6
6
|
from .training_settings import CheckpointConfig
|
|
7
7
|
from .training_settings import DataConfig
|
|
8
8
|
from .training_settings import DDPStrategyConfig
|
|
@@ -15,12 +15,12 @@ from .training_settings import TrainingSettings
|
|
|
15
15
|
|
|
16
16
|
__all__ = [
|
|
17
17
|
"CheckpointConfig",
|
|
18
|
+
"ConfigLoadingMixin",
|
|
18
19
|
"DDPStrategyConfig",
|
|
19
20
|
"DataConfig",
|
|
20
21
|
"EarlyStoppingConfig",
|
|
21
22
|
"FSDP1StrategyConfig",
|
|
22
23
|
"HyperparamsConfig",
|
|
23
|
-
"KostylBaseModel",
|
|
24
24
|
"LightningTrainerParameters",
|
|
25
25
|
"Lr",
|
|
26
26
|
"Optimizer",
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel as PydanticBaseModel
|
|
4
|
+
|
|
5
|
+
from kostyl.utils.fs import load_config
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ConfigLoadingMixin[TConfig: PydanticBaseModel]:
|
|
9
|
+
"""Mixin providing configuration loading functionality for Pydantic models."""
|
|
10
|
+
|
|
11
|
+
@classmethod
|
|
12
|
+
def from_file(
|
|
13
|
+
cls: type[TConfig], # pyright: ignore
|
|
14
|
+
path: str | Path,
|
|
15
|
+
) -> TConfig:
|
|
16
|
+
"""
|
|
17
|
+
Create an instance of the class from a configuration file.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
cls_: The class type to instantiate.
|
|
21
|
+
path (str | Path): Path to the configuration file.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
An instance of the class created from the configuration file.
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
config = load_config(path)
|
|
28
|
+
instance = cls.model_validate(config)
|
|
29
|
+
return instance
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def from_dict(
|
|
33
|
+
cls: type[TConfig], # pyright: ignore
|
|
34
|
+
state_dict: dict,
|
|
35
|
+
) -> TConfig:
|
|
36
|
+
"""
|
|
37
|
+
Creates an instance from a dictionary.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
cls_: The class type to instantiate.
|
|
41
|
+
state_dict (dict): A dictionary representing the state of the
|
|
42
|
+
class that must be validated and used for initialization.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
An initialized instance of the class based on the
|
|
46
|
+
provided state dictionary.
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
instance = cls.model_validate(state_dict)
|
|
50
|
+
return instance
|
|
@@ -36,6 +36,7 @@ class BatchCollatorWithKeyAlignment:
|
|
|
36
36
|
keys_mapping: A dictionary mapping original keys to new keys.
|
|
37
37
|
keys_to_keep: A set of keys to retain as-is from the original items.
|
|
38
38
|
max_length: If provided, truncates "input_ids" and "attention_mask" to this length.
|
|
39
|
+
Only 1D tensors/lists are supported.
|
|
39
40
|
|
|
40
41
|
Raises:
|
|
41
42
|
ValueError: If both `keys_mapping` and `keys_to_keep` are None.
|
|
@@ -59,14 +60,16 @@ class BatchCollatorWithKeyAlignment:
|
|
|
59
60
|
def _truncate_data(self, key: str, value: Any) -> Any:
|
|
60
61
|
match value:
|
|
61
62
|
case torch.Tensor():
|
|
62
|
-
if value.dim()
|
|
63
|
+
if value.dim() >= 2:
|
|
63
64
|
raise ValueError(
|
|
64
|
-
f"Expected
|
|
65
|
+
f"Expected tensor with dim < 2 for key {key}, got {value.dim()}. "
|
|
66
|
+
"Check your data or disable truncation with `max_length=None`."
|
|
65
67
|
)
|
|
66
68
|
case list():
|
|
67
69
|
if isinstance(value[0], list):
|
|
68
70
|
raise ValueError(
|
|
69
|
-
f"Expected value with dim <= 2 for key {key}, got nested lists"
|
|
71
|
+
f"Expected value with dim <= 2 for key {key}, got nested lists. "
|
|
72
|
+
"Check your data or disable truncation with `max_length=None`."
|
|
70
73
|
)
|
|
71
74
|
value = value[: self.max_length]
|
|
72
75
|
return value
|
kostyl/ml/dist_utils.py
CHANGED
|
@@ -4,47 +4,69 @@ from typing import Literal
|
|
|
4
4
|
|
|
5
5
|
import torch.distributed as dist
|
|
6
6
|
|
|
7
|
+
from kostyl.utils.logging import KostylLogger
|
|
7
8
|
from kostyl.utils.logging import setup_logger
|
|
8
9
|
|
|
9
10
|
|
|
10
|
-
|
|
11
|
+
module_logger = setup_logger()
|
|
11
12
|
|
|
12
13
|
|
|
13
|
-
def log_dist(
|
|
14
|
+
def log_dist(
|
|
15
|
+
msg: str,
|
|
16
|
+
logger: KostylLogger | None = None,
|
|
17
|
+
level: Literal["info", "warning", "error", "warning_once", "debug"] = "info",
|
|
18
|
+
log_scope: Literal["only-zero-rank", "world"] = "world",
|
|
19
|
+
group: dist.ProcessGroup | None = None,
|
|
20
|
+
) -> None:
|
|
14
21
|
"""
|
|
15
22
|
Log a message in a distributed environment based on the specified verbosity level.
|
|
16
23
|
|
|
17
24
|
Args:
|
|
18
25
|
msg (str): The message to log.
|
|
19
|
-
|
|
26
|
+
log_scope (Literal["only-zero-rank", "world"]): The verbosity level for logging.
|
|
20
27
|
- "only-zero-rank": Log only from the main process (rank 0).
|
|
21
28
|
- "world": Log from all processes in the distributed environment.
|
|
29
|
+
logger (KostylLogger | None): The logger instance to use. If None, the module logger is used.
|
|
30
|
+
level (Literal["info", "warning", "error", "warning_once", "debug"]): The logging level.
|
|
31
|
+
group (dist.ProcessGroup | None): Optional process group used to determine ranks. Defaults to the global process group.
|
|
22
32
|
|
|
23
33
|
"""
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
34
|
+
if logger is None:
|
|
35
|
+
logger = module_logger
|
|
36
|
+
|
|
37
|
+
log_attr = getattr(logger, level, None)
|
|
38
|
+
if log_attr is None:
|
|
39
|
+
raise ValueError(f"Invalid logging level: {level}")
|
|
40
|
+
|
|
41
|
+
if not dist.is_initialized():
|
|
42
|
+
module_logger.warning_once(
|
|
43
|
+
"Distributed process group is not initialized. Logging from the current process only."
|
|
44
|
+
)
|
|
45
|
+
log_attr(msg)
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
match log_scope:
|
|
30
49
|
case "only-zero-rank":
|
|
31
|
-
if
|
|
32
|
-
|
|
50
|
+
if group is None:
|
|
51
|
+
module_logger.debug(
|
|
52
|
+
"No process group provided; assuming global group for rank check."
|
|
53
|
+
)
|
|
54
|
+
group = dist.group.WORLD
|
|
55
|
+
group_rank = dist.get_rank(group=group)
|
|
56
|
+
if dist.get_global_rank(group=group, group_rank=group_rank) == 0: # pyright: ignore[reportArgumentType]
|
|
57
|
+
log_attr(msg)
|
|
33
58
|
case "world":
|
|
34
|
-
|
|
59
|
+
log_attr(msg)
|
|
35
60
|
case _:
|
|
36
|
-
|
|
37
|
-
f"Invalid logging verbosity level requested: {how}. Message not logged."
|
|
38
|
-
)
|
|
61
|
+
raise ValueError(f"Invalid logging verbosity level: {log_scope}")
|
|
39
62
|
return
|
|
40
63
|
|
|
41
64
|
|
|
42
65
|
def scale_lrs_by_world_size(
|
|
43
66
|
lrs: dict[str, float],
|
|
44
67
|
group: dist.ProcessGroup | None = None,
|
|
45
|
-
config_name: str = "",
|
|
46
68
|
inv_scale: bool = False,
|
|
47
|
-
|
|
69
|
+
verbose_level: Literal["only-zero-rank", "world"] | None = None,
|
|
48
70
|
) -> dict[str, float]:
|
|
49
71
|
"""
|
|
50
72
|
Scale learning-rate configuration values to match the active distributed world size.
|
|
@@ -56,9 +78,8 @@ def scale_lrs_by_world_size(
|
|
|
56
78
|
lrs (dict[str, float]): A dictionary of learning rate names and their corresponding values to be scaled.
|
|
57
79
|
group (dist.ProcessGroup | None): Optional process group used to determine
|
|
58
80
|
the target world size. Defaults to the global process group.
|
|
59
|
-
config_name (str): Human-readable identifier included in log messages.
|
|
60
81
|
inv_scale (bool): If True, use the inverse square-root scale factor.
|
|
61
|
-
|
|
82
|
+
verbose_level (Literal["only-zero-rank", "world"] | None): Verbosity level for logging scaled values.
|
|
62
83
|
- "only-zero-rank": Log only from the main process (rank 0).
|
|
63
84
|
- "world": Log from all processes in the distributed environment.
|
|
64
85
|
- None: No logging.
|
|
@@ -77,31 +98,30 @@ def scale_lrs_by_world_size(
|
|
|
77
98
|
for name, value in lrs.items():
|
|
78
99
|
old_value = value
|
|
79
100
|
new_value = value * scale
|
|
80
|
-
if
|
|
101
|
+
if verbose_level is not None:
|
|
81
102
|
log_dist(
|
|
82
|
-
f"
|
|
83
|
-
|
|
103
|
+
f"lr {name.upper()}: {new_value}; OLD: {old_value}",
|
|
104
|
+
log_scope=verbose_level,
|
|
105
|
+
group=group,
|
|
84
106
|
)
|
|
85
107
|
lrs[name] = new_value
|
|
86
108
|
return lrs
|
|
87
109
|
|
|
88
110
|
|
|
89
|
-
def
|
|
90
|
-
"""Gets the rank of the current process in a distributed setting."""
|
|
91
|
-
if dist.is_initialized():
|
|
92
|
-
return dist.get_rank()
|
|
93
|
-
if "
|
|
94
|
-
return int(os.environ["
|
|
95
|
-
if "SLURM_PROCID" in os.environ:
|
|
96
|
-
return int(os.environ["SLURM_PROCID"])
|
|
111
|
+
def get_local_rank(group: dist.ProcessGroup | None = None) -> int:
|
|
112
|
+
"""Gets the local rank of the current process in a distributed setting."""
|
|
113
|
+
if dist.is_initialized() and group is not None:
|
|
114
|
+
return dist.get_rank(group=group)
|
|
115
|
+
if "SLURM_LOCALID" in os.environ:
|
|
116
|
+
return int(os.environ["SLURM_LOCALID"])
|
|
97
117
|
if "LOCAL_RANK" in os.environ:
|
|
98
118
|
return int(os.environ["LOCAL_RANK"])
|
|
99
119
|
return 0
|
|
100
120
|
|
|
101
121
|
|
|
102
|
-
def
|
|
103
|
-
"""Checks if the current process is the main process (rank 0) in a distributed setting."""
|
|
104
|
-
rank =
|
|
122
|
+
def is_local_zero_rank() -> bool:
|
|
123
|
+
"""Checks if the current process is the main process (rank 0) for the local node in a distributed setting."""
|
|
124
|
+
rank = get_local_rank()
|
|
105
125
|
if rank != 0:
|
|
106
126
|
return False
|
|
107
127
|
return True
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from abc import ABC
|
|
2
|
-
from abc import abstractmethod
|
|
3
1
|
from collections.abc import Callable
|
|
4
2
|
from functools import partial
|
|
5
3
|
from pathlib import Path
|
|
@@ -7,22 +5,14 @@ from typing import override
|
|
|
7
5
|
|
|
8
6
|
from clearml import OutputModel
|
|
9
7
|
|
|
8
|
+
from kostyl.ml.base_uploader import ModelCheckpointUploader
|
|
10
9
|
from kostyl.utils.logging import setup_logger
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
logger = setup_logger()
|
|
14
13
|
|
|
15
14
|
|
|
16
|
-
class
|
|
17
|
-
"""Abstract Lightning callback responsible for tracking and uploading the best-performing model checkpoint."""
|
|
18
|
-
|
|
19
|
-
@abstractmethod
|
|
20
|
-
def upload_checkpoint(self, path: str | Path) -> None:
|
|
21
|
-
"""Upload the checkpoint located at the given path to the configured registry backend."""
|
|
22
|
-
raise NotImplementedError
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
15
|
+
class ClearMLCheckpointUploader(ModelCheckpointUploader):
|
|
26
16
|
"""PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
|
|
27
17
|
|
|
28
18
|
def __init__(
|
|
@@ -38,7 +28,7 @@ class ClearMLRegistryUploaderCallback(RegistryUploaderCallback):
|
|
|
38
28
|
verbose: bool = True,
|
|
39
29
|
) -> None:
|
|
40
30
|
"""
|
|
41
|
-
Initializes the
|
|
31
|
+
Initializes the ClearMLRegistryUploader.
|
|
42
32
|
|
|
43
33
|
Args:
|
|
44
34
|
model_name: The name for the newly created model.
|
|
@@ -1,75 +1,25 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
|
-
from typing import Self
|
|
3
|
-
from typing import TypeVar
|
|
4
2
|
|
|
5
3
|
from caseconverter import pascalcase
|
|
6
4
|
from caseconverter import snakecase
|
|
7
5
|
from clearml import Task
|
|
8
|
-
from pydantic import BaseModel as PydanticBaseModel
|
|
9
6
|
|
|
7
|
+
from kostyl.ml.configs import ConfigLoadingMixin
|
|
10
8
|
from kostyl.utils.dict_manipulations import convert_to_flat_dict
|
|
11
9
|
from kostyl.utils.dict_manipulations import flattened_dict_to_nested
|
|
12
10
|
from kostyl.utils.fs import load_config
|
|
13
11
|
|
|
14
12
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class BaseModelWithConfigLoading(PydanticBaseModel):
|
|
19
|
-
"""Pydantic class providing basic configuration loading functionality."""
|
|
20
|
-
|
|
21
|
-
@classmethod
|
|
22
|
-
def from_file(
|
|
23
|
-
cls: type[Self], # pyright: ignore
|
|
24
|
-
path: str | Path,
|
|
25
|
-
) -> Self:
|
|
26
|
-
"""
|
|
27
|
-
Create an instance of the class from a configuration file.
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
cls_: The class type to instantiate.
|
|
31
|
-
path (str | Path): Path to the configuration file.
|
|
32
|
-
|
|
33
|
-
Returns:
|
|
34
|
-
An instance of the class created from the configuration file.
|
|
35
|
-
|
|
36
|
-
"""
|
|
37
|
-
config = load_config(path)
|
|
38
|
-
instance = cls.model_validate(config)
|
|
39
|
-
return instance
|
|
40
|
-
|
|
41
|
-
@classmethod
|
|
42
|
-
def from_dict(
|
|
43
|
-
cls: type[Self], # pyright: ignore
|
|
44
|
-
state_dict: dict,
|
|
45
|
-
) -> Self:
|
|
46
|
-
"""
|
|
47
|
-
Creates an instance from a dictionary.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
cls_: The class type to instantiate.
|
|
51
|
-
state_dict (dict): A dictionary representing the state of the
|
|
52
|
-
class that must be validated and used for initialization.
|
|
53
|
-
|
|
54
|
-
Returns:
|
|
55
|
-
An initialized instance of the class based on the
|
|
56
|
-
provided state dictionary.
|
|
57
|
-
|
|
58
|
-
"""
|
|
59
|
-
instance = cls.model_validate(state_dict)
|
|
60
|
-
return instance
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
class BaseModelWithClearmlSyncing(BaseModelWithConfigLoading):
|
|
64
|
-
"""Pydantic class providing ClearML configuration loading and syncing functionality."""
|
|
13
|
+
class BaseModelWithClearmlSyncing[TConfig: ConfigLoadingMixin]:
|
|
14
|
+
"""Mixin providing ClearML task configuration syncing functionality for Pydantic models."""
|
|
65
15
|
|
|
66
16
|
@classmethod
|
|
67
17
|
def connect_as_file(
|
|
68
|
-
cls: type[
|
|
18
|
+
cls: type[TConfig], # pyright: ignore
|
|
69
19
|
task: Task,
|
|
70
20
|
path: str | Path,
|
|
71
21
|
alias: str | None = None,
|
|
72
|
-
) ->
|
|
22
|
+
) -> TConfig:
|
|
73
23
|
"""
|
|
74
24
|
Connects the configuration file to a ClearML task and creates an instance of the class from it.
|
|
75
25
|
|
|
@@ -104,11 +54,11 @@ class BaseModelWithClearmlSyncing(BaseModelWithConfigLoading):
|
|
|
104
54
|
|
|
105
55
|
@classmethod
|
|
106
56
|
def connect_as_dict(
|
|
107
|
-
cls: type[
|
|
57
|
+
cls: type[TConfig], # pyright: ignore
|
|
108
58
|
task: Task,
|
|
109
59
|
path: str | Path,
|
|
110
60
|
alias: str | None = None,
|
|
111
|
-
) ->
|
|
61
|
+
) -> TConfig:
|
|
112
62
|
"""
|
|
113
63
|
Connects configuration from a file as a dictionary to a ClearML task and creates an instance of the class.
|
|
114
64
|
|
|
@@ -135,9 +85,3 @@ class BaseModelWithClearmlSyncing(BaseModelWithConfigLoading):
|
|
|
135
85
|
|
|
136
86
|
model = cls.from_dict(state_dict=config)
|
|
137
87
|
return model
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
class KostylBaseModel(BaseModelWithClearmlSyncing):
|
|
141
|
-
"""A Pydantic model class with basic configuration loading functionality."""
|
|
142
|
-
|
|
143
|
-
pass
|
|
@@ -9,9 +9,26 @@ from transformers import AutoTokenizer
|
|
|
9
9
|
from transformers import PreTrainedModel
|
|
10
10
|
from transformers import PreTrainedTokenizerBase
|
|
11
11
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from kostyl.ml.integrations.lightning import (
|
|
15
|
+
LightningCheckpointLoaderMixin, # pyright: ignore[reportAssignmentType]
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
LIGHTING_MIXIN_AVAILABLE = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
|
|
21
|
+
class LightningCheckpointLoaderMixin(PreTrainedModel): # noqa: D101
|
|
22
|
+
pass # type: ignore
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def from_lightning_checkpoint(cls, *args: Any, **kwargs: Any) -> Any: # noqa: D103
|
|
26
|
+
raise ImportError(
|
|
27
|
+
"Loading from Lightning checkpoints requires lightning integration. "
|
|
28
|
+
"Please package install via 'pip install lightning' to enable this functionality."
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
LIGHTING_MIXIN_AVAILABLE = False
|
|
15
32
|
|
|
16
33
|
|
|
17
34
|
def get_tokenizer_from_clearml(
|
|
@@ -89,13 +106,23 @@ def get_model_from_clearml[
|
|
|
89
106
|
local_path = Path(input_model.get_local_copy(raise_on_error=True))
|
|
90
107
|
|
|
91
108
|
if local_path.is_dir() and input_model._is_package():
|
|
109
|
+
if not issubclass(model, (PreTrainedModel, AutoModel)):
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Model class {model.__name__} must be a subclass of PreTrainedModel or AutoModel for directory loads."
|
|
112
|
+
)
|
|
92
113
|
model_instance = model.from_pretrained(local_path, **kwargs)
|
|
93
114
|
elif local_path.suffix == ".ckpt":
|
|
115
|
+
if not LIGHTING_MIXIN_AVAILABLE:
|
|
116
|
+
raise ImportError(
|
|
117
|
+
"Loading from Lightning checkpoints requires lightning integration. "
|
|
118
|
+
"Please package install via 'pip install lightning' to enable this functionality."
|
|
119
|
+
)
|
|
94
120
|
if not issubclass(model, LightningCheckpointLoaderMixin):
|
|
95
121
|
raise ValueError(
|
|
96
|
-
f"Model class {model.__name__} is not compatible with Lightning checkpoints
|
|
122
|
+
f"Model class {model.__name__} is not compatible with Lightning checkpoints "
|
|
123
|
+
"(must inherit from LightningCheckpointLoaderMixin)."
|
|
97
124
|
)
|
|
98
|
-
model_instance = model.from_lightning_checkpoint(local_path, **kwargs)
|
|
125
|
+
model_instance = model.from_lightning_checkpoint(local_path, **kwargs) # type: ignore
|
|
99
126
|
else:
|
|
100
127
|
raise ValueError(
|
|
101
128
|
f"Unsupported model format for path: {local_path}. "
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
try:
|
|
2
|
+
import lightning # noqa: F401
|
|
3
|
+
except ImportError as e:
|
|
4
|
+
raise ImportError(
|
|
5
|
+
"Lightning integration requires the 'lightning' package. "
|
|
6
|
+
"Please install it via 'pip install lightning'."
|
|
7
|
+
) from e
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
from .mixins import LightningCheckpointLoaderMixin
|
|
11
|
+
from .module import KostylLightningModule
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["KostylLightningModule", "LightningCheckpointLoaderMixin"]
|
|
@@ -9,17 +9,16 @@ import torch.distributed as dist
|
|
|
9
9
|
from lightning.fabric.utilities.types import _PATH
|
|
10
10
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
11
11
|
|
|
12
|
+
from kostyl.ml.base_uploader import ModelCheckpointUploader
|
|
12
13
|
from kostyl.ml.configs import CheckpointConfig
|
|
13
|
-
from kostyl.ml.dist_utils import
|
|
14
|
-
from kostyl.ml.lightning import KostylLightningModule
|
|
15
|
-
from kostyl.ml.registry_uploader import RegistryUploaderCallback
|
|
14
|
+
from kostyl.ml.dist_utils import is_local_zero_rank
|
|
16
15
|
from kostyl.utils import setup_logger
|
|
17
16
|
|
|
18
17
|
|
|
19
18
|
logger = setup_logger("callbacks/checkpoint.py")
|
|
20
19
|
|
|
21
20
|
|
|
22
|
-
class
|
|
21
|
+
class ModelCheckpointWithCheckpointUploader(ModelCheckpoint):
|
|
23
22
|
r"""
|
|
24
23
|
Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
|
|
25
24
|
:class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
|
|
@@ -229,8 +228,8 @@ class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
|
|
|
229
228
|
|
|
230
229
|
def __init__( # noqa: D107
|
|
231
230
|
self,
|
|
232
|
-
|
|
233
|
-
|
|
231
|
+
checkpoint_uploader: ModelCheckpointUploader,
|
|
232
|
+
upload_strategy: Literal["only-best", "every-checkpoint"] = "only-best",
|
|
234
233
|
dirpath: _PATH | None = None,
|
|
235
234
|
filename: str | None = None,
|
|
236
235
|
monitor: str | None = None,
|
|
@@ -247,9 +246,9 @@ class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
|
|
|
247
246
|
save_on_train_epoch_end: bool | None = None,
|
|
248
247
|
enable_version_counter: bool = True,
|
|
249
248
|
) -> None:
|
|
250
|
-
self.
|
|
249
|
+
self.registry_uploader = checkpoint_uploader
|
|
251
250
|
self.process_group: dist.ProcessGroup | None = None
|
|
252
|
-
self.
|
|
251
|
+
self.upload_strategy = upload_strategy
|
|
253
252
|
super().__init__(
|
|
254
253
|
dirpath=dirpath,
|
|
255
254
|
filename=filename,
|
|
@@ -269,40 +268,26 @@ class ModelCheckpointWithRegistryUploader(ModelCheckpoint):
|
|
|
269
268
|
)
|
|
270
269
|
return
|
|
271
270
|
|
|
272
|
-
@override
|
|
273
|
-
def setup(
|
|
274
|
-
self,
|
|
275
|
-
trainer: pl.Trainer,
|
|
276
|
-
pl_module: pl.LightningModule | KostylLightningModule,
|
|
277
|
-
stage: str,
|
|
278
|
-
) -> None:
|
|
279
|
-
super().setup(trainer, pl_module, stage)
|
|
280
|
-
if isinstance(pl_module, KostylLightningModule):
|
|
281
|
-
self.process_group = pl_module.get_process_group()
|
|
282
|
-
return
|
|
283
|
-
|
|
284
271
|
@override
|
|
285
272
|
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
|
|
286
273
|
super()._save_checkpoint(trainer, filepath)
|
|
287
|
-
if
|
|
288
|
-
|
|
289
|
-
if trainer.is_global_zero and self.registry_uploader_callback is not None:
|
|
290
|
-
match self.uploading_mode:
|
|
274
|
+
if trainer.is_global_zero and self.registry_uploader is not None:
|
|
275
|
+
match self.upload_strategy:
|
|
291
276
|
case "every-checkpoint":
|
|
292
|
-
self.
|
|
277
|
+
self.registry_uploader.upload_checkpoint(filepath)
|
|
293
278
|
case "only-best":
|
|
294
279
|
if filepath == self.best_model_path:
|
|
295
|
-
self.
|
|
280
|
+
self.registry_uploader.upload_checkpoint(filepath)
|
|
296
281
|
return
|
|
297
282
|
|
|
298
283
|
|
|
299
284
|
def setup_checkpoint_callback(
|
|
300
285
|
dirpath: Path,
|
|
301
286
|
ckpt_cfg: CheckpointConfig,
|
|
302
|
-
|
|
303
|
-
|
|
287
|
+
checkpoint_uploader: ModelCheckpointUploader | None = None,
|
|
288
|
+
upload_strategy: Literal["only-best", "every-checkpoint"] | None = None,
|
|
304
289
|
remove_folder_if_exists: bool = True,
|
|
305
|
-
) ->
|
|
290
|
+
) -> ModelCheckpointWithCheckpointUploader | ModelCheckpoint:
|
|
306
291
|
"""
|
|
307
292
|
Create and configure a checkpoint callback for model saving.
|
|
308
293
|
|
|
@@ -313,33 +298,33 @@ def setup_checkpoint_callback(
|
|
|
313
298
|
Args:
|
|
314
299
|
dirpath: Path to the directory for saving checkpoints.
|
|
315
300
|
ckpt_cfg: Checkpoint configuration (filename, monitor, mode, save_top_k).
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
301
|
+
checkpoint_uploader: Optional checkpoint uploader instance. If provided, enables
|
|
302
|
+
uploading of checkpoints to a remote registry.
|
|
303
|
+
upload_strategy: Checkpoint upload mode:
|
|
319
304
|
- "only-best": only the best checkpoint is uploaded
|
|
320
305
|
- "every-checkpoint": every saved checkpoint is uploaded
|
|
321
|
-
Must be specified together with
|
|
306
|
+
Must be specified together with checkpoint_uploader.
|
|
322
307
|
remove_folder_if_exists: If True, removes existing checkpoint directory before creating a new one.
|
|
323
308
|
|
|
324
309
|
Returns:
|
|
325
|
-
|
|
310
|
+
ModelCheckpointWithCheckpointUploader if checkpoint_uploader is provided,
|
|
326
311
|
otherwise standard ModelCheckpoint.
|
|
327
312
|
|
|
328
313
|
Raises:
|
|
329
|
-
ValueError: If only one of
|
|
314
|
+
ValueError: If only one of checkpoint_uploader or uploading_mode is None.
|
|
330
315
|
|
|
331
316
|
Note:
|
|
332
317
|
If the dirpath directory already exists, it will be removed and recreated
|
|
333
318
|
(only on the main process in distributed training) if remove_folder_if_exists is True.
|
|
334
319
|
|
|
335
320
|
"""
|
|
336
|
-
if (
|
|
321
|
+
if (checkpoint_uploader is None) != (upload_strategy is None):
|
|
337
322
|
raise ValueError(
|
|
338
|
-
"Both
|
|
323
|
+
"Both checkpoint_uploader and upload_strategy must be provided or neither."
|
|
339
324
|
)
|
|
340
325
|
|
|
341
326
|
if dirpath.exists():
|
|
342
|
-
if
|
|
327
|
+
if is_local_zero_rank():
|
|
343
328
|
logger.warning(f"Checkpoint directory {dirpath} already exists.")
|
|
344
329
|
if remove_folder_if_exists:
|
|
345
330
|
rmtree(dirpath)
|
|
@@ -348,8 +333,8 @@ def setup_checkpoint_callback(
|
|
|
348
333
|
logger.info(f"Creating checkpoint directory {dirpath}.")
|
|
349
334
|
dirpath.mkdir(parents=True, exist_ok=True)
|
|
350
335
|
|
|
351
|
-
if (
|
|
352
|
-
checkpoint_callback =
|
|
336
|
+
if (checkpoint_uploader is not None) and (upload_strategy is not None):
|
|
337
|
+
checkpoint_callback = ModelCheckpointWithCheckpointUploader(
|
|
353
338
|
dirpath=dirpath,
|
|
354
339
|
filename=ckpt_cfg.filename,
|
|
355
340
|
save_top_k=ckpt_cfg.save_top_k,
|
|
@@ -357,8 +342,8 @@ def setup_checkpoint_callback(
|
|
|
357
342
|
mode=ckpt_cfg.mode,
|
|
358
343
|
verbose=True,
|
|
359
344
|
save_weights_only=ckpt_cfg.save_weights_only,
|
|
360
|
-
|
|
361
|
-
|
|
345
|
+
checkpoint_uploader=checkpoint_uploader,
|
|
346
|
+
upload_strategy=upload_strategy,
|
|
362
347
|
)
|
|
363
348
|
else:
|
|
364
349
|
checkpoint_callback = ModelCheckpoint(
|
|
@@ -3,7 +3,7 @@ from shutil import rmtree
|
|
|
3
3
|
|
|
4
4
|
from lightning.pytorch.loggers import TensorBoardLogger
|
|
5
5
|
|
|
6
|
-
from kostyl.ml.dist_utils import
|
|
6
|
+
from kostyl.ml.dist_utils import is_local_zero_rank
|
|
7
7
|
from kostyl.utils.logging import setup_logger
|
|
8
8
|
|
|
9
9
|
|
|
@@ -15,7 +15,7 @@ def setup_tb_logger(
|
|
|
15
15
|
) -> TensorBoardLogger:
|
|
16
16
|
"""Sets up a TensorBoardLogger for PyTorch Lightning."""
|
|
17
17
|
if runs_dir.exists():
|
|
18
|
-
if
|
|
18
|
+
if is_local_zero_rank():
|
|
19
19
|
logger.warning(f"TensorBoard log directory {runs_dir} already exists.")
|
|
20
20
|
rmtree(runs_dir)
|
|
21
21
|
logger.warning(f"Removed existing TensorBoard log directory {runs_dir}.")
|