kostyl-toolkit 0.1.19__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kostyl/ml/dist_utils.py CHANGED
@@ -86,22 +86,22 @@ def scale_lrs_by_world_size(
86
86
  return lrs
87
87
 
88
88
 
89
- def _get_rank() -> int:
89
+ def get_rank() -> int:
90
+ """Gets the rank of the current process in a distributed setting."""
90
91
  if dist.is_initialized():
91
- rank = dist.get_rank()
92
- else:
93
- rank = int(os.environ.get("RANK", 0))
94
- return rank
92
+ return dist.get_rank()
93
+ if "RANK" in os.environ:
94
+ return int(os.environ["RANK"])
95
+ if "SLURM_PROCID" in os.environ:
96
+ return int(os.environ["SLURM_PROCID"])
97
+ if "LOCAL_RANK" in os.environ:
98
+ return int(os.environ["LOCAL_RANK"])
99
+ return 0
95
100
 
96
101
 
97
102
  def is_main_process() -> bool:
98
103
  """Checks if the current process is the main process (rank 0) in a distributed setting."""
99
- if dist.is_initialized():
100
- return dist.get_rank() == 0
101
- if "RANK" in os.environ:
102
- return int(os.environ["RANK"]) == 0
103
- if "SLURM_PROCID" in os.environ:
104
- return int(os.environ["SLURM_PROCID"]) == 0
105
- if "LOCAL_RANK" in os.environ:
106
- return int(os.environ["LOCAL_RANK"]) == 0
104
+ rank = get_rank()
105
+ if rank != 0:
106
+ return False
107
107
  return True
@@ -26,93 +26,113 @@ class ClearMLRegistryUploaderCallback(Callback):
26
26
  output_model_name: str,
27
27
  output_model_tags: list[str] | None = None,
28
28
  verbose: bool = True,
29
+ enable_tag_versioning: bool = True,
29
30
  uploading_frequency: Literal[
30
31
  "after-every-eval", "on-train-end"
31
32
  ] = "on-train-end",
32
33
  ) -> None:
33
34
  """
34
- Initialize the callback.
35
+ Initializes the ClearMLRegistryUploaderCallback.
35
36
 
36
37
  Args:
37
- task (Task): The ClearML task object.
38
- ckpt_callback (ModelCheckpoint): The model checkpoint callback.
39
- output_model_name (str): The name for the output model.
40
- output_model_tags (list[str] | None, optional): Tags for the output model. Defaults to None, which is converted to an empty list.
41
- verbose (bool, optional): Whether to log verbose messages. Defaults to True.
42
- uploading_frequency (Literal["after-every-eval", "on-train-end"]): Frequency of uploading the model. Defaults to "on-train-end".
38
+ task: ClearML task.
39
+ ckpt_callback: ModelCheckpoint instance used by Trainer.
40
+ output_model_name: Name for the ClearML output model.
41
+ output_model_tags: Tags for the output model.
42
+ verbose: Whether to log messages.
43
+ enable_tag_versioning: Whether to enable versioning in tags. If True,
44
+ the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
45
+ uploading_frequency: When to upload:
46
+ - "after-every-eval": after each validation phase.
47
+ - "on-train-end": once at the end of training.
43
48
 
44
49
  """
45
50
  super().__init__()
46
51
  if output_model_tags is None:
47
52
  output_model_tags = []
53
+
48
54
  self.task = task
49
55
  self.ckpt_callback = ckpt_callback
50
56
  self.output_model_name = output_model_name
51
57
  self.output_model_tags = output_model_tags
52
58
  self.verbose = verbose
53
59
  self.uploading_frequency = uploading_frequency
60
+ self.enable_tag_versioning = enable_tag_versioning
54
61
 
55
62
  self._output_model: OutputModel | None = None
56
63
  self._last_best_model_path: str = ""
57
64
  return
58
65
 
59
- def _create_output_model(self, pl_module: KostylLightningModule) -> OutputModel:
60
- version = find_version_in_tags(self.output_model_tags)
61
- if version is None:
62
- self.output_model_tags.append("v1.0")
63
- else:
64
- new_version = increment_version(version)
65
- self.output_model_tags.remove(version)
66
- self.output_model_tags.append(new_version)
67
-
66
+ def _create_output_model(self, pl_module: "KostylLightningModule") -> OutputModel:
67
+ if self.enable_tag_versioning:
68
+ version = find_version_in_tags(self.output_model_tags)
69
+ if version is None:
70
+ self.output_model_tags.append("v1.0")
71
+ else:
72
+ new_version = increment_version(version)
73
+ self.output_model_tags.remove(version)
74
+ self.output_model_tags.append(new_version)
75
+
76
+ if "LightningCheckpoint" not in self.output_model_tags:
77
+ self.output_model_tags.append("LightningCheckpoint")
68
78
  config = pl_module.model_config
