kostyl-toolkit 0.1.19__tar.gz → 0.1.20__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/PKG-INFO +1 -1
  2. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/dist_utils.py +13 -13
  3. kostyl_toolkit-0.1.20/kostyl/ml/lightning/callbacks/registry_uploading.py +138 -0
  4. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/utils/logging.py +22 -3
  5. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/pyproject.toml +1 -1
  6. kostyl_toolkit-0.1.19/kostyl/ml/lightning/callbacks/registry_uploading.py +0 -118
  7. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/README.md +0 -0
  8. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/__init__.py +0 -0
  9. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/__init__.py +0 -0
  10. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/clearml/__init__.py +0 -0
  11. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/clearml/dataset_utils.py +0 -0
  12. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/clearml/logging_utils.py +0 -0
  13. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/clearml/pulling_utils.py +0 -0
  14. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/configs/__init__.py +0 -0
  15. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/configs/base_model.py +0 -0
  16. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/configs/hyperparams.py +0 -0
  17. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/configs/training_settings.py +0 -0
  18. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/__init__.py +0 -0
  19. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/callbacks/__init__.py +0 -0
  20. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/callbacks/checkpoint.py +0 -0
  21. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/callbacks/early_stopping.py +0 -0
  22. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/extenstions/__init__.py +0 -0
  23. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/extenstions/custom_module.py +0 -0
  24. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/extenstions/pretrained_model.py +0 -0
  25. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/loggers/__init__.py +0 -0
  26. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/loggers/tb_logger.py +0 -0
  27. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/lightning/steps_estimation.py +0 -0
  28. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/metrics_formatting.py +0 -0
  29. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/params_groups.py +0 -0
  30. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/schedulers/__init__.py +0 -0
  31. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/schedulers/base.py +0 -0
  32. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/schedulers/composite.py +0 -0
  33. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/ml/schedulers/cosine.py +0 -0
  34. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/utils/__init__.py +0 -0
  35. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/utils/dict_manipulations.py +0 -0
  36. {kostyl_toolkit-0.1.19 → kostyl_toolkit-0.1.20}/kostyl/utils/fs.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.19
