lt-tensor 0.0.1a12__py3-none-any.whl → 0.0.1a13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/datasets/audio.py +141 -46
- lt_tensor/misc_utils.py +37 -0
- lt_tensor/model_zoo/__init__.py +18 -9
- lt_tensor/model_zoo/{bsc.py → basic.py} +118 -2
- lt_tensor/model_zoo/features.py +416 -0
- lt_tensor/model_zoo/fusion.py +164 -0
- lt_tensor/model_zoo/istft/generator.py +2 -2
- lt_tensor/model_zoo/istft/sg.py +142 -0
- lt_tensor/model_zoo/istft/trainer.py +37 -12
- lt_tensor/model_zoo/residual.py +217 -0
- lt_tensor/model_zoo/{tfrms.py → transformer.py} +2 -2
- lt_tensor/processors/audio.py +218 -80
- lt_tensor/transform.py +7 -16
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a13.dist-info}/METADATA +6 -4
- lt_tensor-0.0.1a13.dist-info/RECORD +32 -0
- lt_tensor/model_zoo/fsn.py +0 -67
- lt_tensor/model_zoo/gns.py +0 -185
- lt_tensor/model_zoo/istft.py +0 -591
- lt_tensor/model_zoo/rsd.py +0 -107
- lt_tensor-0.0.1a12.dist-info/RECORD +0 -32
- /lt_tensor/model_zoo/{disc.py → discriminator.py} +0 -0
- /lt_tensor/model_zoo/{pos.py → pos_encoder.py} +0 -0
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a13.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a13.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,142 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import math
|
4
|
+
from einops import repeat
|
5
|
+
|
6
|
+
|
7
|
+
class SineGen(nn.Module):
|
8
|
+
def __init__(
|
9
|
+
self,
|
10
|
+
samp_rate,
|
11
|
+
upsample_scale,
|
12
|
+
harmonic_num=0,
|
13
|
+
sine_amp=0.1,
|
14
|
+
noise_std=0.003,
|
15
|
+
voiced_threshold=0,
|
16
|
+
flag_for_pulse=False,
|
17
|
+
):
|
18
|
+
super().__init__()
|
19
|
+
self.sampling_rate = samp_rate
|
20
|
+
self.upsample_scale = upsample_scale
|
21
|
+
self.harmonic_num = harmonic_num
|
22
|
+
self.sine_amp = sine_amp
|
23
|
+
self.noise_std = noise_std
|
24
|
+
self.voiced_threshold = voiced_threshold
|
25
|
+
self.flag_for_pulse = flag_for_pulse
|
26
|
+
self.dim = self.harmonic_num + 1 # fundamental + harmonics
|
27
|
+
|
28
|
+
def _f02uv_b(self, f0):
|
29
|
+
return (f0 > self.voiced_threshold).float() # [B, T]
|
30
|
+
|
31
|
+
def _f02uv(self, f0):
|
32
|
+
return (f0 > self.voiced_threshold).float().unsqueeze(-1) # -> (B, T, 1)
|
33
|
+
|
34
|
+
@torch.no_grad()
|
35
|
+
def _f02sine(self, f0_values):
|
36
|
+
"""
|
37
|
+
f0_values: (B, T, 1)
|
38
|
+
Output: sine waves (B, T * upsample, dim)
|
39
|
+
"""
|
40
|
+
B, T, _ = f0_values.size()
|
41
|
+
f0_upsampled = repeat(
|
42
|
+
f0_values, "b t d -> b (t r) d", r=self.upsample_scale
|
43
|
+
) # (B, T_up, 1)
|
44
|
+
|
45
|
+
# Create harmonics
|
46
|
+
harmonics = (
|
47
|
+
torch.arange(1, self.dim + 1, device=f0_values.device)
|
48
|
+
.float()
|
49
|
+
.view(1, 1, -1)
|
50
|
+
)
|
51
|
+
f0_harm = f0_upsampled * harmonics # (B, T_up, dim)
|
52
|
+
|
53
|
+
# Convert Hz to radians (2πf/sr), then integrate to get phase
|
54
|
+
rad_values = f0_harm / self.sampling_rate # normalized freq
|
55
|
+
rad_values = rad_values % 1.0 # remove multiples of 2π
|
56
|
+
|
57
|
+
# Random initial phase for each harmonic (except 0th if pulse mode)
|
58
|
+
if self.flag_for_pulse:
|
59
|
+
rand_ini = torch.zeros((B, 1, self.dim), device=f0_values.device)
|
60
|
+
else:
|
61
|
+
rand_ini = torch.rand((B, 1, self.dim), device=f0_values.device)
|
62
|
+
|
63
|
+
rand_ini = rand_ini * 2 * math.pi
|
64
|
+
|
65
|
+
# Compute cumulative phase
|
66
|
+
rad_values = rad_values * 2 * math.pi
|
67
|
+
phase = torch.cumsum(rad_values, dim=1) + rand_ini # (B, T_up, dim)
|
68
|
+
|
69
|
+
sine_waves = torch.sin(phase) # (B, T_up, dim)
|
70
|
+
return sine_waves
|
71
|
+
|
72
|
+
def _forward(self, f0):
|
73
|
+
"""
|
74
|
+
f0: (B, T, 1)
|
75
|
+
returns: sine signal with harmonics and noise added
|
76
|
+
"""
|
77
|
+
sine_waves = self._f02sine(f0) # (B, T_up, dim)
|
78
|
+
uv = self._f02uv_b(f0) # (B, T, 1)
|
79
|
+
uv = repeat(uv, "b t d -> b (t r) d", r=self.upsample_scale) # (B, T_up, 1)
|
80
|
+
|
81
|
+
# voiced sine + unvoiced noise
|
82
|
+
sine_signal = self.sine_amp * sine_waves * uv # (B, T_up, dim)
|
83
|
+
noise = torch.randn_like(sine_signal) * self.noise_std
|
84
|
+
output = sine_signal + noise * (1.0 - uv) # noise added only on unvoiced
|
85
|
+
|
86
|
+
return output # (B, T_up, dim)
|
87
|
+
|
88
|
+
def forward(self, f0):
|
89
|
+
"""
|
90
|
+
Args:
|
91
|
+
f0: (B, T) in Hz (before upsampling)
|
92
|
+
Returns:
|
93
|
+
sine_waves: (B, T_up, dim)
|
94
|
+
uv: (B, T_up, 1)
|
95
|
+
noise: (B, T_up, 1)
|
96
|
+
"""
|
97
|
+
B, T = f0.shape
|
98
|
+
device = f0.device
|
99
|
+
|
100
|
+
# Get uv mask (before upsampling)
|
101
|
+
uv = self._f02uv(f0) # (B, T, 1)
|
102
|
+
|
103
|
+
# Expand f0 to include harmonics: (B, T, dim)
|
104
|
+
f0 = f0.unsqueeze(-1) # (B, T, 1)
|
105
|
+
harmonics = (
|
106
|
+
torch.arange(1, self.dim + 1, device=device).float().view(1, 1, -1)
|
107
|
+
) # (1, 1, dim)
|
108
|
+
f0_harm = f0 * harmonics # (B, T, dim)
|
109
|
+
|
110
|
+
# Upsample
|
111
|
+
f0_harm_up = repeat(
|
112
|
+
f0_harm, "b t d -> b (t r) d", r=self.upsample_scale
|
113
|
+
) # (B, T_up, dim)
|
114
|
+
uv_up = repeat(uv, "b t d -> b (t r) d", r=self.upsample_scale) # (B, T_up, 1)
|
115
|
+
|
116
|
+
# Convert to radians
|
117
|
+
rad_per_sample = f0_harm_up / self.sampling_rate # Hz → cycles/sample
|
118
|
+
rad_per_sample = rad_per_sample * 2 * math.pi # cycles → radians/sample
|
119
|
+
|
120
|
+
# Random phase init for each sample
|
121
|
+
B, T_up, D = rad_per_sample.shape
|
122
|
+
rand_phase = torch.rand(B, D, device=device) * 2 * math.pi # (B, D)
|
123
|
+
|
124
|
+
# Compute cumulative phase
|
125
|
+
phase = torch.cumsum(rad_per_sample, dim=1) + rand_phase.unsqueeze(
|
126
|
+
1
|
127
|
+
) # (B, T_up, D)
|
128
|
+
|
129
|
+
# Apply sine
|
130
|
+
sine_waves = torch.sin(phase) * self.sine_amp # (B, T_up, D)
|
131
|
+
|
132
|
+
# Handle unvoiced: create noise only for fundamental
|
133
|
+
noise = torch.randn(B, T_up, 1, device=device) * self.noise_std
|
134
|
+
if self.flag_for_pulse:
|
135
|
+
# If pulse mode is on, align phase at start of voiced segments
|
136
|
+
# Optional and tricky to implement — may require segmenting uv
|
137
|
+
pass
|
138
|
+
|
139
|
+
# Replace sine by noise for unvoiced (only on fundamental)
|
140
|
+
sine_waves[:, :, 0:1] = sine_waves[:, :, 0:1] * uv_up + noise * (1 - uv_up)
|
141
|
+
|
142
|
+
return sine_waves, uv_up, noise
|
@@ -13,18 +13,45 @@ from lt_tensor.misc_utils import set_seed, clear_cache
|
|
13
13
|
from lt_utils.type_utils import is_dir, is_pathlike, is_file
|
14
14
|
from lt_tensor.config_templates import updateDict, ModelConfig
|
15
15
|
from lt_tensor.model_zoo.istft.generator import iSTFTGenerator
|
16
|
-
from lt_tensor.model_zoo.
|
17
|
-
from lt_tensor.model_zoo.
|
16
|
+
from lt_tensor.model_zoo.residual import ResBlock1D, ConvNets, get_weight_norm
|
17
|
+
from lt_tensor.model_zoo.discriminator import MultiPeriodDiscriminator, MultiScaleDiscriminator
|
18
18
|
|
19
19
|
|
20
|
-
def feature_loss(
|
21
|
-
loss = 0
|
22
|
-
for
|
23
|
-
for
|
24
|
-
loss +=
|
20
|
+
def feature_loss(fmap_r, fmap_g):
|
21
|
+
loss = 0
|
22
|
+
for dr, dg in zip(fmap_r, fmap_g):
|
23
|
+
for rl, gl in zip(dr, dg):
|
24
|
+
loss += torch.mean(torch.abs(rl - gl))
|
25
|
+
return loss * 2
|
26
|
+
|
27
|
+
|
28
|
+
def generator_adv_loss(disc_outputs):
|
29
|
+
loss = 0
|
30
|
+
for dg in disc_outputs:
|
31
|
+
l = torch.mean((1 - dg) ** 2)
|
32
|
+
|
33
|
+
loss += l
|
34
|
+
return loss
|
35
|
+
|
36
|
+
|
37
|
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
38
|
+
loss = 0
|
39
|
+
|
40
|
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
41
|
+
r_loss = torch.mean((1 - dr) ** 2)
|
42
|
+
g_loss = torch.mean(dg**2)
|
43
|
+
loss += r_loss + g_loss
|
25
44
|
return loss
|
26
45
|
|
27
46
|
|
47
|
+
"""def feature_loss(fmap_r, fmap_g):
|
48
|
+
loss = 0
|
49
|
+
for dr, dg in zip(fmap_r, fmap_g):
|
50
|
+
for rl, gl in zip(dr, dg):
|
51
|
+
loss += torch.mean(torch.abs(rl - gl))
|
52
|
+
return loss * 2
|
53
|
+
|
54
|
+
|
28
55
|
def generator_adv_loss(fake_preds):
|
29
56
|
loss = 0.0
|
30
57
|
for f in fake_preds:
|
@@ -37,6 +64,7 @@ def discriminator_loss(real_preds, fake_preds):
|
|
37
64
|
for r, f in zip(real_preds, fake_preds):
|
38
65
|
loss += torch.mean((r - 1.0) ** 2) + torch.mean(f**2)
|
39
66
|
return loss
|
67
|
+
"""
|
40
68
|
|
41
69
|
|
42
70
|
class AudioSettings(ModelConfig):
|
@@ -284,9 +312,6 @@ class AudioDecoder(Model):
|
|
284
312
|
win_length=self.settings.n_fft,
|
285
313
|
# length=real_audio.shape[-1]
|
286
314
|
)[:, : real_audio.shape[-1]]
|
287
|
-
# smallest = min(real_audio.shape[-1], fake_audio.shape[-1])
|
288
|
-
# real_audio = real_audio[:, :, :smallest].squeeze(1)
|
289
|
-
# fake_audio = fake_audio[:, :smallest]
|
290
315
|
|
291
316
|
disc_kwargs = dict(
|
292
317
|
real_audio=real_audio,
|
@@ -372,13 +397,13 @@ class AudioDecoder(Model):
|
|
372
397
|
|
373
398
|
loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
|
374
399
|
loss_mel = (
|
375
|
-
F.
|
400
|
+
F.huber_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
|
376
401
|
)
|
377
402
|
loss_fm = ((loss_fm_mpd + loss_fm_msd) * fm_scale) + fm_add
|
378
403
|
|
379
404
|
loss_adv = (loss_adv_mpd + loss_adv_msd) * adv_scale
|
380
405
|
|
381
|
-
loss_g = loss_adv + loss_fm + loss_stft + loss_mel
|
406
|
+
loss_g = loss_adv + loss_fm + loss_stft # + loss_mel
|
382
407
|
if not am_i_frozen:
|
383
408
|
self.g_optim.zero_grad()
|
384
409
|
loss_g.backward()
|
@@ -0,0 +1,217 @@
|
|
1
|
+
__all__ = [
|
2
|
+
"spectral_norm_select",
|
3
|
+
"get_weight_norm",
|
4
|
+
"ResBlock1D",
|
5
|
+
"ResBlock2D",
|
6
|
+
"ResBlock1DShuffled",
|
7
|
+
"AdaResBlock1D",
|
8
|
+
]
|
9
|
+
import math
|
10
|
+
from lt_utils.common import *
|
11
|
+
from lt_tensor.torch_commons import *
|
12
|
+
from lt_tensor.model_base import Model
|
13
|
+
from lt_tensor.misc_utils import log_tensor
|
14
|
+
import torch.nn.functional as F
|
15
|
+
from lt_tensor.model_zoo.fusion import AdaFusion1D, AdaIN1D
|
16
|
+
|
17
|
+
|
18
|
+
def spectral_norm_select(module: nn.Module, enabled: bool):
|
19
|
+
if enabled:
|
20
|
+
return spectral_norm(module)
|
21
|
+
return module
|
22
|
+
|
23
|
+
|
24
|
+
def get_weight_norm(norm_type: Optional[Literal["weight", "spectral"]] = None):
|
25
|
+
if not norm_type:
|
26
|
+
return lambda x: x
|
27
|
+
if norm_type == "weight":
|
28
|
+
return lambda x: weight_norm(x)
|
29
|
+
return lambda x: spectral_norm(x)
|
30
|
+
|
31
|
+
|
32
|
+
class ConvNets(Model):
|
33
|
+
def remove_weight_norm(self):
|
34
|
+
for module in self.modules():
|
35
|
+
try:
|
36
|
+
remove_weight_norm(module)
|
37
|
+
except ValueError:
|
38
|
+
pass
|
39
|
+
|
40
|
+
@staticmethod
|
41
|
+
def init_weights(m, mean=0.0, std=0.01):
|
42
|
+
classname = m.__class__.__name__
|
43
|
+
if "Conv" in classname:
|
44
|
+
m.weight.data.normal_(mean, std)
|
45
|
+
|
46
|
+
|
47
|
+
class ResBlock1D(ConvNets):
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
channels,
|
51
|
+
kernel_size=3,
|
52
|
+
dilation=(1, 3, 5),
|
53
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
54
|
+
):
|
55
|
+
super().__init__()
|
56
|
+
|
57
|
+
self.conv_nets = nn.ModuleList(
|
58
|
+
[
|
59
|
+
self._get_conv_layer(i, channels, kernel_size, 1, dilation, activation)
|
60
|
+
for i in range(3)
|
61
|
+
]
|
62
|
+
)
|
63
|
+
self.conv_nets.apply(self.init_weights)
|
64
|
+
self.last_index = len(self.conv_nets) - 1
|
65
|
+
|
66
|
+
def _get_conv_layer(self, id, ch, k, stride, d, actv):
|
67
|
+
get_padding = lambda ks, d: int((ks * d - d) / 2)
|
68
|
+
return nn.Sequential(
|
69
|
+
actv, # 1
|
70
|
+
weight_norm(
|
71
|
+
nn.Conv1d(
|
72
|
+
ch, ch, k, stride, dilation=d[id], padding=get_padding(k, d[id])
|
73
|
+
)
|
74
|
+
), # 2
|
75
|
+
actv, # 3
|
76
|
+
weight_norm(
|
77
|
+
nn.Conv1d(ch, ch, k, stride, dilation=1, padding=get_padding(k, 1))
|
78
|
+
), # 4
|
79
|
+
)
|
80
|
+
|
81
|
+
def forward(self, x: Tensor):
|
82
|
+
for cnn in self.conv_nets:
|
83
|
+
x = cnn(x) + x
|
84
|
+
return x
|
85
|
+
|
86
|
+
|
87
|
+
class ResBlock1DShuffled(ConvNets):
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
channels,
|
91
|
+
kernel_size=3,
|
92
|
+
dilation=(1, 3, 5),
|
93
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
94
|
+
add_channel_shuffle: bool = False, # requires pytorch 2.7.0 +
|
95
|
+
channel_shuffle_groups=1,
|
96
|
+
):
|
97
|
+
super().__init__()
|
98
|
+
|
99
|
+
self.channel_shuffle = (
|
100
|
+
nn.ChannelShuffle(channel_shuffle_groups)
|
101
|
+
if add_channel_shuffle
|
102
|
+
else nn.Identity()
|
103
|
+
)
|
104
|
+
|
105
|
+
self.conv_nets = nn.ModuleList(
|
106
|
+
[
|
107
|
+
self._get_conv_layer(i, channels, kernel_size, 1, dilation, activation)
|
108
|
+
for i in range(3)
|
109
|
+
]
|
110
|
+
)
|
111
|
+
self.conv_nets.apply(self.init_weights)
|
112
|
+
self.last_index = len(self.conv_nets) - 1
|
113
|
+
|
114
|
+
def _get_conv_layer(self, id, ch, k, stride, d, actv):
|
115
|
+
get_padding = lambda ks, d: int((ks * d - d) / 2)
|
116
|
+
return nn.Sequential(
|
117
|
+
actv, # 1
|
118
|
+
weight_norm(
|
119
|
+
nn.Conv1d(
|
120
|
+
ch, ch, k, stride, dilation=d[id], padding=get_padding(k, d[id])
|
121
|
+
)
|
122
|
+
), # 2
|
123
|
+
actv, # 3
|
124
|
+
weight_norm(
|
125
|
+
nn.Conv1d(ch, ch, k, stride, dilation=1, padding=get_padding(k, 1))
|
126
|
+
), # 4
|
127
|
+
)
|
128
|
+
|
129
|
+
def forward(self, x: Tensor):
|
130
|
+
b = x.clone() * 0.5
|
131
|
+
for cnn in self.conv_nets:
|
132
|
+
x = cnn(self.channel_shuffle(x)) + b
|
133
|
+
return x
|
134
|
+
|
135
|
+
|
136
|
+
class ResBlock2D(Model):
|
137
|
+
def __init__(
|
138
|
+
self,
|
139
|
+
in_channels,
|
140
|
+
out_channels,
|
141
|
+
downsample=False,
|
142
|
+
):
|
143
|
+
super().__init__()
|
144
|
+
stride = 2 if downsample else 1
|
145
|
+
|
146
|
+
self.block = nn.Sequential(
|
147
|
+
nn.Conv2d(in_channels, out_channels, 3, stride, 1),
|
148
|
+
nn.LeakyReLU(0.2),
|
149
|
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
|
150
|
+
)
|
151
|
+
|
152
|
+
self.skip = nn.Identity()
|
153
|
+
if downsample or in_channels != out_channels:
|
154
|
+
self.skip = spectral_norm_select(
|
155
|
+
nn.Conv2d(in_channels, out_channels, 1, stride)
|
156
|
+
)
|
157
|
+
# on less to be handled every cicle
|
158
|
+
self.sqrt_2 = math.sqrt(2)
|
159
|
+
|
160
|
+
def forward(self, x: Tensor):
|
161
|
+
return (self.block(x) + self.skip(x)) / self.sqrt_2
|
162
|
+
|
163
|
+
|
164
|
+
class AdaResBlock1D(ConvNets):
|
165
|
+
def __init__(
|
166
|
+
self,
|
167
|
+
res_block_channels: int,
|
168
|
+
ada_channel_in: int,
|
169
|
+
kernel_size=3,
|
170
|
+
dilation=(1, 3, 5),
|
171
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
172
|
+
):
|
173
|
+
super().__init__()
|
174
|
+
|
175
|
+
self.conv_nets = nn.ModuleList(
|
176
|
+
[
|
177
|
+
self._get_conv_layer(
|
178
|
+
i,
|
179
|
+
res_block_channels,
|
180
|
+
ada_channel_in,
|
181
|
+
kernel_size,
|
182
|
+
1,
|
183
|
+
dilation,
|
184
|
+
)
|
185
|
+
for i in range(3)
|
186
|
+
]
|
187
|
+
)
|
188
|
+
self.conv_nets.apply(self.init_weights)
|
189
|
+
self.last_index = len(self.conv_nets) - 1
|
190
|
+
self.activation = activation
|
191
|
+
|
192
|
+
def _get_conv_layer(self, id, ch, ada_ch, k, stride, d):
|
193
|
+
get_padding = lambda ks, d: int((ks * d - d) / 2)
|
194
|
+
return nn.ModuleDict(
|
195
|
+
dict(
|
196
|
+
norm1=AdaFusion1D(ada_ch, ch),
|
197
|
+
norm2=AdaFusion1D(ada_ch, ch),
|
198
|
+
alpha1=nn.Parameter(torch.ones(1, ada_ch, 1)),
|
199
|
+
alpha2=nn.Parameter(torch.ones(1, ada_ch, 1)),
|
200
|
+
conv1=weight_norm(
|
201
|
+
nn.Conv1d(
|
202
|
+
ch, ch, k, stride, dilation=d[id], padding=get_padding(k, d[id])
|
203
|
+
)
|
204
|
+
), # 2
|
205
|
+
conv2=weight_norm(
|
206
|
+
nn.Conv1d(ch, ch, k, stride, dilation=1, padding=get_padding(k, 1))
|
207
|
+
), # 4
|
208
|
+
)
|
209
|
+
)
|
210
|
+
|
211
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
212
|
+
for cnn in self.conv_nets:
|
213
|
+
xt = self.activation(cnn["norm1"](x, y, cnn["alpha1"]))
|
214
|
+
xt = cnn["conv1"](xt)
|
215
|
+
xt = self.activation(cnn["norm2"](xt, y, cnn["alpha2"]))
|
216
|
+
x = cnn["conv2"](xt) + x
|
217
|
+
return x
|
@@ -11,8 +11,8 @@ from lt_tensor.torch_commons import *
|
|
11
11
|
from lt_tensor.model_base import Model
|
12
12
|
from lt_utils.misc_utils import default
|
13
13
|
from typing import Optional
|
14
|
-
from lt_tensor.model_zoo.
|
15
|
-
from lt_tensor.model_zoo.
|
14
|
+
from lt_tensor.model_zoo.pos_encoder import *
|
15
|
+
from lt_tensor.model_zoo.basic import FeedForward
|
16
16
|
|
17
17
|
|
18
18
|
def init_weights(module):
|