kostyl-toolkit 0.1.19__tar.gz → 0.1.20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/PKG-INFO +1 -1
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/dist_utils.py +13 -13
- kostyl_toolkit-0.1.20/kostyl/ml/lightning/callbacks/registry_uploading.py +138 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/utils/logging.py +22 -3
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/pyproject.toml +1 -1
- kostyl_toolkit-0.1.19/kostyl/ml/lightning/callbacks/registry_uploading.py +0 -118
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/README.md +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/__init__.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/__init__.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/clearml/__init__.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/clearml/dataset_utils.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/clearml/logging_utils.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/clearml/pulling_utils.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/configs/__init__.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/configs/base_model.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/configs/hyperparams.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/configs/training_settings.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/__init__.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/callbacks/__init__.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/callbacks/checkpoint.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/callbacks/early_stopping.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/extenstions/__init__.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/extenstions/custom_module.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/extenstions/pretrained_model.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/loggers/__init__.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/loggers/tb_logger.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/steps_estimation.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/metrics_formatting.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/params_groups.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/schedulers/__init__.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/schedulers/base.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/schedulers/composite.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/schedulers/cosine.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/utils/__init__.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/utils/dict_manipulations.py +0 -0
- {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/utils/fs.py +0 -0
|
@@ -86,22 +86,22 @@ def scale_lrs_by_world_size(
|
|
|
86
86
|
return lrs
|
|
87
87
|
|
|
88
88
|
|
|
89
|
-
def
|
|
89
|
+
def get_rank() -> int:
|
|
90
|
+
"""Gets the rank of the current process in a distributed setting."""
|
|
90
91
|
if dist.is_initialized():
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
92
|
+
return dist.get_rank()
|
|
93
|
+
if "RANK" in os.environ:
|
|
94
|
+
return int(os.environ["RANK"])
|
|
95
|
+
if "SLURM_PROCID" in os.environ:
|
|
96
|
+
return int(os.environ["SLURM_PROCID"])
|
|
97
|
+
if "LOCAL_RANK" in os.environ:
|
|
98
|
+
return int(os.environ["LOCAL_RANK"])
|
|
99
|
+
return 0
|
|
95
100
|
|
|
96
101
|
|
|
97
102
|
def is_main_process() -> bool:
|
|
98
103
|
"""Checks if the current process is the main process (rank 0) in a distributed setting."""
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
return int(os.environ["RANK"]) == 0
|
|
103
|
-
if "SLURM_PROCID" in os.environ:
|
|
104
|
-
return int(os.environ["SLURM_PROCID"]) == 0
|
|
105
|
-
if "LOCAL_RANK" in os.environ:
|
|
106
|
-
return int(os.environ["LOCAL_RANK"]) == 0
|
|
104
|
+
rank = get_rank()
|
|
105
|
+
if rank != 0:
|
|
106
|
+
return False
|
|
107
107
|
return True
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from typing import override
|
|
3
|
+
|
|
4
|
+
from clearml import OutputModel
|
|
5
|
+
from clearml import Task
|
|
6
|
+
from lightning import Trainer
|
|
7
|
+
from lightning.pytorch.callbacks import Callback
|
|
8
|
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
9
|
+
|
|
10
|
+
from kostyl.ml.clearml.logging_utils import find_version_in_tags
|
|
11
|
+
from kostyl.ml.clearml.logging_utils import increment_version
|
|
12
|
+
from kostyl.ml.lightning import KostylLightningModule
|
|
13
|
+
from kostyl.utils.logging import setup_logger
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
logger = setup_logger()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ClearMLRegistryUploaderCallback(Callback):
|
|
20
|
+
"""PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
task: Task,
|
|
25
|
+
ckpt_callback: ModelCheckpoint,
|
|
26
|
+
output_model_name: str,
|
|
27
|
+
output_model_tags: list[str] | None = None,
|
|
28
|
+
verbose: bool = True,
|
|
29
|
+
enable_tag_versioning: bool = True,
|
|
30
|
+
uploading_frequency: Literal[
|
|
31
|
+
"after-every-eval", "on-train-end"
|
|
32
|
+
] = "on-train-end",
|
|
33
|
+
) -> None:
|
|
34
|
+
"""
|
|
35
|
+
Initializes the ClearMLRegistryUploaderCallback.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
task: ClearML task.
|
|
39
|
+
ckpt_callback: ModelCheckpoint instance used by Trainer.
|
|
40
|
+
output_model_name: Name for the ClearML output model.
|
|
41
|
+
output_model_tags: Tags for the output model.
|
|
42
|
+
verbose: Whether to log messages.
|
|
43
|
+
enable_tag_versioning: Whether to enable versioning in tags. If True,
|
|
44
|
+
the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
|
|
45
|
+
uploading_frequency: When to upload:
|
|
46
|
+
- "after-every-eval": after each validation phase.
|
|
47
|
+
- "on-train-end": once at the end of training.
|
|
48
|
+
|
|
49
|
+
"""
|
|
50
|
+
super().__init__()
|
|
51
|
+
if output_model_tags is None:
|
|
52
|
+
output_model_tags = []
|
|
53
|
+
|
|
54
|
+
self.task = task
|
|
55
|
+
self.ckpt_callback = ckpt_callback
|
|
56
|
+
self.output_model_name = output_model_name
|
|
57
|
+
self.output_model_tags = output_model_tags
|
|
58
|
+
self.verbose = verbose
|
|
59
|
+
self.uploading_frequency = uploading_frequency
|
|
60
|
+
self.enable_tag_versioning = enable_tag_versioning
|
|
61
|
+
|
|
62
|
+
self._output_model: OutputModel | None = None
|
|
63
|
+
self._last_best_model_path: str = ""
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
def _create_output_model(self, pl_module: "KostylLightningModule") -> OutputModel:
|
|
67
|
+
if self.enable_tag_versioning:
|
|
68
|
+
version = find_version_in_tags(self.output_model_tags)
|
|
69
|
+
if version is None:
|
|
70
|
+
self.output_model_tags.append("v1.0")
|
|
71
|
+
else:
|
|
72
|
+
new_version = increment_version(version)
|
|
73
|
+
self.output_model_tags.remove(version)
|
|
74
|
+
self.output_model_tags.append(new_version)
|
|
75
|
+
|
|
76
|
+
if "LightningCheckpoint" not in self.output_model_tags:
|
|
77
|
+
self.output_model_tags.append("LightningCheckpoint")
|
|
78
|
+
config = pl_module.model_config
|
|
79
|
+
if config is not None:
|
|
80
|
+
config = config.to_dict()
|
|
81
|
+
|
|
82
|
+
return OutputModel(
|
|
83
|
+
task=self.task,
|
|
84
|
+
name=self.output_model_name,
|
|
85
|
+
framework="PyTorch",
|
|
86
|
+
tags=self.output_model_tags,
|
|
87
|
+
config_dict=config,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
|
|
91
|
+
current_best = self.ckpt_callback.best_model_path
|
|
92
|
+
|
|
93
|
+
if not current_best:
|
|
94
|
+
if self.verbose:
|
|
95
|
+
logger.info("No best model found yet to upload")
|
|
96
|
+
return
|
|
97
|
+
|
|
98
|
+
if current_best == self._last_best_model_path:
|
|
99
|
+
if self.verbose:
|
|
100
|
+
logger.info("Best model unchanged since last upload")
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
if self._output_model is None:
|
|
104
|
+
self._output_model = self._create_output_model(pl_module)
|
|
105
|
+
|
|
106
|
+
if self.verbose:
|
|
107
|
+
logger.info(f"Uploading best model from {current_best}")
|
|
108
|
+
|
|
109
|
+
self._output_model.update_weights(
|
|
110
|
+
current_best,
|
|
111
|
+
auto_delete_file=False,
|
|
112
|
+
async_enable=False,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
self._last_best_model_path = current_best
|
|
116
|
+
return
|
|
117
|
+
|
|
118
|
+
@override
|
|
119
|
+
def on_validation_end(
|
|
120
|
+
self, trainer: Trainer, pl_module: "KostylLightningModule"
|
|
121
|
+
) -> None:
|
|
122
|
+
if self.uploading_frequency != "after-every-eval":
|
|
123
|
+
return
|
|
124
|
+
if not trainer.is_global_zero:
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
self._upload_best_checkpoint(pl_module)
|
|
128
|
+
return
|
|
129
|
+
|
|
130
|
+
@override
|
|
131
|
+
def on_train_end(
|
|
132
|
+
self, trainer: Trainer, pl_module: "KostylLightningModule"
|
|
133
|
+
) -> None:
|
|
134
|
+
if not trainer.is_global_zero:
|
|
135
|
+
return
|
|
136
|
+
|
|
137
|
+
self._upload_best_checkpoint(pl_module)
|
|
138
|
+
return
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
-
import os
|
|
5
4
|
import sys
|
|
6
5
|
import uuid
|
|
6
|
+
from collections import namedtuple
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from functools import partialmethod
|
|
9
9
|
from pathlib import Path
|
|
@@ -13,7 +13,6 @@ from typing import Literal
|
|
|
13
13
|
from typing import cast
|
|
14
14
|
|
|
15
15
|
from loguru import logger as _base_logger
|
|
16
|
-
from torch.nn.modules.module import _IncompatibleKeys
|
|
17
16
|
|
|
18
17
|
|
|
19
18
|
if TYPE_CHECKING:
|
|
@@ -27,6 +26,9 @@ else:
|
|
|
27
26
|
|
|
28
27
|
try:
|
|
29
28
|
import torch.distributed as dist
|
|
29
|
+
from torch.nn.modules.module import (
|
|
30
|
+
_IncompatibleKeys, # pyright: ignore[reportAssignmentType]
|
|
31
|
+
)
|
|
30
32
|
except Exception:
|
|
31
33
|
|
|
32
34
|
class _Dummy:
|
|
@@ -38,7 +40,24 @@ except Exception:
|
|
|
38
40
|
def is_initialized() -> bool:
|
|
39
41
|
return False
|
|
40
42
|
|
|
43
|
+
@staticmethod
|
|
44
|
+
def get_rank() -> int:
|
|
45
|
+
return 0
|
|
46
|
+
|
|
47
|
+
class _IncompatibleKeys(
|
|
48
|
+
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
|
|
49
|
+
):
|
|
50
|
+
__slots__ = ()
|
|
51
|
+
|
|
52
|
+
def __repr__(self) -> str:
|
|
53
|
+
if not self.missing_keys and not self.unexpected_keys:
|
|
54
|
+
return "<All keys matched successfully>"
|
|
55
|
+
return super().__repr__()
|
|
56
|
+
|
|
57
|
+
__str__ = __repr__
|
|
58
|
+
|
|
41
59
|
dist = _Dummy()
|
|
60
|
+
_IncompatibleKeys = _IncompatibleKeys
|
|
42
61
|
|
|
43
62
|
_once_lock = Lock()
|
|
44
63
|
_once_keys: set[tuple[str, str]] = set()
|
|
@@ -106,7 +125,7 @@ def setup_logger(
|
|
|
106
125
|
add_rank = False
|
|
107
126
|
|
|
108
127
|
if add_rank:
|
|
109
|
-
rank =
|
|
128
|
+
rank = dist.get_rank()
|
|
110
129
|
channel = f"rank:{rank} - {base}"
|
|
111
130
|
else:
|
|
112
131
|
channel = base
|
|
@@ -1,118 +0,0 @@
|
|
|
1
|
-
from typing import Literal
|
|
2
|
-
from typing import override
|
|
3
|
-
|
|
4
|
-
from clearml import OutputModel
|
|
5
|
-
from clearml import Task
|
|
6
|
-
from lightning import Trainer
|
|
7
|
-
from lightning.pytorch.callbacks import Callback
|
|
8
|
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
9
|
-
|
|
10
|
-
from kostyl.ml.clearml.logging_utils import find_version_in_tags
|
|
11
|
-
from kostyl.ml.clearml.logging_utils import increment_version
|
|
12
|
-
from kostyl.ml.lightning import KostylLightningModule
|
|
13
|
-
from kostyl.utils.logging import setup_logger
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
logger = setup_logger()
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class ClearMLRegistryUploaderCallback(Callback):
|
|
20
|
-
"""PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
|
|
21
|
-
|
|
22
|
-
def __init__(
|
|
23
|
-
self,
|
|
24
|
-
task: Task,
|
|
25
|
-
ckpt_callback: ModelCheckpoint,
|
|
26
|
-
output_model_name: str,
|
|
27
|
-
output_model_tags: list[str] | None = None,
|
|
28
|
-
verbose: bool = True,
|
|
29
|
-
uploading_frequency: Literal[
|
|
30
|
-
"after-every-eval", "on-train-end"
|
|
31
|
-
] = "on-train-end",
|
|
32
|
-
) -> None:
|
|
33
|
-
"""
|
|
34
|
-
Initialize the callback.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
task (Task): The ClearML task object.
|
|
38
|
-
ckpt_callback (ModelCheckpoint): The model checkpoint callback.
|
|
39
|
-
output_model_name (str): The name for the output model.
|
|
40
|
-
output_model_tags (list[str] | None, optional): Tags for the output model. Defaults to None, which is converted to an empty list.
|
|
41
|
-
verbose (bool, optional): Whether to log verbose messages. Defaults to True.
|
|
42
|
-
uploading_frequency (Literal["after-every-eval", "on-train-end"]): Frequency of uploading the model. Defaults to "on-train-end".
|
|
43
|
-
|
|
44
|
-
"""
|
|
45
|
-
super().__init__()
|
|
46
|
-
if output_model_tags is None:
|
|
47
|
-
output_model_tags = []
|
|
48
|
-
self.task = task
|
|
49
|
-
self.ckpt_callback = ckpt_callback
|
|
50
|
-
self.output_model_name = output_model_name
|
|
51
|
-
self.output_model_tags = output_model_tags
|
|
52
|
-
self.verbose = verbose
|
|
53
|
-
self.uploading_frequency = uploading_frequency
|
|
54
|
-
|
|
55
|
-
self._output_model: OutputModel | None = None
|
|
56
|
-
self._last_best_model_path: str = ""
|
|
57
|
-
return
|
|
58
|
-
|
|
59
|
-
def _create_output_model(self, pl_module: KostylLightningModule) -> OutputModel:
|
|
60
|
-
version = find_version_in_tags(self.output_model_tags)
|
|
61
|
-
if version is None:
|
|
62
|
-
self.output_model_tags.append("v1.0")
|
|
63
|
-
else:
|
|
64
|
-
new_version = increment_version(version)
|
|
65
|
-
self.output_model_tags.remove(version)
|
|
66
|
-
self.output_model_tags.append(new_version)
|
|
67
|
-
|
|
68
|
-
config = pl_module.model_config
|
|
69
|
-
if config is not None:
|
|
70
|
-
config = config.to_dict()
|
|
71
|
-
|
|
72
|
-
output_model = OutputModel(
|
|
73
|
-
task=self.task,
|
|
74
|
-
name=self.output_model_name,
|
|
75
|
-
framework="PyTorch",
|
|
76
|
-
tags=self.output_model_tags,
|
|
77
|
-
config_dict=config,
|
|
78
|
-
)
|
|
79
|
-
return output_model
|
|
80
|
-
|
|
81
|
-
def _upload_best_checkpoint(self, pl_module: KostylLightningModule) -> None:
|
|
82
|
-
if self._output_model is None:
|
|
83
|
-
self._output_model = self._create_output_model(pl_module)
|
|
84
|
-
|
|
85
|
-
if self.ckpt_callback.best_model_path == self._last_best_model_path:
|
|
86
|
-
if self.verbose and (self._last_best_model_path != ""):
|
|
87
|
-
logger.info("Best model unchanged since last upload")
|
|
88
|
-
elif self.verbose:
|
|
89
|
-
logger.info("No best model found yet to upload")
|
|
90
|
-
else:
|
|
91
|
-
if self.verbose:
|
|
92
|
-
logger.info(
|
|
93
|
-
f"Uploading best model from {self.ckpt_callback.best_model_path}"
|
|
94
|
-
)
|
|
95
|
-
self._output_model.update_weights(
|
|
96
|
-
self.ckpt_callback.best_model_path,
|
|
97
|
-
auto_delete_file=False,
|
|
98
|
-
async_enable=False,
|
|
99
|
-
)
|
|
100
|
-
return
|
|
101
|
-
|
|
102
|
-
@override
|
|
103
|
-
def on_validation_epoch_end(
|
|
104
|
-
self, trainer: Trainer, pl_module: KostylLightningModule
|
|
105
|
-
) -> None:
|
|
106
|
-
if (not trainer.is_global_zero) or (
|
|
107
|
-
self.uploading_frequency != "after-every-eval"
|
|
108
|
-
):
|
|
109
|
-
return
|
|
110
|
-
self._upload_best_checkpoint(pl_module)
|
|
111
|
-
return
|
|
112
|
-
|
|
113
|
-
@override
|
|
114
|
-
def on_train_end(self, trainer: Trainer, pl_module: KostylLightningModule) -> None:
|
|
115
|
-
if not trainer.is_global_zero:
|
|
116
|
-
return
|
|
117
|
-
self._upload_best_checkpoint(pl_module)
|
|
118
|
-
return
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/callbacks/early_stopping.py
RENAMED
|
File without changes
|
|
File without changes
|
{kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/extenstions/custom_module.py
RENAMED
|
File without changes
|
{kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/extenstions/pretrained_model.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|