kostyl-toolkit 0.1.41__py3-none-any.whl → 0.1.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -38,18 +38,18 @@ class FSDP1StrategyConfig(BaseModel):
38
38
  """Fully Sharded Data Parallel (FSDP) strategy configuration."""
39
39
 
40
40
  type: Literal["fsdp1"]
41
- param_dtype: DTYPE | None
42
- reduce_dtype: DTYPE | None
43
- buffer_dtype: DTYPE | None
41
+ param_dtype: DTYPE | None = None
42
+ reduce_dtype: DTYPE | None = None
43
+ buffer_dtype: DTYPE | None = None
44
44
 
45
45
 
46
46
  class FSDP2StrategyConfig(BaseModel):
47
47
  """Fully Sharded Data Parallel (FSDP) strategy configuration."""
48
48
 
49
49
  type: Literal["fsdp2"]
50
- param_dtype: DTYPE | None
51
- reduce_dtype: DTYPE | None
52
- buffer_dtype: DTYPE | None
50
+ param_dtype: DTYPE | None = None
51
+ reduce_dtype: DTYPE | None = None
52
+ buffer_dtype: DTYPE | None = None
53
53
 
54
54
 
55
55
  class DDPStrategyConfig(BaseModel):
@@ -59,12 +59,20 @@ class DDPStrategyConfig(BaseModel):
59
59
  find_unused_parameters: bool = False
60
60
 
61
61
 
62
+ SUPPORTED_STRATEGIES = (
63
+ FSDP1StrategyConfig
64
+ | FSDP2StrategyConfig
65
+ | SingleDeviceStrategyConfig
66
+ | DDPStrategyConfig
67
+ )
68
+
69
+
62
70
  class LightningTrainerParameters(BaseModel):
63
71
  """Lightning Trainer parameters configuration."""
64
72
 
65
73
  accelerator: str
66
74
  max_epochs: int
67
- strategy: FSDP1StrategyConfig | SingleDeviceStrategyConfig | DDPStrategyConfig
75
+ strategy: SUPPORTED_STRATEGIES
68
76
  val_check_interval: int | float
69
77
  devices: list[int] | int
70
78
  precision: PRECISION
@@ -93,10 +93,10 @@ class BatchCollatorWithKeyAlignment:
93
93
  if new_key is None:
94
94
  continue
95
95
  value = item[k]
96
- if self.max_length is not None and new_key in (
96
+ if self.max_length is not None and new_key in {
97
97
  "input_ids",
98
98
  "attention_mask",
99
- ):
99
+ }:
100
100
  value = self._truncate_data(new_key, value)
101
101
  new_item[new_key] = value
102
102
  aligned_batch.append(new_item)
@@ -1,8 +1,11 @@
1
- from typing import Any
1
+ from typing import TypedDict
2
+ from typing import Unpack
2
3
 
3
4
  from torch.optim import Optimizer
5
+ from torch.optim.optimizer import ParamsT
4
6
 
5
7
  from kostyl.ml.configs import OPTIMIZER_CONFIG
8
+ from kostyl.ml.configs import SCHEDULER
6
9
  from kostyl.ml.configs import AdamConfig
7
10
  from kostyl.ml.configs import AdamWithPrecisionConfig
8
11
  from kostyl.ml.configs import MuonConfig
@@ -18,6 +21,17 @@ from .schedulers import PlateauWithAnnealingScheduler
18
21
  logger = setup_logger(fmt="only_message")
19
22
 
20
23
 
