kostyl-toolkit 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kostyl/ml/dist_utils.py +14 -6
- kostyl/ml/lightning/callbacks/registry_uploading.py +60 -40
- kostyl/ml/lightning/loggers/tb_logger.py +1 -1
- kostyl/utils/logging.py +22 -3
- {kostyl_toolkit-0.1.18.dist-info → kostyl_toolkit-0.1.20.dist-info}/METADATA +1 -1
- {kostyl_toolkit-0.1.18.dist-info → kostyl_toolkit-0.1.20.dist-info}/RECORD +7 -7
- {kostyl_toolkit-0.1.18.dist-info → kostyl_toolkit-0.1.20.dist-info}/WHEEL +0 -0
kostyl/ml/dist_utils.py
CHANGED
|
@@ -86,14 +86,22 @@ def scale_lrs_by_world_size(
|
|
|
86
86
|
return lrs
|
|
87
87
|
|
|
88
88
|
|
|
89
|
-
def
|
|
89
|
+
def get_rank() -> int:
|
|
90
|
+
"""Gets the rank of the current process in a distributed setting."""
|
|
90
91
|
if dist.is_initialized():
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
92
|
+
return dist.get_rank()
|
|
93
|
+
if "RANK" in os.environ:
|
|
94
|
+
return int(os.environ["RANK"])
|
|
95
|
+
if "SLURM_PROCID" in os.environ:
|
|
96
|
+
return int(os.environ["SLURM_PROCID"])
|
|
97
|
+
if "LOCAL_RANK" in os.environ:
|
|
98
|
+
return int(os.environ["LOCAL_RANK"])
|
|
99
|
+
return 0
|
|
95
100
|
|
|
96
101
|
|
|
97
102
|
def is_main_process() -> bool:
|
|
98
103
|
"""Checks if the current process is the main process (rank 0) in a distributed setting."""
|
|
99
|
-
|
|
104
|
+
rank = get_rank()
|
|
105
|
+
if rank != 0:
|
|
106
|
+
return False
|
|
107
|
+
return True
|
|
@@ -26,93 +26,113 @@ class ClearMLRegistryUploaderCallback(Callback):
|
|
|
26
26
|
output_model_name: str,
|
|
27
27
|
output_model_tags: list[str] | None = None,
|
|
28
28
|
verbose: bool = True,
|
|
29
|
+
enable_tag_versioning: bool = True,
|
|
29
30
|
uploading_frequency: Literal[
|
|
30
31
|
"after-every-eval", "on-train-end"
|
|
31
32
|
] = "on-train-end",
|
|
32
33
|
) -> None:
|
|
33
34
|
"""
|
|
34
|
-
|
|
35
|
+
Initializes the ClearMLRegistryUploaderCallback.
|
|
35
36
|
|
|
36
37
|
Args:
|
|
37
|
-
task
|
|
38
|
-
ckpt_callback
|
|
39
|
-
output_model_name
|
|
40
|
-
output_model_tags
|
|
41
|
-
verbose
|
|
42
|
-
|
|
38
|
+
task: ClearML task.
|
|
39
|
+
ckpt_callback: ModelCheckpoint instance used by Trainer.
|
|
40
|
+
output_model_name: Name for the ClearML output model.
|
|
41
|
+
output_model_tags: Tags for the output model.
|
|
42
|
+
verbose: Whether to log messages.
|
|
43
|
+
enable_tag_versioning: Whether to enable versioning in tags. If True,
|
|
44
|
+
the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
|
|
45
|
+
uploading_frequency: When to upload:
|
|
46
|
+
- "after-every-eval": after each validation phase.
|
|
47
|
+
- "on-train-end": once at the end of training.
|
|
43
48
|
|
|
44
49
|
"""
|
|
45
50
|
super().__init__()
|
|
46
51
|
if output_model_tags is None:
|
|
47
52
|
output_model_tags = []
|
|
53
|
+
|
|
48
54
|
self.task = task
|
|
49
55
|
self.ckpt_callback = ckpt_callback
|
|
50
56
|
self.output_model_name = output_model_name
|
|
51
57
|
self.output_model_tags = output_model_tags
|
|
52
58
|
self.verbose = verbose
|
|
53
59
|
self.uploading_frequency = uploading_frequency
|
|
60
|
+
self.enable_tag_versioning = enable_tag_versioning
|
|
54
61
|
|
|
55
62
|
self._output_model: OutputModel | None = None
|
|
56
63
|
self._last_best_model_path: str = ""
|
|
57
64
|
return
|
|
58
65
|
|
|
59
|
-
def _create_output_model(self, pl_module: KostylLightningModule) -> OutputModel:
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
66
|
+
def _create_output_model(self, pl_module: "KostylLightningModule") -> OutputModel:
|
|
67
|
+
if self.enable_tag_versioning:
|
|
68
|
+
version = find_version_in_tags(self.output_model_tags)
|
|
69
|
+
if version is None:
|
|
70
|
+
self.output_model_tags.append("v1.0")
|
|
71
|
+
else:
|
|
72
|
+
new_version = increment_version(version)
|
|
73
|
+
self.output_model_tags.remove(version)
|
|
74
|
+
self.output_model_tags.append(new_version)
|
|
75
|
+
|
|
76
|
+
if "LightningCheckpoint" not in self.output_model_tags:
|
|
77
|
+
self.output_model_tags.append("LightningCheckpoint")
|
|
68
78
|
config = pl_module.model_config
|
|
69
79
|
if config is not None:
|
|
70
80
|
config = config.to_dict()
|
|
71
81
|
|
|
72
|
-
|
|
82
|
+
return OutputModel(
|
|
73
83
|
task=self.task,
|
|
74
84
|
name=self.output_model_name,
|
|
75
85
|
framework="PyTorch",
|
|
76
86
|
tags=self.output_model_tags,
|
|
77
87
|
config_dict=config,
|
|
78
88
|
)
|
|
79
|
-
return output_model
|
|
80
89
|
|
|
81
|
-
def _upload_best_checkpoint(self, pl_module: KostylLightningModule) -> None:
|
|
82
|
-
|
|
83
|
-
self._output_model = self._create_output_model(pl_module)
|
|
90
|
+
def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
|
|
91
|
+
current_best = self.ckpt_callback.best_model_path
|
|
84
92
|
|
|
85
|
-
if
|
|
86
|
-
if self.verbose
|
|
87
|
-
logger.info("Best model unchanged since last upload")
|
|
88
|
-
elif self.verbose:
|
|
93
|
+
if not current_best:
|
|
94
|
+
if self.verbose:
|
|
89
95
|
logger.info("No best model found yet to upload")
|
|
90
|
-
|
|
96
|
+
return
|
|
97
|
+
|
|
98
|
+
if current_best == self._last_best_model_path:
|
|
91
99
|
if self.verbose:
|
|
92
|
-
logger.info(
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
)
|
|
100
|
+
logger.info("Best model unchanged since last upload")
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
if self._output_model is None:
|
|
104
|
+
self._output_model = self._create_output_model(pl_module)
|
|
105
|
+
|
|
106
|
+
if self.verbose:
|
|
107
|
+
logger.info(f"Uploading best model from {current_best}")
|
|
108
|
+
|
|
109
|
+
self._output_model.update_weights(
|
|
110
|
+
current_best,
|
|
111
|
+
auto_delete_file=False,
|
|
112
|
+
async_enable=False,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
self._last_best_model_path = current_best
|
|
100
116
|
return
|
|
101
117
|
|
|
102
118
|
@override
|
|
103
|
-
def
|
|
104
|
-
self, trainer: Trainer, pl_module: KostylLightningModule
|
|
119
|
+
def on_validation_end(
|
|
120
|
+
self, trainer: Trainer, pl_module: "KostylLightningModule"
|
|
105
121
|
) -> None:
|
|
106
|
-
if
|
|
107
|
-
self.uploading_frequency != "after-every-eval"
|
|
108
|
-
):
|
|
122
|
+
if self.uploading_frequency != "after-every-eval":
|
|
109
123
|
return
|
|
124
|
+
if not trainer.is_global_zero:
|
|
125
|
+
return
|
|
126
|
+
|
|
110
127
|
self._upload_best_checkpoint(pl_module)
|
|
111
128
|
return
|
|
112
129
|
|
|
113
130
|
@override
|
|
114
|
-
def on_train_end(
|
|
131
|
+
def on_train_end(
|
|
132
|
+
self, trainer: Trainer, pl_module: "KostylLightningModule"
|
|
133
|
+
) -> None:
|
|
115
134
|
if not trainer.is_global_zero:
|
|
116
135
|
return
|
|
136
|
+
|
|
117
137
|
self._upload_best_checkpoint(pl_module)
|
|
118
138
|
return
|
kostyl/utils/logging.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
-
import os
|
|
5
4
|
import sys
|
|
6
5
|
import uuid
|
|
6
|
+
from collections import namedtuple
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from functools import partialmethod
|
|
9
9
|
from pathlib import Path
|
|
@@ -13,7 +13,6 @@ from typing import Literal
|
|
|
13
13
|
from typing import cast
|
|
14
14
|
|
|
15
15
|
from loguru import logger as _base_logger
|
|
16
|
-
from torch.nn.modules.module import _IncompatibleKeys
|
|
17
16
|
|
|
18
17
|
|
|
19
18
|
if TYPE_CHECKING:
|
|
@@ -27,6 +26,9 @@ else:
|
|
|
27
26
|
|
|
28
27
|
try:
|
|
29
28
|
import torch.distributed as dist
|
|
29
|
+
from torch.nn.modules.module import (
|
|
30
|
+
_IncompatibleKeys, # pyright: ignore[reportAssignmentType]
|
|
31
|
+
)
|
|
30
32
|
except Exception:
|
|
31
33
|
|
|
32
34
|
class _Dummy:
|
|
@@ -38,7 +40,24 @@ except Exception:
|
|
|
38
40
|
def is_initialized() -> bool:
|
|
39
41
|
return False
|
|
40
42
|
|
|
43
|
+
@staticmethod
|
|
44
|
+
def get_rank() -> int:
|
|
45
|
+
return 0
|
|
46
|
+
|
|
47
|
+
class _IncompatibleKeys(
|
|
48
|
+
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
|
|
49
|
+
):
|
|
50
|
+
__slots__ = ()
|
|
51
|
+
|
|
52
|
+
def __repr__(self) -> str:
|
|
53
|
+
if not self.missing_keys and not self.unexpected_keys:
|
|
54
|
+
return "<All keys matched successfully>"
|
|
55
|
+
return super().__repr__()
|
|
56
|
+
|
|
57
|
+
__str__ = __repr__
|
|
58
|
+
|
|
41
59
|
dist = _Dummy()
|
|
60
|
+
_IncompatibleKeys = _IncompatibleKeys
|
|
42
61
|
|
|
43
62
|
_once_lock = Lock()
|
|
44
63
|
_once_keys: set[tuple[str, str]] = set()
|
|
@@ -106,7 +125,7 @@ def setup_logger(
|
|
|
106
125
|
add_rank = False
|
|
107
126
|
|
|
108
127
|
if add_rank:
|
|
109
|
-
rank =
|
|
128
|
+
rank = dist.get_rank()
|
|
110
129
|
channel = f"rank:{rank} - {base}"
|
|
111
130
|
else:
|
|
112
131
|
channel = base
|
|
@@ -8,17 +8,17 @@ kostyl/ml/configs/__init__.py,sha256=IetcivbqYGutowLqxdKp7QR4tkXKBr4m8t4Zkk9jHZU
|
|
|
8
8
|
kostyl/ml/configs/base_model.py,sha256=Eofn14J9RsjpVx_J4rp6C19pDDCANU4hr3JtX-d0FpQ,4820
|
|
9
9
|
kostyl/ml/configs/hyperparams.py,sha256=CaVNEvpW4LvlHhLsbe2FockIGI1mJufCqjH298nYgKE,2971
|
|
10
10
|
kostyl/ml/configs/training_settings.py,sha256=0cyKF6EuTv6KgXC1g0Oy9zbwnMkDP4uXTJJO1TRQ0aY,2556
|
|
11
|
-
kostyl/ml/dist_utils.py,sha256=
|
|
11
|
+
kostyl/ml/dist_utils.py,sha256=Onf0KHVLA8oeUgZTcTdmR9qiM22f2uYLoNwgLbMGJWk,3495
|
|
12
12
|
kostyl/ml/lightning/__init__.py,sha256=-F3JAyq8KU1d-nACWryGu8d1CbvWbQ1rXFdeRwfE2X8,175
|
|
13
13
|
kostyl/ml/lightning/callbacks/__init__.py,sha256=Vd-rozY4T9Prr3IMqbliXxj6sC6y9XsovHQqRwzc2HI,297
|
|
14
14
|
kostyl/ml/lightning/callbacks/checkpoint.py,sha256=FooGeeUz6TtoXQglpcK16NWAmSX3fbu6wntRtK3a_Io,1936
|
|
15
15
|
kostyl/ml/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
|
|
16
|
-
kostyl/ml/lightning/callbacks/registry_uploading.py,sha256=
|
|
16
|
+
kostyl/ml/lightning/callbacks/registry_uploading.py,sha256=jJdSoFIkTcGLnZIKBzbAkt9MAgpZURLCQdd0DwAC5gk,4659
|
|
17
17
|
kostyl/ml/lightning/extenstions/__init__.py,sha256=OY6QGv1agYgqqKf1xJBrxgp_i8FunVfPzYezfaRrGXU,182
|
|
18
18
|
kostyl/ml/lightning/extenstions/custom_module.py,sha256=nB5jW7cqRD1tyh-q5LD2EtiFQwFkLXpnS9Yu6c5xMRg,5987
|
|
19
19
|
kostyl/ml/lightning/extenstions/pretrained_model.py,sha256=ZOKtrVl095cwvI43wAz-Xdzu4l0v0lHH2mfh4WXwxKQ,5059
|
|
20
20
|
kostyl/ml/lightning/loggers/__init__.py,sha256=e51dszaoJbuzwBkbdugmuDsPldoSO4yaRgmZUg1Bdy0,71
|
|
21
|
-
kostyl/ml/lightning/loggers/tb_logger.py,sha256=
|
|
21
|
+
kostyl/ml/lightning/loggers/tb_logger.py,sha256=j02HK5ue8yzXXV8FWKmmXyHkFlIxgHx-ahHWk_rFCZs,893
|
|
22
22
|
kostyl/ml/lightning/steps_estimation.py,sha256=fTZ0IrUEZV3H6VYlx4GYn56oco56mMiB7FO9F0Z7qc4,1511
|
|
23
23
|
kostyl/ml/metrics_formatting.py,sha256=w0rTz61z0Um_d2pomYLvcQFcZX_C-KolZcIPRsa1efE,1421
|
|
24
24
|
kostyl/ml/params_groups.py,sha256=nUyw5d06Pvy9QPiYtZzLYR87xwXqJLxbHthgQH8oSCM,3583
|
|
@@ -29,7 +29,7 @@ kostyl/ml/schedulers/cosine.py,sha256=jufULVHn_L_ZZEc3ZTG3QCY_pc0jlAMH5Aw496T31j
|
|
|
29
29
|
kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
|
|
30
30
|
kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
|
|
31
31
|
kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
|
|
32
|
-
kostyl/utils/logging.py,sha256=
|
|
33
|
-
kostyl_toolkit-0.1.
|
|
34
|
-
kostyl_toolkit-0.1.
|
|
35
|
-
kostyl_toolkit-0.1.
|
|
32
|
+
kostyl/utils/logging.py,sha256=CQmPZ1x7yiVz56OkK6IZCtdoHs_Owo7fxJ03oOct-Qc,5782
|
|
33
|
+
kostyl_toolkit-0.1.20.dist-info/WHEEL,sha256=3id4o64OvRm9dUknh3mMJNcfoTRK08ua5cU6DFyVy-4,79
|
|
34
|
+
kostyl_toolkit-0.1.20.dist-info/METADATA,sha256=HaAS7ZSdaBNwoyKc6juh3nTDfEgQBPbWtMhAZg7Xg3s,4269
|
|
35
|
+
kostyl_toolkit-0.1.20.dist-info/RECORD,,
|
|
File without changes
|