69
79
  if config is not None:
70
80
  config = config.to_dict()
71
81
 
72
- output_model = OutputModel(
82
+ return OutputModel(
73
83
  task=self.task,
74
84
  name=self.output_model_name,
75
85
  framework="PyTorch",
76
86
  tags=self.output_model_tags,
77
87
  config_dict=config,
78
88
  )
79
- return output_model
80
89
 
81
- def _upload_best_checkpoint(self, pl_module: KostylLightningModule) -> None:
82
- if self._output_model is None:
83
- self._output_model = self._create_output_model(pl_module)
90
+ def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
91
+ current_best = self.ckpt_callback.best_model_path
84
92
 
85
- if self.ckpt_callback.best_model_path == self._last_best_model_path:
86
- if self.verbose and (self._last_best_model_path != ""):
87
- logger.info("Best model unchanged since last upload")
88
- elif self.verbose:
93
+ if not current_best:
94
+ if self.verbose:
89
95
  logger.info("No best model found yet to upload")
90
- else:
96
+ return
97
+
98
+ if current_best == self._last_best_model_path:
91
99
  if self.verbose:
92
- logger.info(
93
- f"Uploading best model from {self.ckpt_callback.best_model_path}"
94
- )
95
- self._output_model.update_weights(
96
- self.ckpt_callback.best_model_path,
97
- auto_delete_file=False,
98
- async_enable=False,
99
- )
100
+ logger.info("Best model unchanged since last upload")
101
+ return
102
+
103
+ if self._output_model is None:
104
+ self._output_model = self._create_output_model(pl_module)
105
+
106
+ if self.verbose:
107
+ logger.info(f"Uploading best model from {current_best}")
108
+
109
+ self._output_model.update_weights(
110
+ current_best,
111
+ auto_delete_file=False,
112
+ async_enable=False,
113
+ )
114
+
115
+ self._last_best_model_path = current_best
100
116
  return
101
117
 
102
118
  @override
103
- def on_validation_epoch_end(
104
- self, trainer: Trainer, pl_module: KostylLightningModule
119
+ def on_validation_end(
120
+ self, trainer: Trainer, pl_module: "KostylLightningModule"
105
121
  ) -> None:
106
- if (not trainer.is_global_zero) or (
107
- self.uploading_frequency != "after-every-eval"
108
- ):
122
+ if self.uploading_frequency != "after-every-eval":
109
123
  return
124
+ if not trainer.is_global_zero:
125
+ return
126
+
110
127
  self._upload_best_checkpoint(pl_module)
111
128
  return
112
129
 
113
130
  @override
114
- def on_train_end(self, trainer: Trainer, pl_module: KostylLightningModule) -> None:
131
+ def on_train_end(
132
+ self, trainer: Trainer, pl_module: "KostylLightningModule"
133
+ ) -> None:
115
134
  if not trainer.is_global_zero:
116
135
  return
136
+
117
137
  self._upload_best_checkpoint(pl_module)
118
138
  return
kostyl/utils/logging.py CHANGED
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- import os
5
4
  import sys
6
5
  import uuid
6
+ from collections import namedtuple
7
7
  from copy import deepcopy
8
8
  from functools import partialmethod
9
9
  from pathlib import Path
@@ -13,7 +13,6 @@ from typing import Literal
13
13
  from typing import cast
14
14
 
15
15
  from loguru import logger as _base_logger
16
- from torch.nn.modules.module import _IncompatibleKeys
17
16
 
18
17
 
19
18
  if TYPE_CHECKING:
@@ -27,6 +26,9 @@ else:
27
26
 
28
27
  try:
29
28
  import torch.distributed as dist
29
+ from torch.nn.modules.module import (
30
+ _IncompatibleKeys, # pyright: ignore[reportAssignmentType]
31
+ )
30
32
  except Exception:
31
33
 
32
34
  class _Dummy:
@@ -38,7 +40,24 @@ except Exception:
38
40
  def is_initialized() -> bool:
39
41
  return False
40
42
 
43
+ @staticmethod
44
+ def get_rank() -> int:
45
+ return 0
46
+
47
+ class _IncompatibleKeys(
48
+ namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
49
+ ):
50
+ __slots__ = ()
51
+
52
+ def __repr__(self) -> str:
53
+ if not self.missing_keys and not self.unexpected_keys:
54
+ return "<All keys matched successfully>"
55
+ return super().__repr__()
56
+
57
+ __str__ = __repr__
58
+
41
59
  dist = _Dummy()
60
+ _IncompatibleKeys = _IncompatibleKeys
42
61
 
43
62
  _once_lock = Lock()
44
63
  _once_keys: set[tuple[str, str]] = set()
@@ -106,7 +125,7 @@ def setup_logger(
106
125
  add_rank = False
107
126
 
108
127
  if add_rank:
109
- rank = int(os.environ.get("RANK", "0"))
128
+ rank = dist.get_rank()
110
129
  channel = f"rank:{rank} - {base}"
111
130
  else:
112
131
  channel = base
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.19
3
+ Version: 0.1.20
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -8,12 +8,12 @@ kostyl/ml/configs/__init__.py,sha256=IetcivbqYGutowLqxdKp7QR4tkXKBr4m8t4Zkk9jHZU
8
8
  kostyl/ml/configs/base_model.py,sha256=Eofn14J9RsjpVx_J4rp6C19pDDCANU4hr3JtX-d0FpQ,4820
9
9
  kostyl/ml/configs/hyperparams.py,sha256=CaVNEvpW4LvlHhLsbe2FockIGI1mJufCqjH298nYgKE,2971
10
10
  kostyl/ml/configs/training_settings.py,sha256=0cyKF6EuTv6KgXC1g0Oy9zbwnMkDP4uXTJJO1TRQ0aY,2556
11
- kostyl/ml/dist_utils.py,sha256=XKshB_9j4G8BLN_fU2sZtvLk4jsgjPb8z_XOhrEGAaI,3502
11
+ kostyl/ml/dist_utils.py,sha256=Onf0KHVLA8oeUgZTcTdmR9qiM22f2uYLoNwgLbMGJWk,3495
12
12
  kostyl/ml/lightning/__init__.py,sha256=-F3JAyq8KU1d-nACWryGu8d1CbvWbQ1rXFdeRwfE2X8,175
13
13
  kostyl/ml/lightning/callbacks/__init__.py,sha256=Vd-rozY4T9Prr3IMqbliXxj6sC6y9XsovHQqRwzc2HI,297
14
14
  kostyl/ml/lightning/callbacks/checkpoint.py,sha256=FooGeeUz6TtoXQglpcK16NWAmSX3fbu6wntRtK3a_Io,1936
15
15
  kostyl/ml/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
16
- kostyl/ml/lightning/callbacks/registry_uploading.py,sha256=AAohqTGigWMGLHt-yITAznIVsaDVXR_jisLDzFNzlu8,4275
16
+ kostyl/ml/lightning/callbacks/registry_uploading.py,sha256=jJdSoFIkTcGLnZIKBzbAkt9MAgpZURLCQdd0DwAC5gk,4659
17
17
  kostyl/ml/lightning/extenstions/__init__.py,sha256=OY6QGv1agYgqqKf1xJBrxgp_i8FunVfPzYezfaRrGXU,182
18
18
  kostyl/ml/lightning/extenstions/custom_module.py,sha256=nB5jW7cqRD1tyh-q5LD2EtiFQwFkLXpnS9Yu6c5xMRg,5987
19
19
  kostyl/ml/lightning/extenstions/pretrained_model.py,sha256=ZOKtrVl095cwvI43wAz-Xdzu4l0v0lHH2mfh4WXwxKQ,5059
@@ -29,7 +29,7 @@ kostyl/ml/schedulers/cosine.py,sha256=jufULVHn_L_ZZEc3ZTG3QCY_pc0jlAMH5Aw496T31j
29
29
  kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
30
30
  kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
31
31
  kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
32
- kostyl/utils/logging.py,sha256=3MvfDPArZhwakHu5nMlp_LpOsWg0E0SP26y41clsBtA,5232
33
- kostyl_toolkit-0.1.19.dist-info/WHEEL,sha256=3id4o64OvRm9dUknh3mMJNcfoTRK08ua5cU6DFyVy-4,79
34
- kostyl_toolkit-0.1.19.dist-info/METADATA,sha256=zAq1MkJ8Wt88R_Zlv5O_pTNNib80qVuKC4rZjtCFhO8,4269
35
- kostyl_toolkit-0.1.19.dist-info/RECORD,,
32
+ kostyl/utils/logging.py,sha256=CQmPZ1x7yiVz56OkK6IZCtdoHs_Owo7fxJ03oOct-Qc,5782
33
+ kostyl_toolkit-0.1.20.dist-info/WHEEL,sha256=3id4o64OvRm9dUknh3mMJNcfoTRK08ua5cU6DFyVy-4,79
34
+ kostyl_toolkit-0.1.20.dist-info/METADATA,sha256=HaAS7ZSdaBNwoyKc6juh3nTDfEgQBPbWtMhAZg7Xg3s,4269
35
+ kostyl_toolkit-0.1.20.dist-info/RECORD,,