24
+ class OVERRIDABLE_CONFIG_KWARGS(TypedDict, total=False): # noqa: D101, N801
25
+ scheduler_type: SCHEDULER | None
26
+
27
+ freeze_ratio: float | None
28
+ warmup_ratio: float | None
29
+ warmup_value: float | None
30
+ base_value: float | None
31
+ final_value: float | None
32
+ plateau_ratio: float | None
33
+
34
+
21
35
  def create_scheduler(
22
36
  config: ScheduledParamConfig,
23
37
  param_group_field: str,
@@ -27,6 +41,7 @@ def create_scheduler(
27
41
  skip_if_zero: bool = False,
28
42
  apply_if_field: str | None = None,
29
43
  ignore_if_field: str | None = None,
44
+ **kwargs: Unpack[OVERRIDABLE_CONFIG_KWARGS],
30
45
  ) -> LinearScheduler | CosineScheduler | PlateauWithAnnealingScheduler:
31
46
  """
32
47
  Converts a ScheduledParamConfig to a scheduler instance.
@@ -41,37 +56,48 @@ def create_scheduler(
41
56
  Default is False.
42
57
  apply_if_field: Require this key to be present in a param group before updating.
43
58
  ignore_if_field: Skip groups that declare this key in their dictionaries.
59
+ **kwargs: Optional overrides for scheduler configuration parameters (e.g., base_value,
60
+ final_value, warmup_ratio, etc.). These overrides take precedence over the values
61
+ provided in the `config` object.
44
62
 
45
63
  Returns:
46
64
  A scheduler instance based on the configuration.
47
65
 
48
66
  """
49
- if config.scheduler_type is None:
67
+ scheduler_type = kwargs.get("scheduler_type", config.scheduler_type)
68
+ base_value = kwargs.get("base_value", config.base_value)
69
+ final_value = kwargs.get("final_value", config.final_value)
70
+ warmup_ratio = kwargs.get("warmup_ratio", config.warmup_ratio)
71
+ warmup_value = kwargs.get("warmup_value", config.warmup_value)
72
+ freeze_ratio = kwargs.get("freeze_ratio", config.freeze_ratio)
73
+ plateau_ratio = kwargs.get("plateau_ratio", config.plateau_ratio)
74
+
75
+ if scheduler_type is None:
50
76
  raise ValueError("scheduler_type must be specified in the config.")
51
77
 
52
- if "plateau" in config.scheduler_type:
53
- scheduler_type = "plateau"
78
+ if "plateau" in scheduler_type:
79
+ lookup_scheduler_type = "plateau"
54
80
  else:
55
- scheduler_type = config.scheduler_type
56
- scheduler_cls = SCHEDULER_MAPPING[scheduler_type] # type: ignore
81
+ lookup_scheduler_type = scheduler_type
82
+ scheduler_cls = SCHEDULER_MAPPING[lookup_scheduler_type] # type: ignore
57
83
 
58
84
  if issubclass(scheduler_cls, PlateauWithAnnealingScheduler):
59
- if "cosine" in config.scheduler_type:
85
+ if "cosine" in scheduler_type:
60
86
  annealing_type = "cosine"
61
- elif "linear" in config.scheduler_type:
87
+ elif "linear" in scheduler_type:
62
88
  annealing_type = "linear"
63
89
  else:
64
- raise ValueError(f"Unknown annealing_type: {config.scheduler_type}")
90
+ raise ValueError(f"Unknown annealing_type: {scheduler_type}")
65
91
  scheduler = scheduler_cls(
66
92
  optimizer=optim,
67
93
  param_group_field=param_group_field,
68
94
  num_iters=num_iters,
69
- plateau_value=config.base_value,
70
- final_value=config.final_value, # type: ignore
71
- warmup_ratio=config.warmup_ratio,
72
- warmup_value=config.warmup_value,
73
- freeze_ratio=config.freeze_ratio,
74
- plateau_ratio=config.plateau_ratio, # type: ignore
95
+ plateau_value=base_value,
96
+ final_value=final_value, # type: ignore
97
+ warmup_ratio=warmup_ratio,
98
+ warmup_value=warmup_value,
99
+ freeze_ratio=freeze_ratio,
100
+ plateau_ratio=plateau_ratio, # type: ignore
75
101
  annealing_type=annealing_type,
76
102
  multiplier_field=multiplier_field,
77
103
  skip_if_zero=skip_if_zero,
@@ -83,8 +109,8 @@ def create_scheduler(
83
109
  optimizer=optim,
84
110
  param_group_field=param_group_field,
85
111
  num_iters=num_iters,
86
- initial_value=config.base_value,
87
- final_value=config.final_value, # type: ignore
112
+ initial_value=base_value,
113
+ final_value=final_value, # type: ignore
88
114
  multiplier_field=multiplier_field,
89
115
  skip_if_zero=skip_if_zero,
90
116
  apply_if_field=apply_if_field,
@@ -95,23 +121,23 @@ def create_scheduler(
95
121
  optimizer=optim,
96
122
  param_group_field=param_group_field,
97
123
  num_iters=num_iters,
98
- base_value=config.base_value,
99
- final_value=config.final_value, # type: ignore
100
- warmup_ratio=config.warmup_ratio,
101
- warmup_value=config.warmup_value,
102
- freeze_ratio=config.freeze_ratio,
124
+ base_value=base_value,
125
+ final_value=final_value, # type: ignore
126
+ warmup_ratio=warmup_ratio,
127
+ warmup_value=warmup_value,
128
+ freeze_ratio=freeze_ratio,
103
129
  multiplier_field=multiplier_field,
104
130
  skip_if_zero=skip_if_zero,
105
131
  apply_if_field=apply_if_field,
106
132
  ignore_if_field=ignore_if_field,
107
133
  )
108
134
  else:
109
- raise ValueError(f"Unsupported scheduler type: {config.scheduler_type}")
135
+ raise ValueError(f"Unsupported scheduler type: {scheduler_type}")
110
136
  return scheduler
111
137
 
112
138
 
113
139
  def create_optimizer( # noqa: C901
114
- parameters_groups: dict[str, Any],
140
+ parameters_groups: ParamsT,
115
141
  optimizer_config: OPTIMIZER_CONFIG,
116
142
  lr: float,
117
143
  weight_decay: float,
@@ -120,8 +146,7 @@ def create_optimizer( # noqa: C901
120
146
  Creates an optimizer based on the configuration.
121
147
 
122
148
  Args:
123
- parameters_groups: Dictionary containing model parameters
124
- (key "params" and per-group options, i.e. "lr", "weight_decay" and etc.).
149
+ parameters_groups: Parameter groups for the optimizer.
125
150
  optimizer_config: Configuration for the optimizer.
126
151
  lr: Learning rate.
127
152
  weight_decay: Weight decay.
@@ -28,8 +28,8 @@ def create_params_groups(
28
28
  Defaults to None, which uses an empty set.
29
29
  no_decay_keywords (set[str] | None, optional): A set of string keywords. If a parameter's
30
30
  name contains any of these keywords, its weight decay is set to 0.0.
31
- If additional keywords are provided, they will be added to the default set.
32
- Defaults to None, which uses a standard set of exclusion keywords:
31
+ If keywords are provided, they will be added to the default set, otherwise the default set is used.
32
+ Default set of keywords:
33
33
  {"norm", "bias", "embedding", "tokenizer", "ln", "scale"}.
34
34
 
35
35
  Returns:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kostyl-toolkit
3
- Version: 0.1.41
3
+ Version: 0.1.43
4
4
  Summary: Kickass Orchestration System for Training, Yielding & Logging
5
5
  Requires-Dist: case-converter>=1.2.0
6
6
  Requires-Dist: loguru>=0.7.3
@@ -4,8 +4,8 @@ kostyl/ml/base_uploader.py,sha256=KxHuohCcNK18kTVFBBqDu_IOQefluhSXOzwC56O66wc,48
4
4
  kostyl/ml/configs/__init__.py,sha256=ytr2HFtsqMIs1pqQ-Ma0jv8c6Ni1-UxPnDTLu16cL68,1189
5
5
  kostyl/ml/configs/hyperparams.py,sha256=poOsR6xbE6jHO02GzxXfbz-6PA0HqULgcTgiuT_4oOM,4401
6
6
  kostyl/ml/configs/mixins.py,sha256=xHHAoRoPbzP9ECFP9duzg6SzegHcoLI8Pr9NrLoWNHs,1411
7
- kostyl/ml/configs/training_settings.py,sha256=wT9CHuLaKrLwonsc87Ee421EyFis_c9fqOgn9bSClm8,2747
8
- kostyl/ml/data_collator.py,sha256=kxiaMDKwSKXGBtrF8yXxHcypf7t_6syU-NwO1LcX50k,4062
7
+ kostyl/ml/configs/training_settings.py,sha256=U3K45eLVrHTOfZ7p9TqTKjGWd4sdjKmJOe3a6U9kUWw,2877
8
+ kostyl/ml/data_collator.py,sha256=NHfV5HcebHzg6iSfXj_cSUcBx9l8qP1HTbIPvqCDlTs,4062
9
9
  kostyl/ml/dist_utils.py,sha256=UFNMLEHc0A5F6KvTRG8GQPpRDwG4m5dvM__UvXNc2aQ,4526
10
10
  kostyl/ml/integrations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  kostyl/ml/integrations/clearml/__init__.py,sha256=cIJMmiTgRrBNlCwejs6x4DrJydv7zYRPqUCWpVgQJYo,970
@@ -25,18 +25,18 @@ kostyl/ml/integrations/lightning/mixins.py,sha256=hVIsIUu6Iryrz6S7GQTqog9vNq8LQy
25
25
  kostyl/ml/integrations/lightning/module.py,sha256=APodFpspSoBYmrtTnCpUFq3WFEvpmHp2hlWUxwwAChM,5219
26
26
  kostyl/ml/integrations/lightning/utils.py,sha256=QhbK5iTv07xkRKuyXiK05aoY_ObSGorzGu1WcFvvFtI,1712
27
27
  kostyl/ml/optim/__init__.py,sha256=GgCr94WbiM-7MVDf6fPlbeUSxS33ufn_UAmqfKerElc,140
28
- kostyl/ml/optim/factory.py,sha256=1WYbAOOGepVL55gSVCJs5WBsiiE4k4ayrS5svFh9kbE,9760
28
+ kostyl/ml/optim/factory.py,sha256=sPFIQkzIA4DKpny1hrcxHTKjEOXcbwLVy2CI9ZGZ938,10713
29
29
  kostyl/ml/optim/schedulers/__init__.py,sha256=lVwWj_mtUDo5OLtr8NNljLQgWGJUKPJWV7bm6-997ag,1486
30
30
  kostyl/ml/optim/schedulers/base.py,sha256=bjmwgdZpnSqpCnHPnKC6MEiRO79cwxMJpZq-eQVNs2M,1353
31
31
  kostyl/ml/optim/schedulers/composite.py,sha256=ee4xlMDMMtjKPkbTF2ue9GTr9DuGCGjZWf11mHbi6aE,2387
32
32
  kostyl/ml/optim/schedulers/cosine.py,sha256=LnFujAKaM8v5soPwiWG899nUMeaeAe8oTiaJqvoQxg8,8705
33
33
  kostyl/ml/optim/schedulers/linear.py,sha256=hlHfU88WLwfke8XGmRUzRKmEpqyyaLb4RRqq6EIzIxU,5881
34
34
  kostyl/ml/optim/schedulers/plateau.py,sha256=kjHivSp2FIMXExgS0q_Wbxk9pEHnuzhrIu4TQ59kYNw,11255
35
- kostyl/ml/params_groups.py,sha256=nUyw5d06Pvy9QPiYtZzLYR87xwXqJLxbHthgQH8oSCM,3583
35
+ kostyl/ml/params_groups.py,sha256=4jMoNKvWTfqyQkfFCvU75VRk_dJaVTgShqbuTOnhOYU,3565
36
36
  kostyl/utils/__init__.py,sha256=hkpmB6c5pr4Ti5BshOROebb7cvjDZfNCw83qZ_FFKMM,240
37
37
  kostyl/utils/dict_manipulations.py,sha256=e3vBicID74nYP8lHkVTQc4-IQwoJimrbFELy5uSF6Gk,1073
38
38
  kostyl/utils/fs.py,sha256=gAQNIU4R_2DhwjgzOS8BOMe0gZymtY1eZwmdgOdDgqo,510
39
39
  kostyl/utils/logging.py,sha256=CgNFNogcK0hoZmygvBWlTcq5A3m2Pfv9eOAP_gwx0pM,6633
40
- kostyl_toolkit-0.1.41.dist-info/WHEEL,sha256=e_m4S054HL0hyR3CpOk-b7Q7fDX6BuFkgL5OjAExXas,80
41
- kostyl_toolkit-0.1.41.dist-info/METADATA,sha256=RoZ8dXiU-CwUldPkhjY2achhvcQ2EE4RzhsJKTXe-AM,4121
42
- kostyl_toolkit-0.1.41.dist-info/RECORD,,
40
+ kostyl_toolkit-0.1.43.dist-info/WHEEL,sha256=fAguSjoiATBe7TNBkJwOjyL1Tt4wwiaQGtNtjRPNMQA,80
41
+ kostyl_toolkit-0.1.43.dist-info/METADATA,sha256=0Qt13h-ngr5NzIJd9m0NfELBQug8M5-YzzW0YBqjYXs,4121
42
+ kostyl_toolkit-0.1.43.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: uv 0.9.27
2
+ Generator: uv 0.9.28
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any