kostyl-toolkit 0.1.38__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
+ from .hyperparams import OPTIMIZER
1
2
  from .hyperparams import HyperparamsConfig
2
3
  from .hyperparams import Lr
3
- from .hyperparams import Optimizer
4
4
  from .hyperparams import WeightDecay
5
5
  from .mixins import ConfigLoadingMixin
6
6
  from .training_settings import CheckpointConfig
@@ -14,6 +14,7 @@ from .training_settings import TrainingSettings
14
14
 
15
15
 
16
16
  __all__ = [
17
+ "OPTIMIZER",
17
18
  "CheckpointConfig",
18
19
  "ConfigLoadingMixin",
19
20
  "DDPStrategyConfig",
@@ -23,7 +24,6 @@ __all__ = [
23
24
  "HyperparamsConfig",
24
25
  "LightningTrainerParameters",
25
26
  "Lr",
26
- "Optimizer",
27
27
  "SingleDeviceStrategyConfig",
28
28
  "TrainingSettings",
29
29
  "WeightDecay",
@@ -11,11 +11,19 @@ logger = setup_logger(fmt="only_message")
11
11
 
12
12
 
13
13
  class AdamConfig(BaseModel):
14
- """AdamW optimizer hyperparameters configuration."""
14
+ """Adam optimizer hyperparameters configuration."""
15
15
 
16
- type: Literal["AdamW"] = "AdamW"
16
+ type: Literal["AdamW", "Adam"] = "AdamW"
17
17
  betas: tuple[float, float] = (0.9, 0.999)
18
- is_adamw: bool = True
18
+
19
+
20
+ class MuonConfig(BaseModel):
21
+ """Muon optimizer hyperparameters configuration."""
22
+
23
+ type: Literal["Muon"]
24
+ nesterov: bool = True
25
+ ns_coefficients: tuple[float, float, float] = (3.4445, -4.7750, 2.0315)
26
+ ns_steps: int = 5
19
27
 
20
28
 
21
29
  class AdamWithPrecisionConfig(BaseModel):
@@ -28,67 +36,85 @@ class AdamWithPrecisionConfig(BaseModel):
28
36
  is_adamw: bool = True
29
37
 
30
38
 
31
- Optimizer = AdamConfig | AdamWithPrecisionConfig
39
+ OPTIMIZER = AdamConfig | AdamWithPrecisionConfig | MuonConfig
40
+ SCHEDULER = Literal[
41
+ "linear",
42
+ "cosine",
43
+ "plateau-with-cosine-annealing",
44
+ "plateau-with-linear-annealing",
45
+ ]
32
46
 
33
47
 
34
48
  class Lr(BaseModel):
35
49
  """Learning rate hyperparameters configuration."""
36
50
 
37
- use_scheduler: bool = False
38
- warmup_iters_ratio: float | None = Field(
39
- default=None, gt=0, lt=1, validate_default=False
40
- )
51
+ scheduler_type: SCHEDULER | None = None
52
+
53
+ freeze_ratio: float | None = Field(default=None, ge=0, le=1)
54
+ warmup_ratio: float | None = Field(default=None, gt=0, lt=1, validate_default=False)
41
55
  warmup_value: float | None = Field(default=None, gt=0, validate_default=False)
42
56
  base_value: float
43
57
  final_value: float | None = Field(default=None, gt=0, validate_default=False)
58
+ plateau_ratio: float | None = Field(
59
+ default=None, gt=0, lt=1, validate_default=False
60
+ )
44
61
 
45
62
  @model_validator(mode="after")
46
- def validate_warmup(self) -> "Lr":
47
- """Validates the warmup parameters based on use_scheduler."""
48
- if (self.warmup_value is None) != (self.warmup_iters_ratio is None): # fmt: skip
49
- raise ValueError(
50
- "Both warmup_value and warmup_iters_ratio must be provided or neither"
51
- )
52
- if ((self.warmup_value is not None) or (self.warmup_iters_ratio is not None)) and not self.use_scheduler: # fmt: skip
63
+ def _validate_freeze_ratio(self) -> "Lr":
64
+ if self.scheduler_type is None and self.freeze_ratio is not None:
65
+ logger.warning("use_scheduler is False, freeze_ratio will be ignored.")
66
+ self.freeze_ratio = None
67
+ return self
68
+
69
+ @model_validator(mode="after")
70
+ def _validate_warmup(self) -> "Lr":
71
+ if ((self.warmup_value is not None) or (self.warmup_ratio is not None)) and self.scheduler_type is None: # fmt: skip
53
72
  logger.warning(
54
- "use_scheduler is False, warmup_value and warmup_iters_ratio will be ignored."
73
+ "scheduler_type is None, warmup_value and warmup_ratio will be ignored."
55
74
  )
56
75
  self.warmup_value = None
57
- self.warmup_iters_ratio = None
76
+ self.warmup_ratio = None
77
+ if (self.warmup_value is None) != (self.warmup_ratio is None): # fmt: skip
78
+ raise ValueError(
79
+ "Both warmup_value and warmup_ratio must be provided or neither"
80
+ )
58
81
  return self
59
82
 
60
83
  @model_validator(mode="after")
61
- def validate_final_value(self) -> "Lr":
62
- """Validates the final_value based on use_scheduler."""
63
- if self.use_scheduler and (self.final_value is None):
64
- raise ValueError("If use_scheduler is True, final_value must be provided.")
65
- if (not self.use_scheduler) and (self.final_value is not None):
84
+ def _validate_final_value(self) -> "Lr":
85
+ if (self.scheduler_type in {"linear"}) and (self.final_value is not None):
86
+ raise ValueError("If scheduler_type is 'linear', final_value must be None.")
87
+ if (self.scheduler_type is None) and (self.final_value is not None):
66
88
  logger.warning("use_scheduler is False, final_value will be ignored.")
67
89
  self.final_value = None
68
90
  return self
69
91
 
70
-
71
- class WeightDecay(BaseModel):
72
- """Weight decay hyperparameters configuration."""
73
-
74
- use_scheduler: bool = False
75
- base_value: float
76
- final_value: float | None = None
77
-
78
92
  @model_validator(mode="after")
79
- def validate_final_value(self) -> "WeightDecay":
80
- """Validates the final_value based on use_scheduler."""
81
- if self.use_scheduler and self.final_value is None:
82
- raise ValueError("If use_scheduler is True, final_value must be provided.")
83
- if not self.use_scheduler and self.final_value is not None:
84
- logger.warning("use_scheduler is False, final_value will be ignored.")
93
+ def _validate_plateau_ratio(self) -> "Lr":
94
+ if self.scheduler_type is not None:
95
+ if self.scheduler_type.startswith("plateau") and self.plateau_ratio is None:
96
+ raise ValueError(
97
+ "If scheduler_type is 'plateau-with-*', plateau_ratio must be provided."
98
+ )
99
+ if (
100
+ not self.scheduler_type.startswith("plateau")
101
+ and self.plateau_ratio is not None
102
+ ):
103
+ logger.warning(
104
+ "scheduler_type is not 'plateau-with-*', plateau_ratio will be ignored."
105
+ )
106
+ self.plateau_ratio = None
85
107
  return self
86
108
 
87
109
 
110
+ class WeightDecay(Lr):
111
+ """Weight decay hyperparameters configuration."""
112
+
113
+
88
114
  class HyperparamsConfig(BaseModel):
89
115
  """Model training hyperparameters configuration."""
90
116
 
91
117
  grad_clip_val: float | None = Field(default=None, gt=0, validate_default=False)
92
- optimizer: Optimizer
118
+ optimizer: OPTIMIZER
93
119
  lr: Lr
94
120
  weight_decay: WeightDecay
@@ -286,7 +286,7 @@ def setup_checkpoint_callback(
286
286
  ckpt_cfg: CheckpointConfig,
287
287
  checkpoint_uploader: ModelCheckpointUploader | None = None,
288
288
  upload_strategy: Literal["only-best", "every-checkpoint"] | None = None,
289
- remove_folder_if_exists: bool = True,
289
+ remove_folder_if_exists: bool = False,
290
290
  ) -> ModelCheckpointWithCheckpointUploader | ModelCheckpoint:
291
291
  """
292
292
  Create and configure a checkpoint callback for model saving.
@@ -11,14 +11,17 @@ logger = setup_logger()
11
11
 
12
12
 
13
13
  def setup_tb_logger(
14
- runs_dir: Path,
14
+ runs_dir: Path, remove_folder_if_exists: bool = False
15
15
  ) -> TensorBoardLogger:
16
16
  """Sets up a TensorBoardLogger for PyTorch Lightning."""
17
17
  if runs_dir.exists():
18
18
  if is_local_zero_rank():
19
19
  logger.warning(f"TensorBoard log directory {runs_dir} already exists.")
20
- rmtree(runs_dir)
21
- logger.warning(f"Removed existing TensorBoard log directory {runs_dir}.")
20
+ if remove_folder_if_exists:
21
+ rmtree(runs_dir)
22
+ logger.warning(
23
+ f"Removed existing TensorBoard log directory {runs_dir}."
24
+ )
22
25
  else:
23
26
  logger.info(f"Creating TensorBoard log directory {runs_dir}.")
24
27
  runs_dir.mkdir(parents=True, exist_ok=True)
@@ -4,16 +4,9 @@ import lightning as L
4
4
  import torch.distributed as dist
5
5
  from torch.distributed import ProcessGroup
6
6
 
7
- from kostyl.ml.configs import DDPStrategyConfig
8
- from kostyl.ml.configs import FSDP1StrategyConfig
9
- from kostyl.ml.configs import SingleDeviceStrategyConfig
10
7
  from kostyl.utils.logging import setup_logger
11
8
 
12
9
 
13
- TRAINING_STRATEGIES = (
14
- FSDP1StrategyConfig | DDPStrategyConfig | SingleDeviceStrategyConfig
15
- )
16
-
17
10
  logger = setup_logger()
18
11
 
19
12
 
@@ -13,21 +13,21 @@ class _LinearScheduleBase(BaseScheduler):
13
13
  self,
14
14
  param_name: str,
15
15
  num_iters: int,
16
- start_value: float,
16
+ base_value: float,
17
17
  final_value: float,
18
18
  ) -> None:
19
19
  self.param_name = param_name
20
20
  self.num_iters = num_iters
21
- self.start_value = start_value
21
+ self.base_value = base_value
22
22
  self.final_value = final_value
23
23
 
24
24
  self.scheduled_values: npt.NDArray[np.float64] = np.array([], dtype=np.float64)
25
- self.current_value_ = self.start_value
25
+ self.current_value_ = self.base_value
26
26
  return
27
27
 
28
28
  def _create_scheduler(self) -> None:
29
29
  self.scheduled_values = np.linspace(
30
- self.start_value, self.final_value, num=self.num_iters, dtype=np.float64
30
+ self.base_value, self.final_value, num=self.num_iters, dtype=np.float64
31
31
  )
32
32
  self._verify()
33
33
  return
@@ -68,7 +68,7 @@ class LinearScheduler(_LinearScheduleBase):
68
68
  optimizer: torch.optim.Optimizer,
69
69
  param_group_field: str,
70
70
  num_iters: int,
71
- start_value: float,
71
+ base_value: float,
72
72
  final_value: float,
73
73
  multiplier_field: str | None = None,
74
74
  skip_if_zero: bool = False,
@@ -82,7 +82,7 @@ class LinearScheduler(_LinearScheduleBase):
82
82
  optimizer: Optimizer whose param groups are updated in-place.
83
83
  param_group_field: Name of the field that receives the scheduled value.
84
84
  num_iters: Number of scheduler iterations before clamping at ``final_value``.
85
- start_value: Value used on the first iteration.
85
+ base_value: Value used on the first iteration.
86
86
  final_value: Value used once ``num_iters`` iterations are consumed.
87
87
  multiplier_field: Optional per-group multiplier applied to the scheduled value.
88
88
  skip_if_zero: Leave groups untouched when their target field equals zero.
@@ -98,7 +98,7 @@ class LinearScheduler(_LinearScheduleBase):
98
98
  super().__init__(
99
99
  param_name=param_group_field,
100
100
  num_iters=num_iters,
101
- start_value=start_value,
101
+ base_value=base_value,
102
102
  final_value=final_value,
103
103
  )
104
104
  self.param_group_field = param_group_field
@@ -1,17 +1,15 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.38
3
+ Version: 0.1.39
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
7
- Requires-Dist: case-converter>=1.2.0 ; extra == 'ml-core'
8
- Requires-Dist: clearml[s3]>=2.0.2 ; extra == 'ml-core'
9
- Requires-Dist: lightning>=2.5.6 ; extra == 'ml-core'
10
- Requires-Dist: pydantic>=2.12.4 ; extra == 'ml-core'
11
- Requires-Dist: torch>=2.9.1 ; extra == 'ml-core'
12
- Requires-Dist: transformers>=4.57.1 ; extra == 'ml-core'
7
+ Requires-Dist: case-converter>=1.2.0 ; extra == 'ml'
8
+ Requires-Dist: pydantic>=2.12.4 ; extra == 'ml'
9
+ Requires-Dist: torch>=2.9.1 ; extra == 'ml'
10
+ Requires-Dist: transformers>=4.57.1 ; extra == 'ml'
13
11
  Requires-Python: >=3.12
14
- Provides-Extra: ml-core
12
+ Provides-Extra: ml
15
13
  Description-Content-Type: text/markdown
16
14
 
17
15
  # Kostyl Toolkit
@@ -1,8 +1,8 @@
1
1
  kostyl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  kostyl/ml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  kostyl/ml/base_uploader.py,sha256=KxHuohCcNK18kTVFBBqDu_IOQefluhSXOzwC56O66wc,484
4
- kostyl/ml/configs/__init__.py,sha256=djYjLxA7riFcSibAKfWHns-BCESEPrqSz_ZY2rJO-cc,913
5
- kostyl/ml/configs/hyperparams.py,sha256=lvtbvOFEoTBAJug7FR35xMQdPLgDQjRoP2fyDP-jD7E,3305
4
+ kostyl/ml/configs/__init__.py,sha256=9AbM3ugsvc0bRKHA99DTRuUc2N7iSV9BH5FTDOp0cpw,913
5
+ kostyl/ml/configs/hyperparams.py,sha256=wU6NoZoTG9WlnU4z3g4ViUfjjp8fb8YWUS_2gAd5wsY,4131
6
6
  kostyl/ml/configs/mixins.py,sha256=xHHAoRoPbzP9ECFP9duzg6SzegHcoLI8Pr9NrLoWNHs,1411
7
7
  kostyl/ml/configs/training_settings.py,sha256=wT9CHuLaKrLwonsc87Ee421EyFis_c9fqOgn9bSClm8,2747
8
8
  kostyl/ml/data_collator.py,sha256=kxiaMDKwSKXGBtrF8yXxHcypf7t_6syU-NwO1LcX50k,4062
@@ -16,25 +16,25 @@ kostyl/ml/integrations/clearml/loading_utils.py,sha256=NAMmB9NTGCXCHh-bR_nrQZyqI
16
16
  kostyl/ml/integrations/clearml/version_utils.py,sha256=GBjIIZbH_itd5sj7XpvxjkyZwxxGOpEcQ3BiWaJTyq8,1210
17
17
  kostyl/ml/integrations/lightning/__init__.py,sha256=r96os8kTuKIAymx3k9Td1JBrO2PH7nQAWUC54NsY5yY,392
18
18
  kostyl/ml/integrations/lightning/callbacks/__init__.py,sha256=EnKkNwwNDZnEqKRlpY4FVrqP88ECPF6nlT2bSLUIKRk,194
19
- kostyl/ml/integrations/lightning/callbacks/checkpoint.py,sha256=SfcaQRkXviMUej0UgrfXcqMDlRKYaAN3rgYCMKI97Os,18433
19
+ kostyl/ml/integrations/lightning/callbacks/checkpoint.py,sha256=iJQKaQUzCMNza7SUKtCVDol_Cy3rZy9KfBKQ6kT6Swg,18434
20
20
  kostyl/ml/integrations/lightning/callbacks/early_stopping.py,sha256=D5nyjktCJ9XYAf28-kgXG8jORvXLl1N3nbDQnvValPM,615
21
21
  kostyl/ml/integrations/lightning/loggers/__init__.py,sha256=e51dszaoJbuzwBkbdugmuDsPldoSO4yaRgmZUg1Bdy0,71
22
- kostyl/ml/integrations/lightning/loggers/tb_logger.py,sha256=CpjlcEIT187cJXJgRYafqfzvcnwPgPaVZ0vLUflIr7k,899
22
+ kostyl/ml/integrations/lightning/loggers/tb_logger.py,sha256=jVzEArHr1bpO-HRdKjlLJ4BJnxomHPAuBKCdYQRHRn4,1023
23
23
  kostyl/ml/integrations/lightning/metrics_formatting.py,sha256=U6vdNENZLvp2dT1L3HqFKtXrHwGKoDXN93hvamPGHjM,1341
24
24
  kostyl/ml/integrations/lightning/mixins.py,sha256=hVIsIUu6Iryrz6S7GQTqog9vNq8LQyjJd2aoJ5Ws6KU,5253
25
25
  kostyl/ml/integrations/lightning/module.py,sha256=39hcVNZSGyj5tLpXyX8IoqMGWt5vf6-Bx5JnNJ2-Wag,5218
26
- kostyl/ml/integrations/lightning/utils.py,sha256=DhLy_3JA5VyMQkB1v6xxRxDNHfisjXFYVjuIKPpO81M,1967
26
+ kostyl/ml/integrations/lightning/utils.py,sha256=QhbK5iTv07xkRKuyXiK05aoY_ObSGorzGu1WcFvvFtI,1712
27
27
  kostyl/ml/params_groups.py,sha256=nUyw5d06Pvy9QPiYtZzLYR87xwXqJLxbHthgQH8oSCM,3583
28
28
  kostyl/ml/schedulers/__init__.py,sha256=VIo8MOP4w5Ll24XqFb3QGi2rKvys6c0dEFYPIdDoPlw,526
29
29
  kostyl/ml/schedulers/base.py,sha256=bjmwgdZpnSqpCnHPnKC6MEiRO79cwxMJpZq-eQVNs2M,1353
30
30
  kostyl/ml/schedulers/composite.py,sha256=ee4xlMDMMtjKPkbTF2ue9GTr9DuGCGjZWf11mHbi6aE,2387
31
31
  kostyl/ml/schedulers/cosine.py,sha256=y8ylrgVOkVcr2-ExoqqNW--tdDX88TBYPQCOppIf2_M,8685
32
- kostyl/ml/schedulers/linear.py,sha256=RnnnblRuRXP3LT03QVIHUaK2kNsiMP1AedrMoeyh3qk,5843
32
+ kostyl/ml/schedulers/linear.py,sha256=a-Dq5VYoM1Z7vBb4_2Np0MAOfRmO1QVZnOzEoY8nM6k,5834
33
33
  kostyl/ml/schedulers/plateau.py,sha256=N-hiostPtTR0W4xnEJYB_1dv0DRx39iufLkGUrSIoWE,11235
34
34
  kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
35
35
  kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
36
36
  kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
37
37
  kostyl/utils/logging.py,sha256=CgNFNogcK0hoZmygvBWlTcq5A3m2Pfv9eOAP_gwx0pM,6633
38
- kostyl_toolkit-0.1.38.dist-info/WHEEL,sha256=e_m4S054HL0hyR3CpOk-b7Q7fDX6BuFkgL5OjAExXas,80
39
- kostyl_toolkit-0.1.38.dist-info/METADATA,sha256=nz5AzlWjKBqh7OZCklk-efWZ1jVDihw3YrrpLyoII3k,4269
40
- kostyl_toolkit-0.1.38.dist-info/RECORD,,
38
+ kostyl_toolkit-0.1.39.dist-info/WHEEL,sha256=e_m4S054HL0hyR3CpOk-b7Q7fDX6BuFkgL5OjAExXas,80
39
+ kostyl_toolkit-0.1.39.dist-info/METADATA,sha256=9AA-uH9-jSciHx3xh4WGxHkzTyjEESi0AFRitTkIM5w,4136
40
+ kostyl_toolkit-0.1.39.dist-info/RECORD,,