3
+ Version: 0.1.20
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -86,22 +86,22 @@ def scale_lrs_by_world_size(
86
86
  return lrs
87
87
 
88
88
 
89
- def _get_rank() -> int:
89
+ def get_rank() -> int:
90
+ """Gets the rank of the current process in a distributed setting."""
90
91
  if dist.is_initialized():
91
- rank = dist.get_rank()
92
- else:
93
- rank = int(os.environ.get("RANK", 0))
94
- return rank
92
+ return dist.get_rank()
93
+ if "RANK" in os.environ:
94
+ return int(os.environ["RANK"])
95
+ if "SLURM_PROCID" in os.environ:
96
+ return int(os.environ["SLURM_PROCID"])
97
+ if "LOCAL_RANK" in os.environ:
98
+ return int(os.environ["LOCAL_RANK"])
99
+ return 0
95
100
 
96
101
 
97
102
  def is_main_process() -> bool:
98
103
  """Checks if the current process is the main process (rank 0) in a distributed setting."""
99
- if dist.is_initialized():
100
- return dist.get_rank() == 0
101
- if "RANK" in os.environ:
102
- return int(os.environ["RANK"]) == 0
103
- if "SLURM_PROCID" in os.environ:
104
- return int(os.environ["SLURM_PROCID"]) == 0
105
- if "LOCAL_RANK" in os.environ:
106
- return int(os.environ["LOCAL_RANK"]) == 0
104
+ rank = get_rank()
105
+ if rank != 0:
106
+ return False
107
107
  return True
@@ -0,0 +1,138 @@
1
+ from typing import Literal
2
+ from typing import override
3
+
4
+ from clearml import OutputModel
5
+ from clearml import Task
6
+ from lightning import Trainer
7
+ from lightning.pytorch.callbacks import Callback
8
+ from lightning.pytorch.callbacks import ModelCheckpoint
9
+
10
+ from kostyl.ml.clearml.logging_utils import find_version_in_tags
11
+ from kostyl.ml.clearml.logging_utils import increment_version
12
+ from kostyl.ml.lightning import KostylLightningModule
13
+ from kostyl.utils.logging import setup_logger
14
+
15
+
16
+ logger = setup_logger()
17
+
18
+
19
+ class ClearMLRegistryUploaderCallback(Callback):
20
+ """PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
21
+
22
+ def __init__(
23
+ self,
24
+ task: Task,
25
+ ckpt_callback: ModelCheckpoint,
26
+ output_model_name: str,
27
+ output_model_tags: list[str] | None = None,
28
+ verbose: bool = True,
29
+ enable_tag_versioning: bool = True,
30
+ uploading_frequency: Literal[
31
+ "after-every-eval", "on-train-end"
32
+ ] = "on-train-end",
33
+ ) -> None:
34
+ """
35
+ Initializes the ClearMLRegistryUploaderCallback.
36
+
37
+ Args:
38
+ task: ClearML task.
39
+ ckpt_callback: ModelCheckpoint instance used by Trainer.
40
+ output_model_name: Name for the ClearML output model.
41
+ output_model_tags: Tags for the output model.
42
+ verbose: Whether to log messages.
43
+ enable_tag_versioning: Whether to enable versioning in tags. If True,
44
+ the version tag (e.g., "v1.0") will be automatically incremented or if not present, added as "v1.0".
45
+ uploading_frequency: When to upload:
46
+ - "after-every-eval": after each validation phase.
47
+ - "on-train-end": once at the end of training.
48
+
49
+ """
50
+ super().__init__()
51
+ if output_model_tags is None:
52
+ output_model_tags = []
53
+
54
+ self.task = task
55
+ self.ckpt_callback = ckpt_callback
56
+ self.output_model_name = output_model_name
57
+ self.output_model_tags = output_model_tags
58
+ self.verbose = verbose
59
+ self.uploading_frequency = uploading_frequency
60
+ self.enable_tag_versioning = enable_tag_versioning
61
+
62
+ self._output_model: OutputModel | None = None
63
+ self._last_best_model_path: str = ""
64
+ return
65
+
66
+ def _create_output_model(self, pl_module: "KostylLightningModule") -> OutputModel:
67
+ if self.enable_tag_versioning:
68
+ version = find_version_in_tags(self.output_model_tags)
69
+ if version is None:
70
+ self.output_model_tags.append("v1.0")
71
+ else:
72
+ new_version = increment_version(version)
73
+ self.output_model_tags.remove(version)
74
+ self.output_model_tags.append(new_version)
75
+
76
+ if "LightningCheckpoint" not in self.output_model_tags:
77
+ self.output_model_tags.append("LightningCheckpoint")
78
+ config = pl_module.model_config
79
+ if config is not None:
80
+ config = config.to_dict()
81
+
82
+ return OutputModel(
83
+ task=self.task,
84
+ name=self.output_model_name,
85
+ framework="PyTorch",
86
+ tags=self.output_model_tags,
87
+ config_dict=config,
88
+ )
89
+
90
+ def _upload_best_checkpoint(self, pl_module: "KostylLightningModule") -> None:
91
+ current_best = self.ckpt_callback.best_model_path
92
+
93
+ if not current_best:
94
+ if self.verbose:
95
+ logger.info("No best model found yet to upload")
96
+ return
97
+
98
+ if current_best == self._last_best_model_path:
99
+ if self.verbose:
100
+ logger.info("Best model unchanged since last upload")
101
+ return
102
+
103
+ if self._output_model is None:
104
+ self._output_model = self._create_output_model(pl_module)
105
+
106
+ if self.verbose:
107
+ logger.info(f"Uploading best model from {current_best}")
108
+
109
+ self._output_model.update_weights(
110
+ current_best,
111
+ auto_delete_file=False,
112
+ async_enable=False,
113
+ )
114
+
115
+ self._last_best_model_path = current_best
116
+ return
117
+
118
+ @override
119
+ def on_validation_end(
120
+ self, trainer: Trainer, pl_module: "KostylLightningModule"
121
+ ) -> None:
122
+ if self.uploading_frequency != "after-every-eval":
123
+ return
124
+ if not trainer.is_global_zero:
125
+ return
126
+
127
+ self._upload_best_checkpoint(pl_module)
128
+ return
129
+
130
+ @override
131
+ def on_train_end(
132
+ self, trainer: Trainer, pl_module: "KostylLightningModule"
133
+ ) -> None:
134
+ if not trainer.is_global_zero:
135
+ return
136
+
137
+ self._upload_best_checkpoint(pl_module)
138
+ return
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- import os
5
4
  import sys
6
5
  import uuid
6
+ from collections import namedtuple
7
7
  from copy import deepcopy
8
8
  from functools import partialmethod
9
9
  from pathlib import Path
@@ -13,7 +13,6 @@ from typing import Literal
13
13
  from typing import cast
14
14
 
15
15
  from loguru import logger as _base_logger
16
- from torch.nn.modules.module import _IncompatibleKeys
17
16
 
18
17
 
19
18
  if TYPE_CHECKING:
@@ -27,6 +26,9 @@ else:
27
26
 
28
27
  try:
29
28
  import torch.distributed as dist
29
+ from torch.nn.modules.module import (
30
+ _IncompatibleKeys, # pyright: ignore[reportAssignmentType]
31
+ )
30
32
  except Exception:
31
33
 
32
34
  class _Dummy:
@@ -38,7 +40,24 @@ except Exception:
38
40
  def is_initialized() -> bool:
39
41
  return False
40
42
 
43
+ @staticmethod
44
+ def get_rank() -> int:
45
+ return 0
46
+
47
+ class _IncompatibleKeys(
48
+ namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
49
+ ):
50
+ __slots__ = ()
51
+
52
+ def __repr__(self) -> str:
53
+ if not self.missing_keys and not self.unexpected_keys:
54
+ return "<All keys matched successfully>"
55
+ return super().__repr__()
56
+
57
+ __str__ = __repr__
58
+
41
59
  dist = _Dummy()
60
+ _IncompatibleKeys = _IncompatibleKeys
42
61
 
43
62
  _once_lock = Lock()
44
63
  _once_keys: set[tuple[str, str]] = set()
@@ -106,7 +125,7 @@ def setup_logger(
106
125
  add_rank = False
107
126
 
108
127
  if add_rank:
109
- rank = int(os.environ.get("RANK", "0"))
128
+ rank = dist.get_rank()
110
129
  channel = f"rank:{rank} - {base}"
111
130
  else:
112
131
  channel = base
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "kostyl-toolkit"
3
- version = "0.1.19"
3
+ version = "0.1.20"
4
4
  description = "Kickass Orchestration System for Training, Yielding & Logging "
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -1,118 +0,0 @@
1
- from typing import Literal
2
- from typing import override
3
-
4
- from clearml import OutputModel
5
- from clearml import Task
6
- from lightning import Trainer
7
- from lightning.pytorch.callbacks import Callback
8
- from lightning.pytorch.callbacks import ModelCheckpoint
9
-
10
- from kostyl.ml.clearml.logging_utils import find_version_in_tags
11
- from kostyl.ml.clearml.logging_utils import increment_version
12
- from kostyl.ml.lightning import KostylLightningModule
13
- from kostyl.utils.logging import setup_logger
14
-
15
-
16
- logger = setup_logger()
17
-
18
-
19
- class ClearMLRegistryUploaderCallback(Callback):
20
- """PyTorch Lightning callback to upload the best model checkpoint to ClearML."""
21
-
22
- def __init__(
23
- self,
24
- task: Task,
25
- ckpt_callback: ModelCheckpoint,
26
- output_model_name: str,
27
- output_model_tags: list[str] | None = None,
28
- verbose: bool = True,
29
- uploading_frequency: Literal[
30
- "after-every-eval", "on-train-end"
31
- ] = "on-train-end",
32
- ) -> None:
33
- """
34
- Initialize the callback.
35
-
36
- Args:
37
- task (Task): The ClearML task object.
38
- ckpt_callback (ModelCheckpoint): The model checkpoint callback.
39
- output_model_name (str): The name for the output model.
40
- output_model_tags (list[str] | None, optional): Tags for the output model. Defaults to None, which is converted to an empty list.
41
- verbose (bool, optional): Whether to log verbose messages. Defaults to True.
42
- uploading_frequency (Literal["after-every-eval", "on-train-end"]): Frequency of uploading the model. Defaults to "on-train-end".
43
-
44
- """
45
- super().__init__()
46
- if output_model_tags is None:
47
- output_model_tags = []
48
- self.task = task
49
- self.ckpt_callback = ckpt_callback
50
- self.output_model_name = output_model_name
51
- self.output_model_tags = output_model_tags
52
- self.verbose = verbose
53
- self.uploading_frequency = uploading_frequency
54
-
55
- self._output_model: OutputModel | None = None
56
- self._last_best_model_path: str = ""
57
- return
58
-
59
- def _create_output_model(self, pl_module: KostylLightningModule) -> OutputModel:
60
- version = find_version_in_tags(self.output_model_tags)
61
- if version is None:
62
- self.output_model_tags.append("v1.0")
63
- else:
64
- new_version = increment_version(version)
65
- self.output_model_tags.remove(version)
66
- self.output_model_tags.append(new_version)
67
-
68
- config = pl_module.model_config
69
- if config is not None:
70
- config = config.to_dict()
71
-
72
- output_model = OutputModel(
73
- task=self.task,
74
- name=self.output_model_name,
75
- framework="PyTorch",
76
- tags=self.output_model_tags,
77
- config_dict=config,
78
- )
79
- return output_model
80
-
81
- def _upload_best_checkpoint(self, pl_module: KostylLightningModule) -> None:
82
- if self._output_model is None:
83
- self._output_model = self._create_output_model(pl_module)
84
-
85
- if self.ckpt_callback.best_model_path == self._last_best_model_path:
86
- if self.verbose and (self._last_best_model_path != ""):
87
- logger.info("Best model unchanged since last upload")
88
- elif self.verbose:
89
- logger.info("No best model found yet to upload")
90
- else:
91
- if self.verbose:
92
- logger.info(
93
- f"Uploading best model from {self.ckpt_callback.best_model_path}"
94
- )
95
- self._output_model.update_weights(
96
- self.ckpt_callback.best_model_path,
97
- auto_delete_file=False,
98
- async_enable=False,
99
- )
100
- return
101
-
102
- @override
103
- def on_validation_epoch_end(
104
- self, trainer: Trainer, pl_module: KostylLightningModule
105
- ) -> None:
106
- if (not trainer.is_global_zero) or (
107
- self.uploading_frequency != "after-every-eval"
108
- ):
109
- return
110
- self._upload_best_checkpoint(pl_module)
111
- return
112
-
113
- @override
114
- def on_train_end(self, trainer: Trainer, pl_module: KostylLightningModule) -> None:
115
- if not trainer.is_global_zero:
116
- return
117
- self._upload_best_checkpoint(pl_module)
118
- return