kostyl-toolkit 0.1.10__tar.gz → 0.1.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {kostyl_toolkit-0.1.10 → kostyl_toolkit-0.1.12}/PKG-INFO +1 -1
  2. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/clearml/pulling_utils.py +6 -3
  3. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/configs/__init__.py +0 -2
  4. kostyl_toolkit-0.1.10/kostyl/ml_core/clearml/config_mixin.py → kostyl_toolkit-0.1.12/kostyl/ml/configs/base_model.py +63 -11
  5. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/configs/hyperparams.py +3 -2
  6. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/configs/training_settings.py +0 -11
  7. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/lightning/callbacks/checkpoint.py +2 -2
  8. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/lightning/callbacks/early_stopping.py +1 -1
  9. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/lightning/callbacks/registry_uploading.py +3 -3
  10. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/lightning/extenstions/custom_module.py +12 -13
  11. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/lightning/extenstions/pretrained_model.py +12 -2
  12. kostyl_toolkit-0.1.12/kostyl/ml/lightning/loggers/__init__.py +4 -0
  13. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/lightning/loggers/tb_logger.py +1 -1
  14. {kostyl_toolkit-0.1.10 → kostyl_toolkit-0.1.12}/pyproject.toml +1 -1
  15. kostyl_toolkit-0.1.10/kostyl/ml_core/configs/base_model.py +0 -60
  16. kostyl_toolkit-0.1.10/kostyl/ml_core/lightning/loggers/__init__.py +0 -0
  17. {kostyl_toolkit-0.1.10 → kostyl_toolkit-0.1.12}/README.md +0 -0
  18. {kostyl_toolkit-0.1.10 → kostyl_toolkit-0.1.12}/kostyl/__init__.py +0 -0
  19. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/__init__.py +0 -0
  20. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/clearml/__init__.py +0 -0
  21. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/clearml/dataset_utils.py +0 -0
  22. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/clearml/logging_utils.py +0 -0
  23. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/dist_utils.py +0 -0
  24. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/lightning/__init__.py +0 -0
  25. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/lightning/callbacks/__init__.py +0 -0
  26. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/lightning/extenstions/__init__.py +0 -0
  27. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/lightning/steps_estimation.py +0 -0
  28. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/metrics_formatting.py +0 -0
  29. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/params_groups.py +0 -0
  30. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/schedulers/__init__.py +0 -0
  31. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/schedulers/base.py +0 -0
  32. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/schedulers/composite.py +0 -0
  33. {kostyl_toolkit-0.1.10/kostyl/ml_core → kostyl_toolkit-0.1.12/kostyl/ml}/schedulers/cosine.py +0 -0
  34. {kostyl_toolkit-0.1.10 → kostyl_toolkit-0.1.12}/kostyl/utils/__init__.py +0 -0
  35. {kostyl_toolkit-0.1.10 → kostyl_toolkit-0.1.12}/kostyl/utils/dict_manipulations.py +0 -0
  36. {kostyl_toolkit-0.1.10 → kostyl_toolkit-0.1.12}/kostyl/utils/fs.py +0 -0
  37. {kostyl_toolkit-0.1.10 → kostyl_toolkit-0.1.12}/kostyl/utils/logging.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.10
3
+ Version: 0.1.12
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -3,12 +3,13 @@ from typing import cast
3
3
 
4
4
  from clearml import InputModel
5
5
  from clearml import Task
6
+ from transformers import Any
6
7
  from transformers import AutoModel
7
8
  from transformers import AutoTokenizer
8
9
  from transformers import PreTrainedModel
9
10
  from transformers import PreTrainedTokenizerBase
10
11
 
