xinference 0.13.2__py3-none-any.whl → 0.13.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -1
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +26 -4
- xinference/client/restful/restful_client.py +16 -1
- xinference/core/chat_interface.py +2 -2
- xinference/core/model.py +8 -3
- xinference/core/scheduler.py +4 -4
- xinference/model/audio/core.py +5 -2
- xinference/model/audio/cosyvoice.py +136 -0
- xinference/model/audio/model_spec.json +24 -0
- xinference/model/audio/model_spec_modelscope.json +27 -0
- xinference/model/flexible/launchers/__init__.py +1 -0
- xinference/model/flexible/launchers/image_process_launcher.py +70 -0
- xinference/model/image/model_spec.json +7 -0
- xinference/model/image/stable_diffusion/core.py +6 -1
- xinference/model/llm/llm_family.json +802 -82
- xinference/model/llm/llm_family_csghub.json +39 -0
- xinference/model/llm/llm_family_modelscope.json +295 -47
- xinference/model/llm/pytorch/chatglm.py +243 -5
- xinference/model/llm/pytorch/cogvlm2.py +1 -1
- xinference/model/llm/utils.py +78 -1
- xinference/model/llm/vllm/core.py +8 -0
- xinference/thirdparty/cosyvoice/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
- xinference/thirdparty/cosyvoice/bin/train.py +136 -0
- xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
- xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
- xinference/thirdparty/cosyvoice/cli/model.py +60 -0
- xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
- xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
- xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
- xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
- xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
- xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
- xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
- xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
- xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
- xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
- xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
- xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
- xinference/thirdparty/cosyvoice/utils/common.py +103 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
- xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/METADATA +16 -8
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/RECORD +76 -32
- xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
- /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
|
2
|
+
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
|
16
|
+
"""Encoder self-attention layer definition."""
|
|
17
|
+
|
|
18
|
+
from typing import Optional, Tuple
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from torch import nn
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TransformerEncoderLayer(nn.Module):
|
|
25
|
+
"""Encoder layer module.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
size (int): Input dimension.
|
|
29
|
+
self_attn (torch.nn.Module): Self-attention module instance.
|
|
30
|
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
|
31
|
+
instance can be used as the argument.
|
|
32
|
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
|
33
|
+
`PositionwiseFeedForward`, instance can be used as the argument.
|
|
34
|
+
dropout_rate (float): Dropout rate.
|
|
35
|
+
normalize_before (bool):
|
|
36
|
+
True: use layer_norm before each sub-block.
|
|
37
|
+
False: to use layer_norm after each sub-block.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
size: int,
|
|
43
|
+
self_attn: torch.nn.Module,
|
|
44
|
+
feed_forward: torch.nn.Module,
|
|
45
|
+
dropout_rate: float,
|
|
46
|
+
normalize_before: bool = True,
|
|
47
|
+
):
|
|
48
|
+
"""Construct an EncoderLayer object."""
|
|
49
|
+
super().__init__()
|
|
50
|
+
self.self_attn = self_attn
|
|
51
|
+
self.feed_forward = feed_forward
|
|
52
|
+
self.norm1 = nn.LayerNorm(size, eps=1e-5)
|
|
53
|
+
self.norm2 = nn.LayerNorm(size, eps=1e-5)
|
|
54
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
55
|
+
self.size = size
|
|
56
|
+
self.normalize_before = normalize_before
|
|
57
|
+
|
|
58
|
+
def forward(
|
|
59
|
+
self,
|
|
60
|
+
x: torch.Tensor,
|
|
61
|
+
mask: torch.Tensor,
|
|
62
|
+
pos_emb: torch.Tensor,
|
|
63
|
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
64
|
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
|
65
|
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
|
66
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
67
|
+
"""Compute encoded features.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
x (torch.Tensor): (#batch, time, size)
|
|
71
|
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
|
72
|
+
(0, 0, 0) means fake mask.
|
|
73
|
+
pos_emb (torch.Tensor): just for interface compatibility
|
|
74
|
+
to ConformerEncoderLayer
|
|
75
|
+
mask_pad (torch.Tensor): does not used in transformer layer,
|
|
76
|
+
just for unified api with conformer.
|
|
77
|
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
|
78
|
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
|
79
|
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
|
80
|
+
(#batch=1, size, cache_t2), not used here, it's for interface
|
|
81
|
+
compatibility to ConformerEncoderLayer.
|
|
82
|
+
Returns:
|
|
83
|
+
torch.Tensor: Output tensor (#batch, time, size).
|
|
84
|
+
torch.Tensor: Mask tensor (#batch, time, time).
|
|
85
|
+
torch.Tensor: att_cache tensor,
|
|
86
|
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
|
87
|
+
torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
|
|
88
|
+
|
|
89
|
+
"""
|
|
90
|
+
residual = x
|
|
91
|
+
if self.normalize_before:
|
|
92
|
+
x = self.norm1(x)
|
|
93
|
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
|
|
94
|
+
x = residual + self.dropout(x_att)
|
|
95
|
+
if not self.normalize_before:
|
|
96
|
+
x = self.norm1(x)
|
|
97
|
+
|
|
98
|
+
residual = x
|
|
99
|
+
if self.normalize_before:
|
|
100
|
+
x = self.norm2(x)
|
|
101
|
+
x = residual + self.dropout(self.feed_forward(x))
|
|
102
|
+
if not self.normalize_before:
|
|
103
|
+
x = self.norm2(x)
|
|
104
|
+
|
|
105
|
+
fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
|
106
|
+
return x, mask, new_att_cache, fake_cnn_cache
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class ConformerEncoderLayer(nn.Module):
|
|
110
|
+
"""Encoder layer module.
|
|
111
|
+
Args:
|
|
112
|
+
size (int): Input dimension.
|
|
113
|
+
self_attn (torch.nn.Module): Self-attention module instance.
|
|
114
|
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
|
115
|
+
instance can be used as the argument.
|
|
116
|
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
|
117
|
+
`PositionwiseFeedForward` instance can be used as the argument.
|
|
118
|
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
|
119
|
+
instance.
|
|
120
|
+
`PositionwiseFeedForward` instance can be used as the argument.
|
|
121
|
+
conv_module (torch.nn.Module): Convolution module instance.
|
|
122
|
+
`ConvlutionModule` instance can be used as the argument.
|
|
123
|
+
dropout_rate (float): Dropout rate.
|
|
124
|
+
normalize_before (bool):
|
|
125
|
+
True: use layer_norm before each sub-block.
|
|
126
|
+
False: use layer_norm after each sub-block.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
size: int,
|
|
132
|
+
self_attn: torch.nn.Module,
|
|
133
|
+
feed_forward: Optional[nn.Module] = None,
|
|
134
|
+
feed_forward_macaron: Optional[nn.Module] = None,
|
|
135
|
+
conv_module: Optional[nn.Module] = None,
|
|
136
|
+
dropout_rate: float = 0.1,
|
|
137
|
+
normalize_before: bool = True,
|
|
138
|
+
):
|
|
139
|
+
"""Construct an EncoderLayer object."""
|
|
140
|
+
super().__init__()
|
|
141
|
+
self.self_attn = self_attn
|
|
142
|
+
self.feed_forward = feed_forward
|
|
143
|
+
self.feed_forward_macaron = feed_forward_macaron
|
|
144
|
+
self.conv_module = conv_module
|
|
145
|
+
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
|
|
146
|
+
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
|
|
147
|
+
if feed_forward_macaron is not None:
|
|
148
|
+
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
|
|
149
|
+
self.ff_scale = 0.5
|
|
150
|
+
else:
|
|
151
|
+
self.ff_scale = 1.0
|
|
152
|
+
if self.conv_module is not None:
|
|
153
|
+
self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
|
|
154
|
+
self.norm_final = nn.LayerNorm(
|
|
155
|
+
size, eps=1e-5) # for the final output of the block
|
|
156
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
157
|
+
self.size = size
|
|
158
|
+
self.normalize_before = normalize_before
|
|
159
|
+
|
|
160
|
+
def forward(
|
|
161
|
+
self,
|
|
162
|
+
x: torch.Tensor,
|
|
163
|
+
mask: torch.Tensor,
|
|
164
|
+
pos_emb: torch.Tensor,
|
|
165
|
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
166
|
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
|
167
|
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
|
168
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
169
|
+
"""Compute encoded features.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
x (torch.Tensor): (#batch, time, size)
|
|
173
|
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
|
174
|
+
(0, 0, 0) means fake mask.
|
|
175
|
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
|
176
|
+
for ConformerEncoderLayer.
|
|
177
|
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
|
178
|
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
|
179
|
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
|
180
|
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
|
181
|
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
|
182
|
+
(#batch=1, size, cache_t2)
|
|
183
|
+
Returns:
|
|
184
|
+
torch.Tensor: Output tensor (#batch, time, size).
|
|
185
|
+
torch.Tensor: Mask tensor (#batch, time, time).
|
|
186
|
+
torch.Tensor: att_cache tensor,
|
|
187
|
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
|
188
|
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
# whether to use macaron style
|
|
192
|
+
if self.feed_forward_macaron is not None:
|
|
193
|
+
residual = x
|
|
194
|
+
if self.normalize_before:
|
|
195
|
+
x = self.norm_ff_macaron(x)
|
|
196
|
+
x = residual + self.ff_scale * self.dropout(
|
|
197
|
+
self.feed_forward_macaron(x))
|
|
198
|
+
if not self.normalize_before:
|
|
199
|
+
x = self.norm_ff_macaron(x)
|
|
200
|
+
|
|
201
|
+
# multi-headed self-attention module
|
|
202
|
+
residual = x
|
|
203
|
+
if self.normalize_before:
|
|
204
|
+
x = self.norm_mha(x)
|
|
205
|
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
|
|
206
|
+
att_cache)
|
|
207
|
+
x = residual + self.dropout(x_att)
|
|
208
|
+
if not self.normalize_before:
|
|
209
|
+
x = self.norm_mha(x)
|
|
210
|
+
|
|
211
|
+
# convolution module
|
|
212
|
+
# Fake new cnn cache here, and then change it in conv_module
|
|
213
|
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
|
214
|
+
if self.conv_module is not None:
|
|
215
|
+
residual = x
|
|
216
|
+
if self.normalize_before:
|
|
217
|
+
x = self.norm_conv(x)
|
|
218
|
+
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
|
219
|
+
x = residual + self.dropout(x)
|
|
220
|
+
|
|
221
|
+
if not self.normalize_before:
|
|
222
|
+
x = self.norm_conv(x)
|
|
223
|
+
|
|
224
|
+
# feed forward module
|
|
225
|
+
residual = x
|
|
226
|
+
if self.normalize_before:
|
|
227
|
+
x = self.norm_ff(x)
|
|
228
|
+
|
|
229
|
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
|
230
|
+
if not self.normalize_before:
|
|
231
|
+
x = self.norm_ff(x)
|
|
232
|
+
|
|
233
|
+
if self.conv_module is not None:
|
|
234
|
+
x = self.norm_final(x)
|
|
235
|
+
|
|
236
|
+
return x, mask, new_att_cache, new_cnn_cache
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# Copyright (c) 2019 Shigeki Karita
|
|
2
|
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Label smoothing module."""
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LabelSmoothingLoss(nn.Module):
|
|
22
|
+
"""Label-smoothing loss.
|
|
23
|
+
|
|
24
|
+
In a standard CE loss, the label's data distribution is:
|
|
25
|
+
[0,1,2] ->
|
|
26
|
+
[
|
|
27
|
+
[1.0, 0.0, 0.0],
|
|
28
|
+
[0.0, 1.0, 0.0],
|
|
29
|
+
[0.0, 0.0, 1.0],
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
In the smoothing version CE Loss,some probabilities
|
|
33
|
+
are taken from the true label prob (1.0) and are divided
|
|
34
|
+
among other labels.
|
|
35
|
+
|
|
36
|
+
e.g.
|
|
37
|
+
smoothing=0.1
|
|
38
|
+
[0,1,2] ->
|
|
39
|
+
[
|
|
40
|
+
[0.9, 0.05, 0.05],
|
|
41
|
+
[0.05, 0.9, 0.05],
|
|
42
|
+
[0.05, 0.05, 0.9],
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
size (int): the number of class
|
|
47
|
+
padding_idx (int): padding class id which will be ignored for loss
|
|
48
|
+
smoothing (float): smoothing rate (0.0 means the conventional CE)
|
|
49
|
+
normalize_length (bool):
|
|
50
|
+
normalize loss by sequence length if True
|
|
51
|
+
normalize loss by batch size if False
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self,
|
|
55
|
+
size: int,
|
|
56
|
+
padding_idx: int,
|
|
57
|
+
smoothing: float,
|
|
58
|
+
normalize_length: bool = False):
|
|
59
|
+
"""Construct an LabelSmoothingLoss object."""
|
|
60
|
+
super(LabelSmoothingLoss, self).__init__()
|
|
61
|
+
self.criterion = nn.KLDivLoss(reduction="none")
|
|
62
|
+
self.padding_idx = padding_idx
|
|
63
|
+
self.confidence = 1.0 - smoothing
|
|
64
|
+
self.smoothing = smoothing
|
|
65
|
+
self.size = size
|
|
66
|
+
self.normalize_length = normalize_length
|
|
67
|
+
|
|
68
|
+
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
69
|
+
"""Compute loss between x and target.
|
|
70
|
+
|
|
71
|
+
The model outputs and data labels tensors are flatten to
|
|
72
|
+
(batch*seqlen, class) shape and a mask is applied to the
|
|
73
|
+
padding part which should not be calculated for loss.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
x (torch.Tensor): prediction (batch, seqlen, class)
|
|
77
|
+
target (torch.Tensor):
|
|
78
|
+
target signal masked with self.padding_id (batch, seqlen)
|
|
79
|
+
Returns:
|
|
80
|
+
loss (torch.Tensor) : The KL loss, scalar float value
|
|
81
|
+
"""
|
|
82
|
+
assert x.size(2) == self.size
|
|
83
|
+
batch_size = x.size(0)
|
|
84
|
+
x = x.view(-1, self.size)
|
|
85
|
+
target = target.view(-1)
|
|
86
|
+
# use zeros_like instead of torch.no_grad() for true_dist,
|
|
87
|
+
# since no_grad() can not be exported by JIT
|
|
88
|
+
true_dist = torch.zeros_like(x)
|
|
89
|
+
true_dist.fill_(self.smoothing / (self.size - 1))
|
|
90
|
+
ignore = target == self.padding_idx # (B,)
|
|
91
|
+
total = len(target) - ignore.sum().item()
|
|
92
|
+
target = target.masked_fill(ignore, 0) # avoid -1 index
|
|
93
|
+
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
|
94
|
+
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
|
95
|
+
denom = total if self.normalize_length else batch_size
|
|
96
|
+
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# Copyright (c) 2019 Shigeki Karita
|
|
2
|
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Positionwise feed forward layer definition."""
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PositionwiseFeedForward(torch.nn.Module):
|
|
21
|
+
"""Positionwise feed forward layer.
|
|
22
|
+
|
|
23
|
+
FeedForward are appied on each position of the sequence.
|
|
24
|
+
The output dim is same with the input dim.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
idim (int): Input dimenstion.
|
|
28
|
+
hidden_units (int): The number of hidden units.
|
|
29
|
+
dropout_rate (float): Dropout rate.
|
|
30
|
+
activation (torch.nn.Module): Activation function
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
idim: int,
|
|
36
|
+
hidden_units: int,
|
|
37
|
+
dropout_rate: float,
|
|
38
|
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
|
39
|
+
):
|
|
40
|
+
"""Construct a PositionwiseFeedForward object."""
|
|
41
|
+
super(PositionwiseFeedForward, self).__init__()
|
|
42
|
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
|
43
|
+
self.activation = activation
|
|
44
|
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
|
45
|
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
|
46
|
+
|
|
47
|
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
|
48
|
+
"""Forward function.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
xs: input tensor (B, L, D)
|
|
52
|
+
Returns:
|
|
53
|
+
output tensor, (B, L, D)
|
|
54
|
+
"""
|
|
55
|
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class MoEFFNLayer(torch.nn.Module):
|
|
59
|
+
"""
|
|
60
|
+
Mixture of expert with Positionwise feed forward layer
|
|
61
|
+
See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
|
|
62
|
+
The output dim is same with the input dim.
|
|
63
|
+
|
|
64
|
+
Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
|
|
65
|
+
https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
|
|
66
|
+
Args:
|
|
67
|
+
n_expert: number of expert.
|
|
68
|
+
n_expert_per_token: The actual number of experts used for each frame
|
|
69
|
+
idim (int): Input dimenstion.
|
|
70
|
+
hidden_units (int): The number of hidden units.
|
|
71
|
+
dropout_rate (float): Dropout rate.
|
|
72
|
+
activation (torch.nn.Module): Activation function
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
n_expert: int,
|
|
78
|
+
n_expert_per_token: int,
|
|
79
|
+
idim: int,
|
|
80
|
+
hidden_units: int,
|
|
81
|
+
dropout_rate: float,
|
|
82
|
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
|
83
|
+
):
|
|
84
|
+
super(MoEFFNLayer, self).__init__()
|
|
85
|
+
self.gate = torch.nn.Linear(idim, n_expert, bias=False)
|
|
86
|
+
self.experts = torch.nn.ModuleList(
|
|
87
|
+
PositionwiseFeedForward(idim, hidden_units, dropout_rate,
|
|
88
|
+
activation) for _ in range(n_expert))
|
|
89
|
+
self.n_expert_per_token = n_expert_per_token
|
|
90
|
+
|
|
91
|
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
|
92
|
+
"""Foward function.
|
|
93
|
+
Args:
|
|
94
|
+
xs: input tensor (B, L, D)
|
|
95
|
+
Returns:
|
|
96
|
+
output tensor, (B, L, D)
|
|
97
|
+
|
|
98
|
+
"""
|
|
99
|
+
B, L, D = xs.size(
|
|
100
|
+
) # batch size, sequence length, embedding dimension (idim)
|
|
101
|
+
xs = xs.view(-1, D) # (B*L, D)
|
|
102
|
+
router = self.gate(xs) # (B*L, n_expert)
|
|
103
|
+
logits, indices = torch.topk(
|
|
104
|
+
router, self.n_expert_per_token
|
|
105
|
+
) # probs:(B*L, n_expert), indices: (B*L, n_expert)
|
|
106
|
+
weights = torch.nn.functional.softmax(
|
|
107
|
+
logits, dim=1,
|
|
108
|
+
dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
|
|
109
|
+
output = torch.zeros_like(xs) # (B*L, D)
|
|
110
|
+
for i, expert in enumerate(self.experts):
|
|
111
|
+
mask = indices == i
|
|
112
|
+
batch_idx, ith_expert = torch.where(mask)
|
|
113
|
+
output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
|
|
114
|
+
xs[batch_idx])
|
|
115
|
+
return output.view(B, L, D)
|