nshtrainer 0.18.1__tar.gz → 0.19.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/PKG-INFO +1 -1
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/pyproject.toml +1 -1
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/_hf_hub.py +17 -1
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/nn/nonlinearity.py +95 -12
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/README.md +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/__init__.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/loggers/_base.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/model/config.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/model/modules/callback.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/model/modules/debug.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/model/modules/distributed.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/model/modules/logger.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/model/modules/profiler.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/trainer/trainer.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import io
|
|
1
2
|
import logging
|
|
2
3
|
import os
|
|
3
4
|
from pathlib import Path
|
|
@@ -299,7 +300,22 @@ def _save_checkpoint_files(
|
|
|
299
300
|
# Resolve the repository name
|
|
300
301
|
repo_name = _repo_name(api, root_config)
|
|
301
302
|
|
|
303
|
+
# Let's read all the files to memory right now,
|
|
304
|
+
# in case they get used/removed by other processes.
|
|
305
|
+
# Read all the files to memory
|
|
306
|
+
file_contents: list[bytes | None] = []
|
|
302
307
|
for p in paths:
|
|
308
|
+
try:
|
|
309
|
+
with open(p, "rb") as f:
|
|
310
|
+
file_contents.append(f.read())
|
|
311
|
+
except IOError as e:
|
|
312
|
+
log.warning(f"Failed to read checkpoint file {p}: {str(e)}")
|
|
313
|
+
file_contents.append(None)
|
|
314
|
+
|
|
315
|
+
for p, contents in zip(paths, file_contents):
|
|
316
|
+
if contents is None:
|
|
317
|
+
continue
|
|
318
|
+
|
|
303
319
|
try:
|
|
304
320
|
relative_path = p.relative_to(checkpoint_dir)
|
|
305
321
|
except ValueError:
|
|
@@ -314,7 +330,7 @@ def _save_checkpoint_files(
|
|
|
314
330
|
# Upload the checkpoint file to the repository
|
|
315
331
|
try:
|
|
316
332
|
api.upload_file(
|
|
317
|
-
path_or_fileobj=
|
|
333
|
+
path_or_fileobj=io.BytesIO(contents),
|
|
318
334
|
path_in_repo=str(path_in_repo),
|
|
319
335
|
repo_id=repo_name,
|
|
320
336
|
repo_type="model",
|
|
@@ -4,15 +4,19 @@ from typing import Annotated, Literal
|
|
|
4
4
|
import nshconfig as C
|
|
5
5
|
import torch
|
|
6
6
|
import torch.nn as nn
|
|
7
|
-
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from typing_extensions import final, override
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
class BaseNonlinearityConfig(C.Config, ABC):
|
|
11
12
|
@abstractmethod
|
|
12
|
-
def create_module(self) -> nn.Module:
|
|
13
|
-
|
|
13
|
+
def create_module(self) -> nn.Module: ...
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
|
14
17
|
|
|
15
18
|
|
|
19
|
+
@final
|
|
16
20
|
class ReLUNonlinearityConfig(BaseNonlinearityConfig):
|
|
17
21
|
name: Literal["relu"] = "relu"
|
|
18
22
|
|
|
@@ -20,7 +24,11 @@ class ReLUNonlinearityConfig(BaseNonlinearityConfig):
|
|
|
20
24
|
def create_module(self) -> nn.Module:
|
|
21
25
|
return nn.ReLU()
|
|
22
26
|
|
|
27
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
28
|
+
return F.relu(x)
|
|
23
29
|
|
|
30
|
+
|
|
31
|
+
@final
|
|
24
32
|
class SigmoidNonlinearityConfig(BaseNonlinearityConfig):
|
|
25
33
|
name: Literal["sigmoid"] = "sigmoid"
|
|
26
34
|
|
|
@@ -28,7 +36,11 @@ class SigmoidNonlinearityConfig(BaseNonlinearityConfig):
|
|
|
28
36
|
def create_module(self) -> nn.Module:
|
|
29
37
|
return nn.Sigmoid()
|
|
30
38
|
|
|
39
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
40
|
+
return torch.sigmoid(x)
|
|
41
|
+
|
|
31
42
|
|
|
43
|
+
@final
|
|
32
44
|
class TanhNonlinearityConfig(BaseNonlinearityConfig):
|
|
33
45
|
name: Literal["tanh"] = "tanh"
|
|
34
46
|
|
|
@@ -36,23 +48,44 @@ class TanhNonlinearityConfig(BaseNonlinearityConfig):
|
|
|
36
48
|
def create_module(self) -> nn.Module:
|
|
37
49
|
return nn.Tanh()
|
|
38
50
|
|
|
51
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
52
|
+
return torch.tanh(x)
|
|
53
|
+
|
|
39
54
|
|
|
55
|
+
@final
|
|
40
56
|
class SoftmaxNonlinearityConfig(BaseNonlinearityConfig):
|
|
41
57
|
name: Literal["softmax"] = "softmax"
|
|
42
58
|
|
|
59
|
+
dim: int = -1
|
|
60
|
+
"""The dimension to apply the softmax function."""
|
|
61
|
+
|
|
43
62
|
@override
|
|
44
63
|
def create_module(self) -> nn.Module:
|
|
45
|
-
return nn.Softmax(dim=
|
|
64
|
+
return nn.Softmax(dim=self.dim)
|
|
65
|
+
|
|
66
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
67
|
+
return torch.softmax(x, dim=self.dim)
|
|
46
68
|
|
|
47
69
|
|
|
70
|
+
@final
|
|
48
71
|
class SoftplusNonlinearityConfig(BaseNonlinearityConfig):
|
|
49
72
|
name: Literal["softplus"] = "softplus"
|
|
50
73
|
|
|
74
|
+
beta: float = 1.0
|
|
75
|
+
"""The beta parameter in the softplus function."""
|
|
76
|
+
|
|
77
|
+
threshold: float = 20.0
|
|
78
|
+
"""Values above this revert to a linear function."""
|
|
79
|
+
|
|
51
80
|
@override
|
|
52
81
|
def create_module(self) -> nn.Module:
|
|
53
|
-
return nn.Softplus()
|
|
82
|
+
return nn.Softplus(beta=self.beta, threshold=self.threshold)
|
|
83
|
+
|
|
84
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
85
|
+
return F.softplus(x, beta=self.beta, threshold=self.threshold)
|
|
54
86
|
|
|
55
87
|
|
|
88
|
+
@final
|
|
56
89
|
class SoftsignNonlinearityConfig(BaseNonlinearityConfig):
|
|
57
90
|
name: Literal["softsign"] = "softsign"
|
|
58
91
|
|
|
@@ -60,44 +93,78 @@ class SoftsignNonlinearityConfig(BaseNonlinearityConfig):
|
|
|
60
93
|
def create_module(self) -> nn.Module:
|
|
61
94
|
return nn.Softsign()
|
|
62
95
|
|
|
96
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
97
|
+
return F.softsign(x)
|
|
98
|
+
|
|
63
99
|
|
|
100
|
+
@final
|
|
64
101
|
class ELUNonlinearityConfig(BaseNonlinearityConfig):
|
|
65
102
|
name: Literal["elu"] = "elu"
|
|
66
103
|
|
|
104
|
+
alpha: float = 1.0
|
|
105
|
+
"""The alpha parameter in the ELU function."""
|
|
106
|
+
|
|
67
107
|
@override
|
|
68
108
|
def create_module(self) -> nn.Module:
|
|
69
109
|
return nn.ELU()
|
|
70
110
|
|
|
111
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
112
|
+
return F.elu(x, alpha=self.alpha)
|
|
113
|
+
|
|
71
114
|
|
|
115
|
+
@final
|
|
72
116
|
class LeakyReLUNonlinearityConfig(BaseNonlinearityConfig):
|
|
73
117
|
name: Literal["leaky_relu"] = "leaky_relu"
|
|
74
118
|
|
|
75
|
-
negative_slope: float
|
|
119
|
+
negative_slope: float = 1.0e-2
|
|
120
|
+
"""The negative slope of the leaky ReLU function."""
|
|
76
121
|
|
|
77
122
|
@override
|
|
78
123
|
def create_module(self) -> nn.Module:
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
return
|
|
124
|
+
return nn.LeakyReLU(negative_slope=self.negative_slope)
|
|
125
|
+
|
|
126
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
127
|
+
return F.leaky_relu(x, negative_slope=self.negative_slope)
|
|
83
128
|
|
|
84
129
|
|
|
130
|
+
@final
|
|
85
131
|
class PReLUConfig(BaseNonlinearityConfig):
|
|
86
132
|
name: Literal["prelu"] = "prelu"
|
|
87
133
|
|
|
134
|
+
num_parameters: int = 1
|
|
135
|
+
"""The number of :math:`a` to learn.
|
|
136
|
+
Although it takes an int as input, there is only two values are legitimate:
|
|
137
|
+
1, or the number of channels at input."""
|
|
138
|
+
|
|
139
|
+
init: float = 0.25
|
|
140
|
+
"""The initial value of :math:`a`."""
|
|
141
|
+
|
|
88
142
|
@override
|
|
89
143
|
def create_module(self) -> nn.Module:
|
|
90
|
-
return nn.PReLU()
|
|
144
|
+
return nn.PReLU(num_parameters=self.num_parameters, init=self.init)
|
|
145
|
+
|
|
146
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
147
|
+
raise NotImplementedError(
|
|
148
|
+
"PReLU requires learnable parameters and cannot be called directly."
|
|
149
|
+
)
|
|
91
150
|
|
|
92
151
|
|
|
152
|
+
@final
|
|
93
153
|
class GELUNonlinearityConfig(BaseNonlinearityConfig):
|
|
94
154
|
name: Literal["gelu"] = "gelu"
|
|
95
155
|
|
|
156
|
+
approximate: Literal["tanh", "none"] = "none"
|
|
157
|
+
"""The gelu approximation algorithm to use."""
|
|
158
|
+
|
|
96
159
|
@override
|
|
97
160
|
def create_module(self) -> nn.Module:
|
|
98
|
-
return nn.GELU()
|
|
161
|
+
return nn.GELU(approximate=self.approximate)
|
|
162
|
+
|
|
163
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
164
|
+
return F.gelu(x, approximate=self.approximate)
|
|
99
165
|
|
|
100
166
|
|
|
167
|
+
@final
|
|
101
168
|
class SwishNonlinearityConfig(BaseNonlinearityConfig):
|
|
102
169
|
name: Literal["swish"] = "swish"
|
|
103
170
|
|
|
@@ -105,7 +172,11 @@ class SwishNonlinearityConfig(BaseNonlinearityConfig):
|
|
|
105
172
|
def create_module(self) -> nn.Module:
|
|
106
173
|
return nn.SiLU()
|
|
107
174
|
|
|
175
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
176
|
+
return F.silu(x)
|
|
177
|
+
|
|
108
178
|
|
|
179
|
+
@final
|
|
109
180
|
class SiLUNonlinearityConfig(BaseNonlinearityConfig):
|
|
110
181
|
name: Literal["silu"] = "silu"
|
|
111
182
|
|
|
@@ -113,7 +184,11 @@ class SiLUNonlinearityConfig(BaseNonlinearityConfig):
|
|
|
113
184
|
def create_module(self) -> nn.Module:
|
|
114
185
|
return nn.SiLU()
|
|
115
186
|
|
|
187
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
188
|
+
return F.silu(x)
|
|
116
189
|
|
|
190
|
+
|
|
191
|
+
@final
|
|
117
192
|
class MishNonlinearityConfig(BaseNonlinearityConfig):
|
|
118
193
|
name: Literal["mish"] = "mish"
|
|
119
194
|
|
|
@@ -121,6 +196,9 @@ class MishNonlinearityConfig(BaseNonlinearityConfig):
|
|
|
121
196
|
def create_module(self) -> nn.Module:
|
|
122
197
|
return nn.Mish()
|
|
123
198
|
|
|
199
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
200
|
+
return F.mish(x)
|
|
201
|
+
|
|
124
202
|
|
|
125
203
|
class SwiGLU(nn.SiLU):
|
|
126
204
|
@override
|
|
@@ -129,6 +207,7 @@ class SwiGLU(nn.SiLU):
|
|
|
129
207
|
return input * super().forward(gate)
|
|
130
208
|
|
|
131
209
|
|
|
210
|
+
@final
|
|
132
211
|
class SwiGLUNonlinearityConfig(BaseNonlinearityConfig):
|
|
133
212
|
name: Literal["swiglu"] = "swiglu"
|
|
134
213
|
|
|
@@ -136,6 +215,10 @@ class SwiGLUNonlinearityConfig(BaseNonlinearityConfig):
|
|
|
136
215
|
def create_module(self) -> nn.Module:
|
|
137
216
|
return SwiGLU()
|
|
138
217
|
|
|
218
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
219
|
+
input, gate = x.chunk(2, dim=-1)
|
|
220
|
+
return input * F.silu(gate)
|
|
221
|
+
|
|
139
222
|
|
|
140
223
|
NonlinearityConfig = Annotated[
|
|
141
224
|
ReLUNonlinearityConfig
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py
RENAMED
|
File without changes
|
{nshtrainer-0.18.1 → nshtrainer-0.19.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|