11
- from kostyl.ml_core.lightning.extenstions.pretrained_model import (
12
+ from kostyl.ml.lightning.extenstions.pretrained_model import (
12
13
  LightningCheckpointLoaderMixin,
13
14
  )
14
15
 
@@ -48,6 +49,7 @@ def get_model_from_clearml[
48
49
  model: type[TModel],
49
50
  task: Task | None = None,
50
51
  ignore_remote_overrides: bool = True,
52
+ **kwargs: Any,
51
53
  ) -> TModel:
52
54
  """
53
55
  Retrieve a pretrained model from ClearML and instantiate it using the appropriate loader.
@@ -59,6 +61,7 @@ def get_model_from_clearml[
59
61
  will be connected to this task, with remote overrides optionally ignored.
60
62
  ignore_remote_overrides: When connecting the input model to the provided task, determines whether
61
63
  remote configuration overrides should be ignored.
64
+ **kwargs: Additional keyword arguments to pass to the model loading method.
62
65
 
63
66
  Returns:
64
67
  An instantiated model loaded either from a ClearML package directory or a Lightning checkpoint.
@@ -72,13 +75,13 @@ def get_model_from_clearml[
72
75
  local_path = Path(input_model.get_local_copy(raise_on_error=True))
73
76
 
74
77
  if local_path.is_dir() and input_model._is_package():
75
- model_instance = model.from_pretrained(local_path)
78
+ model_instance = model.from_pretrained(local_path, **kwargs)
76
79
  elif local_path.suffix == ".ckpt":
77
80
  if not issubclass(model, LightningCheckpointLoaderMixin):
78
81
  raise ValueError(
79
82
  f"Model class {model.__name__} is not compatible with Lightning checkpoints."
80
83
  )
81
- model_instance = model.from_lighting_checkpoint(local_path)
84
+ model_instance = model.from_lighting_checkpoint(local_path, **kwargs)
82
85
  else:
83
86
  raise ValueError(
84
87
  f"Unsupported model format for path: {local_path}. "
@@ -4,7 +4,6 @@ from .hyperparams import Lr
4
4
  from .hyperparams import Optimizer
5
5
  from .hyperparams import WeightDecay
6
6
  from .training_settings import CheckpointConfig
7
- from .training_settings import ClearMLTrainingSettings
8
7
  from .training_settings import DataConfig
9
8
  from .training_settings import DDPStrategyConfig
10
9
  from .training_settings import EarlyStoppingConfig
@@ -16,7 +15,6 @@ from .training_settings import TrainingSettings
16
15
 
17
16
  __all__ = [
18
17
  "CheckpointConfig",
19
- "ClearMLTrainingSettings",
20
18
  "DDPStrategyConfig",
21
19
  "DataConfig",
22
20
  "EarlyStoppingConfig",
@@ -1,29 +1,75 @@
1
1
  from pathlib import Path
2
+ from typing import Self
2
3
  from typing import TypeVar
3
4
 
4
- import clearml
5
5
  from caseconverter import pascalcase
6
6
  from caseconverter import snakecase
7
+ from clearml import Task
8
+ from pydantic import BaseModel as PydanticBaseModel
7
9
 
8
- from kostyl.ml_core.configs.base_model import KostylBaseModel
9
10
  from kostyl.utils.dict_manipulations import convert_to_flat_dict
10
11
  from kostyl.utils.dict_manipulations import flattened_dict_to_nested
11
12
  from kostyl.utils.fs import load_config
12
13
 
13
14
 
14
- TModel = TypeVar("TModel", bound="KostylBaseModel")
15
+ TConfig = TypeVar("TConfig", bound=PydanticBaseModel)
15
16
 
16
17
 
17
- class ClearMLConfigMixin:
18
- """Pydantic mixin class providing ClearML configuration loading and syncing functionality."""
18
+ class BaseModelWithConfigLoading(PydanticBaseModel):
19
+ """Pydantic class providing basic configuration loading functionality."""
20
+
21
+ @classmethod
22
+ def from_file(
23
+ cls: type[Self], # pyright: ignore
24
+ path: str | Path,
25
+ ) -> Self:
26
+ """
27
+ Create an instance of the class from a configuration file.
28
+
29
+ Args:
30
+ cls_: The class type to instantiate.
31
+ path (str | Path): Path to the configuration file.
32
+
33
+ Returns:
34
+ An instance of the class created from the configuration file.
35
+
36
+ """
37
+ config = load_config(path)
38
+ instance = cls.model_validate(config)
39
+ return instance
40
+
41
+ @classmethod
42
+ def from_dict(
43
+ cls: type[Self], # pyright: ignore
44
+ state_dict: dict,
45
+ ) -> Self:
46
+ """
47
+ Creates an instance from a dictionary.
48
+
49
+ Args:
50
+ cls_: The class type to instantiate.
51
+ state_dict (dict): A dictionary representing the state of the
52
+ class that must be validated and used for initialization.
53
+
54
+ Returns:
55
+ An initialized instance of the class based on the
56
+ provided state dictionary.
57
+
58
+ """
59
+ instance = cls.model_validate(state_dict)
60
+ return instance
61
+
62
+
63
+ class BaseModelWithClearmlSyncing(BaseModelWithConfigLoading):
64
+ """Pydantic class providing ClearML configuration loading and syncing functionality."""
19
65
 
20
66
  @classmethod
21
67
  def connect_as_file(
22
- cls: type[TModel], # pyright: ignore
23
- task: clearml.Task,
68
+ cls: type[Self], # pyright: ignore
69
+ task: Task,
24
70
  path: str | Path,
25
71
  alias: str | None = None,
26
- ) -> TModel:
72
+ ) -> Self:
27
73
  """
28
74
  Connects the configuration file to a ClearML task and creates an instance of the class from it.
29
75
 
@@ -58,11 +104,11 @@ class ClearMLConfigMixin:
58
104
 
59
105
  @classmethod
60
106
  def connect_as_dict(
61
- cls: type[TModel], # pyright: ignore
62
- task: clearml.Task,
107
+ cls: type[Self], # pyright: ignore
108
+ task: Task,
63
109
  path: str | Path,
64
110
  alias: str | None = None,
65
- ) -> TModel:
111
+ ) -> Self:
66
112
  """
67
113
  Connects configuration from a file as a dictionary to a ClearML task and creates an instance of the class.
68
114
 
@@ -89,3 +135,9 @@ class ClearMLConfigMixin:
89
135
 
90
136
  model = cls.from_dict(state_dict=config)
91
137
  return model
138
+
139
+
140
+ class KostylBaseModel(BaseModelWithClearmlSyncing):
141
+ """A Pydantic model class with basic configuration loading functionality."""
142
+
143
+ pass
@@ -2,9 +2,10 @@ from pydantic import BaseModel
2
2
  from pydantic import Field
3
3
  from pydantic import model_validator
4
4
 
5
- from kostyl.ml_core.clearml.config_mixin import ClearMLConfigMixin
6
5
  from kostyl.utils.logging import setup_logger
7
6
 
7
+ from .base_model import KostylBaseModel
8
+
8
9
 
9
10
  logger = setup_logger(fmt="only_message")
10
11
 
@@ -74,7 +75,7 @@ class WeightDecay(BaseModel):
74
75
  return self
75
76
 
76
77
 
77
- class HyperparamsConfig(BaseModel, ClearMLConfigMixin):
78
+ class HyperparamsConfig(KostylBaseModel):
78
79
  """Model training hyperparameters configuration."""
79
80
 
80
81
  grad_clip_val: float | None = Field(default=None, gt=0, validate_default=False)
@@ -3,7 +3,6 @@ from typing import Literal
3
3
  from pydantic import BaseModel
4
4
  from pydantic import Field
5
5
 
6
- from kostyl.ml_core.clearml.config_mixin import ClearMLConfigMixin
7
6
  from kostyl.utils.logging import setup_logger
8
7
 
9
8
  from .base_model import KostylBaseModel
@@ -103,13 +102,3 @@ class TrainingSettings(KostylBaseModel):
103
102
  early_stopping: EarlyStoppingConfig | None = None
104
103
  checkpoint: CheckpointConfig
105
104
  data: DataConfig
106
-
107
-
108
- class ClearMLTrainingSettings(
109
- TrainingSettings,
110
- ClearMLConfigMixin,
111
- ):
112
- """Training parameters configuration with ClearML features support (config syncing, model identifiers tracking and etc)."""
113
-
114
- model_id: str
115
- tokenizer_id: str
@@ -3,8 +3,8 @@ from shutil import rmtree
3
3
 
4
4
  from lightning.pytorch.callbacks import ModelCheckpoint
5
5
 
6
- from kostyl.ml_core.configs import CheckpointConfig
7
- from kostyl.ml_core.dist_utils import is_main_process
6
+ from kostyl.ml.configs import CheckpointConfig
7
+ from kostyl.ml.dist_utils import is_main_process
8
8
  from kostyl.utils import setup_logger
9
9
 
10
10
 
@@ -1,6 +1,6 @@
1
1
  from lightning.pytorch.callbacks import EarlyStopping
2
2
 
3
- from kostyl.ml_core.configs import EarlyStoppingConfig
3
+ from kostyl.ml.configs import EarlyStoppingConfig
4
4
 
5
5
 
6
6
  def setup_early_stopping_callback(
@@ -7,9 +7,9 @@ from lightning import Trainer
7
7
  from lightning.pytorch.callbacks import Callback
8
8
  from lightning.pytorch.callbacks import ModelCheckpoint
9
9
 
10
- from kostyl.ml_core.clearml.logging_utils import find_version_in_tags
11
- from kostyl.ml_core.clearml.logging_utils import increment_version
12
- from kostyl.ml_core.lightning import KostylLightningModule
10
+ from kostyl.ml.clearml.logging_utils import find_version_in_tags
11
+ from kostyl.ml.clearml.logging_utils import increment_version
12
+ from kostyl.ml.lightning import KostylLightningModule
13
13
  from kostyl.utils.logging import setup_logger
14
14
 
15
15
 
@@ -15,9 +15,8 @@ from torchmetrics import MetricCollection
15
15
  from transformers import PretrainedConfig
16
16
  from transformers import PreTrainedModel
17
17
 
18
- from kostyl.ml_core.configs import HyperparamsConfig
19
- from kostyl.ml_core.metrics_formatting import apply_suffix
20
- from kostyl.ml_core.schedulers.base import BaseScheduler
18
+ from kostyl.ml.metrics_formatting import apply_suffix
19
+ from kostyl.ml.schedulers.base import BaseScheduler
21
20
  from kostyl.utils import setup_logger
22
21
 
23
22
 
@@ -28,7 +27,6 @@ class KostylLightningModule(L.LightningModule):
28
27
  """Custom PyTorch Lightning Module with logging, checkpointing, and distributed training utilities."""
29
28
 
30
29
  model: PreTrainedModel | nn.Module | None
31
- hyperparams: HyperparamsConfig
32
30
 
33
31
  def get_process_group(self) -> ProcessGroup | None:
34
32
  """
