nshtrainer 0.18.2__py3-none-any.whl → 0.19.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_hf_hub.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import io
2
2
  import logging
3
3
  import os
4
+ import re
4
5
  from pathlib import Path
5
6
  from typing import TYPE_CHECKING, Any, cast
6
7
 
@@ -150,7 +151,32 @@ def _repo_name(api: "HfApi", root_config: "BaseConfig"):
150
151
  elif (username := api.whoami().get("name", None)) is None:
151
152
  raise ValueError("Could not get username from Hugging Face Hub.")
152
153
 
153
- return f"{username}/{root_config.project}-{root_config.run_name}-{root_config.id}"
154
+ # Sanitize the project (if it exists), run_name, and id
155
+ parts = []
156
+ if root_config.project:
157
+ parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.project))
158
+ parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.run_name))
159
+ parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.id))
160
+
161
+ # Combine parts and ensure it starts and ends with alphanumeric characters
162
+ repo_name = "-".join(parts)
163
+ repo_name = repo_name.strip("-")
164
+ repo_name = re.sub(
165
+ r"-+", "-", repo_name
166
+ ) # Replace multiple dashes with a single dash
167
+
168
+ # Ensure the name is not longer than 96 characters (excluding username)
169
+ if len(repo_name) > 96:
170
+ repo_name = repo_name[:96].rstrip("-")
171
+
172
+ # Ensure the repo name starts with an alphanumeric character
173
+ repo_name = re.sub(r"^[^a-zA-Z0-9]+", "", repo_name)
174
+
175
+ # If the repo_name is empty after all sanitization, use a default name
176
+ if not repo_name:
177
+ repo_name = "default-repo-name"
178
+
179
+ return f"{username}/{repo_name}"
154
180
 
155
181
 
156
182
  def _init(*, trainer: "Trainer", root_config: "BaseConfig"):
@@ -1,7 +1,6 @@
1
1
  import logging
2
- from collections import abc
3
- from collections.abc import Callable, Iterable
4
- from typing import Any, TypeAlias, cast, final
2
+ from collections.abc import Callable, Iterable, Sequence
3
+ from typing import Any, TypeAlias, cast, final, overload
5
4
 
6
5
  from lightning.pytorch import Callback, LightningModule
7
6
  from lightning.pytorch.callbacks import LambdaCallback
@@ -19,11 +18,61 @@ class CallbackRegistrarModuleMixin:
19
18
  def __init__(self, *args, **kwargs):
20
19
  super().__init__(*args, **kwargs)
21
20
 
22
- self._ll_callbacks: list[CallbackFn] = []
21
+ self._nshtrainer_callbacks: list[CallbackFn] = []
22
+
23
+ @overload
24
+ def register_callback(
25
+ self, callback: Callback | Iterable[Callback] | CallbackFn | None = None, /
26
+ ): ...
27
+
28
+ @overload
29
+ def register_callback(
30
+ self,
31
+ /,
32
+ *,
33
+ setup: Callable | None = None,
34
+ teardown: Callable | None = None,
35
+ on_fit_start: Callable | None = None,
36
+ on_fit_end: Callable | None = None,
37
+ on_sanity_check_start: Callable | None = None,
38
+ on_sanity_check_end: Callable | None = None,
39
+ on_train_batch_start: Callable | None = None,
40
+ on_train_batch_end: Callable | None = None,
41
+ on_train_epoch_start: Callable | None = None,
42
+ on_train_epoch_end: Callable | None = None,
43
+ on_validation_epoch_start: Callable | None = None,
44
+ on_validation_epoch_end: Callable | None = None,
45
+ on_test_epoch_start: Callable | None = None,
46
+ on_test_epoch_end: Callable | None = None,
47
+ on_validation_batch_start: Callable | None = None,
48
+ on_validation_batch_end: Callable | None = None,
49
+ on_test_batch_start: Callable | None = None,
50
+ on_test_batch_end: Callable | None = None,
51
+ on_train_start: Callable | None = None,
52
+ on_train_end: Callable | None = None,
53
+ on_validation_start: Callable | None = None,
54
+ on_validation_end: Callable | None = None,
55
+ on_test_start: Callable | None = None,
56
+ on_test_end: Callable | None = None,
57
+ on_exception: Callable | None = None,
58
+ on_save_checkpoint: Callable | None = None,
59
+ on_load_checkpoint: Callable | None = None,
60
+ on_before_backward: Callable | None = None,
61
+ on_after_backward: Callable | None = None,
62
+ on_before_optimizer_step: Callable | None = None,
63
+ on_before_zero_grad: Callable | None = None,
64
+ on_predict_start: Callable | None = None,
65
+ on_predict_end: Callable | None = None,
66
+ on_predict_batch_start: Callable | None = None,
67
+ on_predict_batch_end: Callable | None = None,
68
+ on_predict_epoch_start: Callable | None = None,
69
+ on_predict_epoch_end: Callable | None = None,
70
+ ): ...
23
71
 
