lt-tensor 0.0.1a12__py3-none-any.whl → 0.0.1a14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/datasets/audio.py +141 -46
- lt_tensor/misc_utils.py +38 -1
- lt_tensor/model_zoo/__init__.py +18 -9
- lt_tensor/model_zoo/{bsc.py → basic.py} +118 -2
- lt_tensor/model_zoo/features.py +416 -0
- lt_tensor/model_zoo/fusion.py +164 -0
- lt_tensor/model_zoo/istft/generator.py +5 -65
- lt_tensor/model_zoo/istft/sg.py +142 -0
- lt_tensor/model_zoo/istft/trainer.py +227 -59
- lt_tensor/model_zoo/residual.py +252 -0
- lt_tensor/model_zoo/{tfrms.py → transformer.py} +2 -2
- lt_tensor/processors/audio.py +207 -80
- lt_tensor/transform.py +7 -16
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a14.dist-info}/METADATA +7 -5
- lt_tensor-0.0.1a14.dist-info/RECORD +32 -0
- lt_tensor/model_zoo/fsn.py +0 -67
- lt_tensor/model_zoo/gns.py +0 -185
- lt_tensor/model_zoo/istft.py +0 -591
- lt_tensor/model_zoo/rsd.py +0 -107
- lt_tensor-0.0.1a12.dist-info/RECORD +0 -32
- /lt_tensor/model_zoo/{disc.py → discriminator.py} +0 -0
- /lt_tensor/model_zoo/{pos.py → pos_encoder.py} +0 -0
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a14.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a14.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a14.dist-info}/top_level.txt +0 -0
lt_tensor/datasets/audio.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
__all__ = ["WaveMelDataset"]
|
2
2
|
from lt_tensor.torch_commons import *
|
3
3
|
from lt_utils.common import *
|
4
|
+
from lt_utils.misc_utils import default
|
4
5
|
import random
|
5
6
|
from torch.utils.data import Dataset, DataLoader, Sampler
|
6
7
|
from lt_tensor.processors import AudioProcessor
|
@@ -9,64 +10,131 @@ from lt_tensor.misc_utils import log_tensor
|
|
9
10
|
from tqdm import tqdm
|
10
11
|
|
11
12
|
|
13
|
+
DEFAULT_DEVICE = torch.tensor([0]).device
|
14
|
+
|
15
|
+
|
12
16
|
class WaveMelDataset(Dataset):
|
13
|
-
|
17
|
+
cached_data: Union[list[dict[str, Tensor]], Tuple[Tensor, Tensor]] = []
|
18
|
+
loaded_files: Dict[str, List[Dict[str, Tensor]]] = {}
|
19
|
+
normalize_waves: bool = False
|
20
|
+
randomize_ranges: bool = False
|
21
|
+
alpha_wv: float = 1.0
|
22
|
+
limit_files: Optional[int] = None
|
23
|
+
max_frame_length: Optional[int] = None
|
14
24
|
|
15
25
|
def __init__(
|
16
26
|
self,
|
17
27
|
audio_processor: AudioProcessor,
|
18
|
-
|
28
|
+
dataset_path: PathLike,
|
19
29
|
limit_files: Optional[int] = None,
|
20
30
|
max_frame_length: Optional[int] = None,
|
21
|
-
randomize_ranges: bool =
|
31
|
+
randomize_ranges: Optional[bool] = None,
|
32
|
+
pre_load: bool = False,
|
33
|
+
normalize_waves: Optional[bool] = None,
|
34
|
+
alpha_wv: Optional[float] = None,
|
35
|
+
n_noises: int = 0, # TODO: Implement the random noises into the dataset
|
22
36
|
):
|
23
37
|
super().__init__()
|
24
38
|
assert max_frame_length is None or max_frame_length >= (
|
25
39
|
(audio_processor.n_fft // 2) + 1
|
26
40
|
)
|
41
|
+
self.ap = audio_processor
|
42
|
+
self.dataset_path = dataset_path
|
43
|
+
if limit_files:
|
44
|
+
self.limit_files = limit_files
|
45
|
+
if normalize_waves is not None:
|
46
|
+
self.normalize_waves = normalize_waves
|
47
|
+
if alpha_wv is not None:
|
48
|
+
self.alpha_wv = alpha_wv
|
49
|
+
if pre_load is not None:
|
50
|
+
self.pre_loaded = pre_load
|
51
|
+
if randomize_ranges is not None:
|
52
|
+
self.randomize_ranges = randomize_ranges
|
27
53
|
|
28
54
|
self.post_n_fft = (audio_processor.n_fft // 2) + 1
|
55
|
+
|
29
56
|
if max_frame_length is not None:
|
57
|
+
max_frame_length = max(self.post_n_fft + 1, max_frame_length)
|
30
58
|
self.r_range = max(self.post_n_fft + 1, max_frame_length // 3)
|
31
|
-
|
32
|
-
|
59
|
+
self.max_frame_length = max_frame_length
|
60
|
+
|
61
|
+
self.files = self.ap.find_audios(dataset_path, maximum=None)
|
33
62
|
if limit_files:
|
34
63
|
random.shuffle(self.files)
|
35
|
-
self.files = self.files[:
|
36
|
-
|
64
|
+
self.files = self.files[-self.limit_files :]
|
65
|
+
if pre_load:
|
66
|
+
for file in tqdm(self.files, "Loading files"):
|
67
|
+
results = self.load_data(file)
|
68
|
+
if not results:
|
69
|
+
continue
|
70
|
+
self.cached_data.extend(results)
|
37
71
|
|
72
|
+
def renew_dataset(self, new_path: Optional[PathLike] = None):
|
73
|
+
new_path = default(new_path, self.dataset_path)
|
74
|
+
self.files = self.ap.find_audios(new_path, maximum=None)
|
75
|
+
random.shuffle(self.files)
|
38
76
|
for file in tqdm(self.files, "Loading files"):
|
39
|
-
results = self.load_data(file
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
return {"mel": audio_mel, "raw": audio_raw, "file": file}
|
77
|
+
results = self.load_data(file)
|
78
|
+
if not results:
|
79
|
+
continue
|
80
|
+
self.cached_data.extend(results)
|
44
81
|
|
45
|
-
def
|
82
|
+
def _add_dict(
|
46
83
|
self,
|
84
|
+
audio_wave: Tensor,
|
85
|
+
audio_mel: Tensor,
|
86
|
+
pitch: Tensor,
|
87
|
+
rms: Tensor,
|
47
88
|
file: PathLike,
|
48
|
-
audio_frames_limit: Optional[int] = None,
|
49
|
-
randomize_ranges: bool = False,
|
50
89
|
):
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
90
|
+
return {
|
91
|
+
"wave": audio_wave,
|
92
|
+
"pitch": pitch,
|
93
|
+
"rms": rms,
|
94
|
+
"mel": audio_mel,
|
95
|
+
"file": file,
|
96
|
+
}
|
97
|
+
|
98
|
+
def load_data(self, file: PathLike):
|
99
|
+
initial_audio = self.ap.normalize_audio(
|
100
|
+
self.ap.load_audio(
|
101
|
+
file, normalize=self.normalize_waves, alpha=self.alpha_wv
|
102
|
+
)
|
103
|
+
)
|
104
|
+
if initial_audio.shape[-1] < self.post_n_fft:
|
105
|
+
return None
|
106
|
+
|
107
|
+
if (
|
108
|
+
not self.max_frame_length
|
109
|
+
or initial_audio.shape[-1] <= self.max_frame_length
|
110
|
+
):
|
111
|
+
audio_rms = self.ap.compute_rms(initial_audio)
|
112
|
+
audio_pitch = self.ap.compute_pitch(initial_audio)
|
55
113
|
audio_mel = self.ap.compute_mel(initial_audio, add_base=True)
|
56
|
-
return [
|
114
|
+
return [
|
115
|
+
self._add_dict(initial_audio, audio_mel, audio_pitch, audio_rms, file)
|
116
|
+
]
|
57
117
|
results = []
|
58
|
-
|
59
|
-
|
118
|
+
|
119
|
+
if self.randomize_ranges:
|
120
|
+
frame_limit = random.randint(self.r_range, self.max_frame_length)
|
60
121
|
else:
|
61
|
-
frame_limit =
|
62
|
-
|
63
|
-
|
64
|
-
|
122
|
+
frame_limit = self.max_frame_length
|
123
|
+
|
124
|
+
fragments = list(
|
125
|
+
torch.split(initial_audio, split_size_or_sections=frame_limit, dim=-1)
|
126
|
+
)
|
127
|
+
random.shuffle(fragments)
|
128
|
+
for fragment in fragments:
|
65
129
|
if fragment.shape[-1] < self.post_n_fft:
|
66
|
-
#
|
130
|
+
# Too small
|
67
131
|
continue
|
132
|
+
audio_rms = self.ap.compute_rms(fragment)
|
133
|
+
audio_pitch = self.ap.compute_pitch(fragment)
|
68
134
|
audio_mel = self.ap.compute_mel(fragment, add_base=True)
|
69
|
-
results.append(
|
135
|
+
results.append(
|
136
|
+
self._add_dict(fragment, audio_mel, audio_pitch, audio_rms, file)
|
137
|
+
)
|
70
138
|
return results
|
71
139
|
|
72
140
|
def get_data_loader(
|
@@ -93,31 +161,58 @@ class WaveMelDataset(Dataset):
|
|
93
161
|
collate_fn=self.collate_fn,
|
94
162
|
)
|
95
163
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
164
|
+
def collate_fn(self, batch: Sequence[Dict[str, Tensor]]):
|
165
|
+
mel = []
|
166
|
+
wave = []
|
167
|
+
file = []
|
168
|
+
rms = []
|
169
|
+
pitch = []
|
101
170
|
for x in batch:
|
102
|
-
|
103
|
-
|
104
|
-
|
171
|
+
mel.append(x["mel"])
|
172
|
+
wave.append(x["wave"])
|
173
|
+
file.append(x["file"])
|
174
|
+
rms.append(x["rms"])
|
175
|
+
pitch.append(x["pitch"])
|
105
176
|
# Find max time in mel (dim -1), and max audio length
|
106
|
-
max_mel_len = max([m.shape[-1] for m in
|
107
|
-
max_audio_len = max([a.shape[-1] for a in
|
177
|
+
max_mel_len = max([m.shape[-1] for m in mel])
|
178
|
+
max_audio_len = max([a.shape[-1] for a in wave])
|
179
|
+
max_pitch_len = max([a.shape[-1] for a in pitch])
|
180
|
+
max_rms_len = max([a.shape[-1] for a in rms])
|
108
181
|
|
109
|
-
|
110
|
-
[FT.pad(m, (0, max_mel_len - m.shape[-1])) for m in
|
182
|
+
padded_mel = torch.stack(
|
183
|
+
[FT.pad(m, (0, max_mel_len - m.shape[-1])) for m in mel]
|
111
184
|
) # shape: [B, 80, T_max]
|
112
185
|
|
113
|
-
|
114
|
-
[FT.pad(a, (0, max_audio_len - a.shape[-1])) for a in
|
186
|
+
padded_wave = torch.stack(
|
187
|
+
[FT.pad(a, (0, max_audio_len - a.shape[-1])) for a in wave]
|
188
|
+
) # shape: [B, L_max]
|
189
|
+
|
190
|
+
padded_pitch = torch.stack(
|
191
|
+
[FT.pad(a, (0, max_pitch_len - a.shape[-1])) for a in pitch]
|
115
192
|
) # shape: [B, L_max]
|
193
|
+
padded_rms = torch.stack(
|
194
|
+
[FT.pad(a, (0, max_rms_len - a.shape[-1])) for a in rms]
|
195
|
+
) # shape: [B, L_max]
|
196
|
+
return dict(
|
197
|
+
mel=padded_mel,
|
198
|
+
wave=padded_wave,
|
199
|
+
pitch=padded_pitch,
|
200
|
+
rms=padded_rms,
|
201
|
+
file=file,
|
202
|
+
)
|
116
203
|
|
117
|
-
|
204
|
+
def get_item(self, idx: int):
|
205
|
+
if self.pre_loaded:
|
206
|
+
return self.cached_data[idx]
|
207
|
+
file = self.files[idx]
|
208
|
+
if file not in self.loaded_files:
|
209
|
+
self.loaded_files[file] = self.load_data(file)
|
210
|
+
return random.choice(self.loaded_files[file])
|
118
211
|
|
119
212
|
def __len__(self):
|
120
|
-
|
213
|
+
if self.pre_loaded:
|
214
|
+
return len(self.cached_data)
|
215
|
+
return len(self.files)
|
121
216
|
|
122
|
-
def __getitem__(self, index):
|
123
|
-
return self.
|
217
|
+
def __getitem__(self, index: int):
|
218
|
+
return self.get_item(index)
|
lt_tensor/misc_utils.py
CHANGED
@@ -23,6 +23,7 @@ __all__ = [
|
|
23
23
|
"get_losses",
|
24
24
|
"plot_view",
|
25
25
|
"get_weights",
|
26
|
+
"get_activated_conv",
|
26
27
|
]
|
27
28
|
|
28
29
|
import re
|
@@ -39,6 +40,42 @@ from lt_utils.common import *
|
|
39
40
|
from lt_utils.misc_utils import ff_list
|
40
41
|
import torch.nn.functional as F
|
41
42
|
|
43
|
+
CONV_MAP = {
|
44
|
+
"conv1d": nn.Conv1d,
|
45
|
+
"conv1d": nn.Conv2d,
|
46
|
+
"conv1d": nn.Conv3d,
|
47
|
+
"convtranspose1d": nn.ConvTranspose1d,
|
48
|
+
"convtranspose2d": nn.ConvTranspose2d,
|
49
|
+
"convtranspose3d": nn.ConvTranspose3d,
|
50
|
+
}
|
51
|
+
|
52
|
+
|
53
|
+
def get_activated_conv(
|
54
|
+
in_channels: int,
|
55
|
+
out_channels: int,
|
56
|
+
kernel_size: int = 1,
|
57
|
+
stride: int = 1,
|
58
|
+
padding: Union[int, str] = 0,
|
59
|
+
groups: int = 1,
|
60
|
+
conv_type: Literal[
|
61
|
+
"Conv1d",
|
62
|
+
"Conv2d",
|
63
|
+
"Conv3d",
|
64
|
+
"ConvTranspose1d",
|
65
|
+
"ConvTranspose2d",
|
66
|
+
"ConvTranspose3d",
|
67
|
+
] = "Conv1d",
|
68
|
+
activation: nn.Module = nn.LeakyReLU(0.2),
|
69
|
+
norm_fn: Callable[[nn.Module], nn.Module] = lambda x: x,
|
70
|
+
):
|
71
|
+
assert conv_type.lower() in CONV_MAP, f"Invalid conv type: {conv_type}."
|
72
|
+
return nn.Sequential(
|
73
|
+
activation,
|
74
|
+
norm_fn(
|
75
|
+
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups)
|
76
|
+
),
|
77
|
+
)
|
78
|
+
|
42
79
|
|
43
80
|
def plot_view(
|
44
81
|
data: Dict[str, List[Any]],
|
@@ -203,7 +240,7 @@ class LogTensor:
|
|
203
240
|
stored_items: List[
|
204
241
|
Dict[str, Union[str, Number, Tensor, List[Union[Tensor, Number, str]]]]
|
205
242
|
] = []
|
206
|
-
max_stored_items: int =
|
243
|
+
max_stored_items: int = 8
|
207
244
|
|
208
245
|
def _setup_message(self, title: str, t: Union[Tensor, str, int]):
|
209
246
|
try:
|
lt_tensor/model_zoo/__init__.py
CHANGED
@@ -1,11 +1,20 @@
|
|
1
1
|
__all__ = [
|
2
|
-
"
|
3
|
-
"
|
4
|
-
"
|
5
|
-
"
|
6
|
-
"
|
7
|
-
"
|
8
|
-
"
|
9
|
-
"istft",
|
2
|
+
"basic", # basic
|
3
|
+
"residual", # residual
|
4
|
+
"transformer", # transformer
|
5
|
+
"pos_encoder",
|
6
|
+
"fusion",
|
7
|
+
"features",
|
8
|
+
"discriminator",
|
9
|
+
"istft",
|
10
10
|
]
|
11
|
-
from . import
|
11
|
+
from . import (
|
12
|
+
basic,
|
13
|
+
discriminator,
|
14
|
+
features,
|
15
|
+
fusion,
|
16
|
+
istft,
|
17
|
+
pos_encoder,
|
18
|
+
residual,
|
19
|
+
transformer,
|
20
|
+
)
|
@@ -8,11 +8,14 @@ __all__ = [
|
|
8
8
|
"StyleEncoder",
|
9
9
|
"PatchEmbed1D",
|
10
10
|
"MultiScaleEncoder1D",
|
11
|
+
"UpSampleConv1D",
|
11
12
|
]
|
12
|
-
|
13
|
+
import torch.nn.functional as F
|
13
14
|
from lt_tensor.torch_commons import *
|
14
15
|
from lt_tensor.model_base import Model
|
15
16
|
from lt_tensor.transform import get_sinusoidal_embedding
|
17
|
+
from lt_utils.common import *
|
18
|
+
import math
|
16
19
|
|
17
20
|
|
18
21
|
class FeedForward(Model):
|
@@ -30,6 +33,9 @@ class FeedForward(Model):
|
|
30
33
|
nn.Linear(d_model, ff_dim),
|
31
34
|
activation,
|
32
35
|
nn.Dropout(dropout),
|
36
|
+
nn.Linear(ff_dim, ff_dim),
|
37
|
+
activation,
|
38
|
+
nn.Dropout(dropout),
|
33
39
|
nn.Linear(ff_dim, d_model),
|
34
40
|
normalizer,
|
35
41
|
)
|
@@ -38,6 +44,24 @@ class FeedForward(Model):
|
|
38
44
|
return self.net(x)
|
39
45
|
|
40
46
|
|
47
|
+
class UpSampleConv1D(nn.Module):
|
48
|
+
def __init__(self, upsample: bool = False, dim_in: int = 0, dim_out: int = 0):
|
49
|
+
super().__init__()
|
50
|
+
if upsample:
|
51
|
+
self.upsample = lambda x: F.interpolate(x, scale_factor=2, mode="nearest")
|
52
|
+
else:
|
53
|
+
self.upsample = nn.Identity()
|
54
|
+
|
55
|
+
if dim_in == dim_out:
|
56
|
+
self.learned = nn.Identity()
|
57
|
+
else:
|
58
|
+
self.learned = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
59
|
+
|
60
|
+
def forward(self, x):
|
61
|
+
x = self.upsample(x)
|
62
|
+
return self.learned(x)
|
63
|
+
|
64
|
+
|
41
65
|
class MLP(Model):
|
42
66
|
def __init__(
|
43
67
|
self,
|
@@ -161,7 +185,6 @@ class StyleEncoder(Model):
|
|
161
185
|
self.linear = nn.Linear(hidden, out_dim)
|
162
186
|
|
163
187
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
164
|
-
# x: [B, Mels, T]
|
165
188
|
x = self.net(x).squeeze(-1) # [B, hidden]
|
166
189
|
return self.linear(x) # [B, out_dim]
|
167
190
|
|
@@ -230,3 +253,96 @@ class AudioClassifier(Model):
|
|
230
253
|
|
231
254
|
def forward(self, x):
|
232
255
|
return self.model(x)
|
256
|
+
|
257
|
+
|
258
|
+
class LoRALinearLayer(nn.Module):
|
259
|
+
|
260
|
+
def __init__(
|
261
|
+
self,
|
262
|
+
in_features: int,
|
263
|
+
out_features: int,
|
264
|
+
rank: int = 4,
|
265
|
+
alpha: float = 1.0,
|
266
|
+
):
|
267
|
+
super().__init__()
|
268
|
+
self.down = nn.Linear(in_features, rank, bias=False)
|
269
|
+
self.up = nn.Linear(rank, out_features, bias=False)
|
270
|
+
self.alpha = alpha
|
271
|
+
self.rank = rank
|
272
|
+
self.out_features = out_features
|
273
|
+
self.in_features = in_features
|
274
|
+
self.ah = self.alpha / self.rank
|
275
|
+
self._down_dt = self.down.weight.dtype
|
276
|
+
|
277
|
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
278
|
+
nn.init.zeros_(self.up.weight)
|
279
|
+
|
280
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
281
|
+
orig_dtype = hidden_states.dtype
|
282
|
+
down_hidden_states = self.down(hidden_states.to(self._down_dt))
|
283
|
+
up_hidden_states = self.up(down_hidden_states) * self.ah
|
284
|
+
return up_hidden_states.to(orig_dtype)
|
285
|
+
|
286
|
+
|
287
|
+
class LoRAConv1DLayer(nn.Module):
|
288
|
+
def __init__(
|
289
|
+
self,
|
290
|
+
in_features: int,
|
291
|
+
out_features: int,
|
292
|
+
kernel_size: Union[int, Tuple[int, ...]] = 1,
|
293
|
+
rank: int = 4,
|
294
|
+
alpha: float = 1.0,
|
295
|
+
):
|
296
|
+
super().__init__()
|
297
|
+
self.down = nn.Conv1d(
|
298
|
+
in_features, rank, kernel_size, padding=kernel_size // 2, bias=False
|
299
|
+
)
|
300
|
+
self.up = nn.Conv1d(
|
301
|
+
rank, out_features, kernel_size, padding=kernel_size // 2, bias=False
|
302
|
+
)
|
303
|
+
self.ah = alpha / rank
|
304
|
+
self._down_dt = self.down.weight.dtype
|
305
|
+
nn.init.kaiming_uniform_(self.down.weight, a=math.sqrt(5))
|
306
|
+
nn.init.zeros_(self.up.weight)
|
307
|
+
|
308
|
+
def forward(self, inputs: Tensor) -> Tensor:
|
309
|
+
orig_dtype = inputs.dtype
|
310
|
+
down_hidden_states = self.down(inputs.to(self._down_dt))
|
311
|
+
up_hidden_states = self.up(down_hidden_states) * self.ah
|
312
|
+
return up_hidden_states.to(orig_dtype)
|
313
|
+
|
314
|
+
|
315
|
+
class LoRAConv2DLayer(nn.Module):
|
316
|
+
def __init__(
|
317
|
+
self,
|
318
|
+
in_features: int,
|
319
|
+
out_features: int,
|
320
|
+
kernel_size: Union[int, Tuple[int, ...]] = (1, 1),
|
321
|
+
rank: int = 4,
|
322
|
+
alpha: float = 1.0,
|
323
|
+
):
|
324
|
+
super().__init__()
|
325
|
+
self.down = nn.Conv2d(
|
326
|
+
in_features,
|
327
|
+
rank,
|
328
|
+
kernel_size,
|
329
|
+
padding="same",
|
330
|
+
bias=False,
|
331
|
+
)
|
332
|
+
self.up = nn.Conv2d(
|
333
|
+
rank,
|
334
|
+
out_features,
|
335
|
+
kernel_size,
|
336
|
+
padding="same",
|
337
|
+
bias=False,
|
338
|
+
)
|
339
|
+
self.ah = alpha / rank
|
340
|
+
|
341
|
+
nn.init.kaiming_normal_(self.down.weight, a=0.2)
|
342
|
+
nn.init.zeros_(self.up.weight)
|
343
|
+
|
344
|
+
def forward(self, inputs: Tensor) -> Tensor:
|
345
|
+
orig_dtype = inputs.dtype
|
346
|
+
down_hidden_states = self.down(inputs.to(self._down_dt))
|
347
|
+
up_hidden_states = self.up(down_hidden_states) * self.ah
|
348
|
+
return up_hidden_states.to(orig_dtype)
|