@@ -70,6 +68,11 @@ class KostylLightningModule(L.LightningModule):
70
68
  return model.config # type: ignore
71
69
  return None
72
70
 
71
+ @property
72
+ def grad_clip_val(self) -> float | None:
73
+ """Returns the gradient clipping value from hyperparameters if set."""
74
+ raise NotImplementedError
75
+
73
76
  @override
74
77
  def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
75
78
  model = self.get_model()
@@ -93,20 +96,16 @@ class KostylLightningModule(L.LightningModule):
93
96
  def on_before_optimizer_step(self, optimizer) -> None:
94
97
  if self.model is None:
95
98
  raise ValueError("Model must be configured before optimizer step.")
96
- if not hasattr(self, "hyperparams"):
97
- logger.warning_once("cannot clip gradients, hyperparams attr missing")
98
- return
99
- if self.hyperparams.grad_clip_val is None:
99
+
100
+ grad_clip_val = self.grad_clip_val
101
+ if grad_clip_val is None:
100
102
  return
101
103
 
102
104
  if not isinstance(self.trainer.strategy, FSDPStrategy):
103
- norm = torch.nn.utils.clip_grad_norm_(
104
- self.parameters(), self.hyperparams.grad_clip_val
105
- )
105
+ norm = torch.nn.utils.clip_grad_norm_(self.parameters(), grad_clip_val)
106
106
  else:
107
107
  module: FSDP = self.trainer.strategy.model # type: ignore
108
- norm = module.clip_grad_norm_(self.hyperparams.grad_clip_val)
109
-
108
+ norm = module.clip_grad_norm_(grad_clip_val)
110
109
  self.log(
111
110
  "grad_norm",
112
111
  norm,
@@ -1,4 +1,5 @@
1
1
  from pathlib import Path
2
+ from typing import Any
2
3
  from typing import cast
3
4
 
4
5
  import torch
@@ -28,6 +29,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
28
29
  config_key: str = "config",
29
30
  weights_prefix: str = "model.",
30
31
  should_log_incompatible_keys: bool = True,
32
+ **kwargs: Any,
31
33
  ) -> TModelInstance:
32
34
  """
33
35
  Load a model from a Lightning checkpoint file.
@@ -48,6 +50,7 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
48
50
  weights_prefix (str, optional): Prefix to strip from state dict keys. Defaults to "model.".
49
51
  If not empty and doesn't end with ".", a "." is appended.
50
52
  should_log_incompatible_keys (bool, optional): Whether to log incompatible keys. Defaults to True.
53
+ **kwargs: Additional keyword arguments to pass to the model loading method.
51
54
 
52
55
  Returns:
53
56
  TModelInstance: The loaded model instance.
@@ -75,10 +78,17 @@ class LightningCheckpointLoaderMixin(PreTrainedModel):
75
78
  )
76
79
 
77
80
  config_cls = cast(PretrainedConfig, type(cls.config_class))
78
- config = config_cls.from_dict(checkpoint_dict[config_key])
81
+ config_dict = checkpoint_dict[config_key]
82
+ config_dict.update(kwargs)
83
+ config = config_cls.from_dict(config_dict)
84
+
85
+ kwargs_for_model = {}
86
+ for key in kwargs:
87
+ if not hasattr(config, key):
88
+ kwargs_for_model[key] = kwargs[key]
79
89
 
80
90
  with torch.device("meta"):
81
- model = cls(config)
91
+ model = cls(config, **kwargs_for_model)
82
92
 
83
93
  if "peft_config" in checkpoint_dict:
84
94
  if PeftConfig is None:
@@ -0,0 +1,4 @@
1
+ from .tb_logger import setup_tb_logger
2
+
3
+
4
+ __all__ = ["setup_tb_logger"]
@@ -3,7 +3,7 @@ from shutil import rmtree
3
3
 
4
4
  from lightning.pytorch.loggers import TensorBoardLogger
5
5
 
6
- from kostyl.ml_core.dist_utils import is_main_process
6
+ from kostyl.ml.dist_utils import is_main_process
7
7
  from kostyl.utils.logging import setup_logger
8
8
 
9
9
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "kostyl-toolkit"
3
- version = "0.1.10"
3
+ version = "0.1.12"
4
4
  description = "Kickass Orchestration System for Training, Yielding & Logging "
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -1,60 +0,0 @@
1
- from pathlib import Path
2
- from typing import TypeVar
3
-
4
- from pydantic import BaseModel as PydanticBaseModel
5
-
6
- from kostyl.utils.fs import load_config
7
-
8
-
9
- TConfig = TypeVar("TConfig", bound=PydanticBaseModel)
10
-
11
-
12
- class ConfigLoadingMixin:
13
- """Pydantic mixin class providing basic configuration loading functionality."""
14
-
15
- @classmethod
16
- def from_file(
17
- cls: type[TConfig], # pyright: ignore
18
- path: str | Path,
19
- ) -> TConfig:
20
- """
21
- Create an instance of the class from a configuration file.
22
-
23
- Args:
24
- cls_: The class type to instantiate.
25
- path (str | Path): Path to the configuration file.
26
-
27
- Returns:
28
- An instance of the class created from the configuration file.
29
-
30
- """
31
- config = load_config(path)
32
- instance = cls.model_validate(config)
33
- return instance
34
-
35
- @classmethod
36
- def from_dict(
37
- cls: type[TConfig], # pyright: ignore
38
- state_dict: dict,
39
- ) -> TConfig:
40
- """
41
- Creates an instance from a dictionary.
42
-
43
- Args:
44
- cls_: The class type to instantiate.
45
- state_dict (dict): A dictionary representing the state of the
46
- class that must be validated and used for initialization.
47
-
48
- Returns:
49
- An initialized instance of the class based on the
50
- provided state dictionary.
51
-
52
- """
53
- instance = cls.model_validate(state_dict)
54
- return instance
55
-
56
-
57
- class KostylBaseModel(PydanticBaseModel, ConfigLoadingMixin):
58
- """A Pydantic model class with basic configuration loading functionality."""
59
-
60
- pass