24
72
  def register_callback(
25
73
  self,
26
74
  callback: Callback | Iterable[Callback] | CallbackFn | None = None,
75
+ /,
27
76
  *,
28
77
  setup: Callable | None = None,
29
78
  teardown: Callable | None = None,
@@ -109,7 +158,7 @@ class CallbackRegistrarModuleMixin:
109
158
  else:
110
159
  callback_ = callback
111
160
 
112
- self._ll_callbacks.append(callback_)
161
+ self._nshtrainer_callbacks.append(callback_)
113
162
 
114
163
 
115
164
  class CallbackModuleMixin(
@@ -136,7 +185,7 @@ class CallbackModuleMixin(
136
185
  @override
137
186
  def configure_callbacks(self):
138
187
  callbacks = super().configure_callbacks()
139
- if not isinstance(callbacks, abc.Sequence):
188
+ if not isinstance(callbacks, Sequence):
140
189
  callbacks = [callbacks]
141
190
 
142
191
  callbacks = list(callbacks)
@@ -145,7 +194,7 @@ class CallbackModuleMixin(
145
194
  if callback_result is None:
146
195
  continue
147
196
 
148
- if not isinstance(callback_result, abc.Iterable):
197
+ if not isinstance(callback_result, Iterable):
149
198
  callback_result = [callback_result]
150
199
 
151
200
  for callback in callback_result:
@@ -4,15 +4,19 @@ from typing import Annotated, Literal
4
4
  import nshconfig as C
5
5
  import torch
6
6
  import torch.nn as nn
7
- from typing_extensions import override
7
+ import torch.nn.functional as F
8
+ from typing_extensions import final, override
8
9
 
9
10
 
10
11
  class BaseNonlinearityConfig(C.Config, ABC):
11
12
  @abstractmethod
12
- def create_module(self) -> nn.Module:
13
- pass
13
+ def create_module(self) -> nn.Module: ...
14
+
15
+ @abstractmethod
16
+ def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
14
17
 
15
18
 
19
+ @final
16
20
  class ReLUNonlinearityConfig(BaseNonlinearityConfig):
17
21
  name: Literal["relu"] = "relu"
18
22
 
@@ -20,7 +24,11 @@ class ReLUNonlinearityConfig(BaseNonlinearityConfig):
20
24
  def create_module(self) -> nn.Module:
21
25
  return nn.ReLU()
22
26
 
27
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
28
+ return F.relu(x)
23
29
 
30
+
31
+ @final
24
32
  class SigmoidNonlinearityConfig(BaseNonlinearityConfig):
25
33
  name: Literal["sigmoid"] = "sigmoid"
26
34
 
@@ -28,7 +36,11 @@ class SigmoidNonlinearityConfig(BaseNonlinearityConfig):
28
36
  def create_module(self) -> nn.Module:
29
37
  return nn.Sigmoid()
30
38
 
39
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
40
+ return torch.sigmoid(x)
41
+
31
42
 
43
+ @final
32
44
  class TanhNonlinearityConfig(BaseNonlinearityConfig):
33
45
  name: Literal["tanh"] = "tanh"
34
46
 
@@ -36,23 +48,44 @@ class TanhNonlinearityConfig(BaseNonlinearityConfig):
36
48
  def create_module(self) -> nn.Module:
37
49
  return nn.Tanh()
38
50
 
51
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
52
+ return torch.tanh(x)
53
+
39
54
 
55
+ @final
40
56
  class SoftmaxNonlinearityConfig(BaseNonlinearityConfig):
41
57
  name: Literal["softmax"] = "softmax"
42
58
 
59
+ dim: int = -1
60
+ """The dimension to apply the softmax function."""
61
+
43
62
  @override
44
63
  def create_module(self) -> nn.Module:
45
- return nn.Softmax(dim=1)
64
+ return nn.Softmax(dim=self.dim)
65
+
66
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
67
+ return torch.softmax(x, dim=self.dim)
46
68
 
47
69
 
70
+ @final
48
71
  class SoftplusNonlinearityConfig(BaseNonlinearityConfig):
49
72
  name: Literal["softplus"] = "softplus"
50
73
 
74
+ beta: float = 1.0
75
+ """The beta parameter in the softplus function."""
76
+
77
+ threshold: float = 20.0
78
+ """Values above this revert to a linear function."""
79
+
51
80
  @override
52
81
  def create_module(self) -> nn.Module:
53
- return nn.Softplus()
82
+ return nn.Softplus(beta=self.beta, threshold=self.threshold)
83
+
84
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
85
+ return F.softplus(x, beta=self.beta, threshold=self.threshold)
54
86
 
55
87
 
88
+ @final
56
89
  class SoftsignNonlinearityConfig(BaseNonlinearityConfig):
57
90
  name: Literal["softsign"] = "softsign"
58
91
 
@@ -60,44 +93,78 @@ class SoftsignNonlinearityConfig(BaseNonlinearityConfig):
60
93
  def create_module(self) -> nn.Module:
61
94
  return nn.Softsign()
62
95
 
96
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
97
+ return F.softsign(x)
98
+
63
99
 
100
+ @final
64
101
  class ELUNonlinearityConfig(BaseNonlinearityConfig):
65
102
  name: Literal["elu"] = "elu"
66
103
 
104
+ alpha: float = 1.0
105
+ """The alpha parameter in the ELU function."""
106
+
67
107
  @override
68
108
  def create_module(self) -> nn.Module:
69
109
  return nn.ELU()
70
110
 
111
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
112
+ return F.elu(x, alpha=self.alpha)
113
+
71
114
 
115
+ @final
72
116
  class LeakyReLUNonlinearityConfig(BaseNonlinearityConfig):
73
117
  name: Literal["leaky_relu"] = "leaky_relu"
74
118
 
75
- negative_slope: float | None = None
119
+ negative_slope: float = 1.0e-2
120
+ """The negative slope of the leaky ReLU function."""
76
121
 
77
122
  @override
78
123
  def create_module(self) -> nn.Module:
79
- kwargs = {}
80
- if self.negative_slope is not None:
81
- kwargs["negative_slope"] = self.negative_slope
82
- return nn.LeakyReLU(**kwargs)
124
+ return nn.LeakyReLU(negative_slope=self.negative_slope)
125
+
126
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
127
+ return F.leaky_relu(x, negative_slope=self.negative_slope)
83
128
 
84
129
 
130
+ @final
85
131
  class PReLUConfig(BaseNonlinearityConfig):
86
132
  name: Literal["prelu"] = "prelu"
87
133
 
134
+ num_parameters: int = 1
135
+ """The number of :math:`a` to learn.
136
+ Although it takes an int as input, there is only two values are legitimate:
137
+ 1, or the number of channels at input."""
138
+
139
+ init: float = 0.25
140
+ """The initial value of :math:`a`."""
141
+
88
142
  @override
89
143
  def create_module(self) -> nn.Module:
90
- return nn.PReLU()
144
+ return nn.PReLU(num_parameters=self.num_parameters, init=self.init)
145
+
146
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
147
+ raise NotImplementedError(
148
+ "PReLU requires learnable parameters and cannot be called directly."
149
+ )
91
150
 
92
151
 
152
+ @final
93
153
  class GELUNonlinearityConfig(BaseNonlinearityConfig):
94
154
  name: Literal["gelu"] = "gelu"
95
155
 
156
+ approximate: Literal["tanh", "none"] = "none"
157
+ """The gelu approximation algorithm to use."""
158
+
96
159
  @override
97
160
  def create_module(self) -> nn.Module:
98
- return nn.GELU()
161
+ return nn.GELU(approximate=self.approximate)
162
+
163
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
164
+ return F.gelu(x, approximate=self.approximate)
99
165
 
100
166
 
167
+ @final
101
168
  class SwishNonlinearityConfig(BaseNonlinearityConfig):
102
169
  name: Literal["swish"] = "swish"
103
170
 
@@ -105,7 +172,11 @@ class SwishNonlinearityConfig(BaseNonlinearityConfig):
105
172
  def create_module(self) -> nn.Module:
106
173
  return nn.SiLU()
107
174
 
175
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
176
+ return F.silu(x)
177
+
108
178
 
179
+ @final
109
180
  class SiLUNonlinearityConfig(BaseNonlinearityConfig):
110
181
  name: Literal["silu"] = "silu"
111
182
 
@@ -113,7 +184,11 @@ class SiLUNonlinearityConfig(BaseNonlinearityConfig):
113
184
  def create_module(self) -> nn.Module:
114
185
  return nn.SiLU()
115
186
 
187
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
188
+ return F.silu(x)
116
189
 
190
+
191
+ @final
117
192
  class MishNonlinearityConfig(BaseNonlinearityConfig):
118
193
  name: Literal["mish"] = "mish"
119
194
 
@@ -121,6 +196,9 @@ class MishNonlinearityConfig(BaseNonlinearityConfig):
121
196
  def create_module(self) -> nn.Module:
122
197
  return nn.Mish()
123
198
 
199
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
200
+ return F.mish(x)
201
+
124
202
 
125
203
  class SwiGLU(nn.SiLU):
126
204
  @override
@@ -129,6 +207,7 @@ class SwiGLU(nn.SiLU):
129
207
  return input * super().forward(gate)
130
208
 
131
209
 
210
+ @final
132
211
  class SwiGLUNonlinearityConfig(BaseNonlinearityConfig):
133
212
  name: Literal["swiglu"] = "swiglu"
134
213
 
@@ -136,6 +215,10 @@ class SwiGLUNonlinearityConfig(BaseNonlinearityConfig):
136
215
  def create_module(self) -> nn.Module:
137
216
  return SwiGLU()
138
217
 
218
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
219
+ input, gate = x.chunk(2, dim=-1)
220
+ return input * F.silu(gate)
221
+
139
222
 
140
223
  NonlinearityConfig = Annotated[
141
224
  ReLUNonlinearityConfig
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.18.2
3
+ Version: 0.19.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -3,7 +3,7 @@ nshtrainer/_checkpoint/loader.py,sha256=myFObRsPdb8jBncMK73vjr5FDJIfKhF86Ec_kSjX
3
3
  nshtrainer/_checkpoint/metadata.py,sha256=p5e7dhVPpOGrXeuesq_7Y_RHi5lguzDAR_UXtMJXzWU,5175
4
4
  nshtrainer/_checkpoint/saver.py,sha256=DkbCH0YeOJ71m32vAARiQdGBf0hvwwdoAV8LOFGy-0Y,1428
5
5
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
6
- nshtrainer/_hf_hub.py,sha256=Py9_8ADvMCFPaJzeE7bxm8Mgs3mEMkyWJ4pDEccTGt8,11230
6
+ nshtrainer/_hf_hub.py,sha256=To3BnnGWbMNNMBdzVtgrNOcNU2fi1dQpwwuclusFAbI,12169
7
7
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
8
8
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
9
9
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
@@ -58,7 +58,7 @@ nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ
58
58
  nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
59
59
  nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
60
60
  nshtrainer/model/config.py,sha256=147uV7IukvuYE4G_ZuQNxVjnlog1BdCrAVbcj_sx9Vs,43104
61
- nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
61
+ nshtrainer/model/modules/callback.py,sha256=thhlJaqLRw2gwvb3Z6DJ8Kk8XUxKhinU_8ad30vne34,8541
62
62
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
63
63
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
64
64
  nshtrainer/model/modules/logger.py,sha256=CJWSmNT8SV5GLtfml-qGYenqRPXcNOMsJRGEavAd8Hw,5464
@@ -69,7 +69,7 @@ nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,142
69
69
  nshtrainer/nn/mlp.py,sha256=V0FrScpIUdg_IgIO8GMtIsGEtmHjwF14i2IWxmZrsqg,5952
70
70
  nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,2329
71
71
  nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
72
- nshtrainer/nn/nonlinearity.py,sha256=owtU4kh4G98psD0axOJWVfBhm-OtJVgFM-TXSHmbNPU,3625
72
+ nshtrainer/nn/nonlinearity.py,sha256=4sYE4MN5zojc-go1k0PYtqssVRuXrM7D4tbpIXp5K-E,6078
73
73
  nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
74
74
  nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
75
75
  nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
@@ -85,6 +85,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
85
85
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
86
86
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
87
87
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
88
- nshtrainer-0.18.2.dist-info/METADATA,sha256=vev96DaxCnqJOAvvGrGOJ37OpWNFLrCdtGPN-kpnvO4,935
89
- nshtrainer-0.18.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
90
- nshtrainer-0.18.2.dist-info/RECORD,,
88
+ nshtrainer-0.19.1.dist-info/METADATA,sha256=NMPSdeNqcMnyB9UiQ-4f-MdhBZ_RmCAPCYcYCCvjyYI,935
89
+ nshtrainer-0.19.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
90
+ nshtrainer-0.19.1.dist-info/RECORD,,