xinference 0.13.1__py3-none-any.whl → 0.13.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -1
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +99 -5
- xinference/client/restful/restful_client.py +98 -1
- xinference/core/chat_interface.py +2 -2
- xinference/core/model.py +85 -26
- xinference/core/scheduler.py +4 -4
- xinference/model/audio/chattts.py +40 -8
- xinference/model/audio/core.py +5 -2
- xinference/model/audio/cosyvoice.py +136 -0
- xinference/model/audio/model_spec.json +24 -0
- xinference/model/audio/model_spec_modelscope.json +27 -0
- xinference/model/flexible/launchers/__init__.py +1 -0
- xinference/model/flexible/launchers/image_process_launcher.py +70 -0
- xinference/model/image/core.py +3 -0
- xinference/model/image/model_spec.json +21 -0
- xinference/model/image/stable_diffusion/core.py +49 -7
- xinference/model/llm/llm_family.json +1065 -106
- xinference/model/llm/llm_family.py +26 -6
- xinference/model/llm/llm_family_csghub.json +39 -0
- xinference/model/llm/llm_family_modelscope.json +460 -47
- xinference/model/llm/pytorch/chatglm.py +243 -5
- xinference/model/llm/pytorch/cogvlm2.py +1 -1
- xinference/model/llm/sglang/core.py +7 -2
- xinference/model/llm/utils.py +78 -1
- xinference/model/llm/vllm/core.py +11 -0
- xinference/thirdparty/cosyvoice/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
- xinference/thirdparty/cosyvoice/bin/train.py +136 -0
- xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
- xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
- xinference/thirdparty/cosyvoice/cli/model.py +60 -0
- xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
- xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
- xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
- xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
- xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
- xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
- xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
- xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
- xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
- xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
- xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
- xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
- xinference/thirdparty/cosyvoice/utils/common.py +103 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
- xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/METADATA +18 -8
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/RECORD +80 -36
- xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
- /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
# Copyright (c) 2019 Shigeki Karita
|
|
2
|
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
|
3
|
+
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
|
4
|
+
# 2024 Alibaba Inc (Xiang Lyu)
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""Multi-Head Attention layer definition."""
|
|
18
|
+
|
|
19
|
+
import math
|
|
20
|
+
from typing import Tuple
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
from torch import nn
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MultiHeadedAttention(nn.Module):
|
|
27
|
+
"""Multi-Head Attention layer.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
n_head (int): The number of heads.
|
|
31
|
+
n_feat (int): The number of features.
|
|
32
|
+
dropout_rate (float): Dropout rate.
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self,
|
|
37
|
+
n_head: int,
|
|
38
|
+
n_feat: int,
|
|
39
|
+
dropout_rate: float,
|
|
40
|
+
key_bias: bool = True):
|
|
41
|
+
"""Construct an MultiHeadedAttention object."""
|
|
42
|
+
super().__init__()
|
|
43
|
+
assert n_feat % n_head == 0
|
|
44
|
+
# We assume d_v always equals d_k
|
|
45
|
+
self.d_k = n_feat // n_head
|
|
46
|
+
self.h = n_head
|
|
47
|
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
|
48
|
+
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
|
|
49
|
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
|
50
|
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
|
51
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
|
52
|
+
|
|
53
|
+
def forward_qkv(
|
|
54
|
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
|
55
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
56
|
+
"""Transform query, key and value.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
60
|
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
61
|
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
torch.Tensor: Transformed query tensor, size
|
|
65
|
+
(#batch, n_head, time1, d_k).
|
|
66
|
+
torch.Tensor: Transformed key tensor, size
|
|
67
|
+
(#batch, n_head, time2, d_k).
|
|
68
|
+
torch.Tensor: Transformed value tensor, size
|
|
69
|
+
(#batch, n_head, time2, d_k).
|
|
70
|
+
|
|
71
|
+
"""
|
|
72
|
+
n_batch = query.size(0)
|
|
73
|
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
|
74
|
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
|
75
|
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
|
76
|
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
|
77
|
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
|
78
|
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
|
79
|
+
|
|
80
|
+
return q, k, v
|
|
81
|
+
|
|
82
|
+
def forward_attention(
|
|
83
|
+
self,
|
|
84
|
+
value: torch.Tensor,
|
|
85
|
+
scores: torch.Tensor,
|
|
86
|
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
|
87
|
+
) -> torch.Tensor:
|
|
88
|
+
"""Compute attention context vector.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
value (torch.Tensor): Transformed value, size
|
|
92
|
+
(#batch, n_head, time2, d_k).
|
|
93
|
+
scores (torch.Tensor): Attention score, size
|
|
94
|
+
(#batch, n_head, time1, time2).
|
|
95
|
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
|
96
|
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
|
100
|
+
weighted by the attention score (#batch, time1, time2).
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
n_batch = value.size(0)
|
|
104
|
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
|
|
105
|
+
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
|
106
|
+
# 1st chunk to ease the onnx export.]
|
|
107
|
+
# 2. pytorch training
|
|
108
|
+
if mask.size(2) > 0: # time2 > 0
|
|
109
|
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
|
110
|
+
# For last chunk, time2 might be larger than scores.size(-1)
|
|
111
|
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
|
112
|
+
scores = scores.masked_fill(mask, -float('inf'))
|
|
113
|
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
|
114
|
+
mask, 0.0) # (batch, head, time1, time2)
|
|
115
|
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
|
|
116
|
+
# 1. onnx(16/-1, -1/-1, 16/0)
|
|
117
|
+
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
|
118
|
+
else:
|
|
119
|
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
|
120
|
+
|
|
121
|
+
p_attn = self.dropout(attn)
|
|
122
|
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
|
123
|
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
|
124
|
+
self.h * self.d_k)
|
|
125
|
+
) # (batch, time1, d_model)
|
|
126
|
+
|
|
127
|
+
return self.linear_out(x) # (batch, time1, d_model)
|
|
128
|
+
|
|
129
|
+
def forward(
|
|
130
|
+
self,
|
|
131
|
+
query: torch.Tensor,
|
|
132
|
+
key: torch.Tensor,
|
|
133
|
+
value: torch.Tensor,
|
|
134
|
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
135
|
+
pos_emb: torch.Tensor = torch.empty(0),
|
|
136
|
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
|
137
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
138
|
+
"""Compute scaled dot product attention.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
142
|
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
143
|
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
144
|
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
|
145
|
+
(#batch, time1, time2).
|
|
146
|
+
1.When applying cross attention between decoder and encoder,
|
|
147
|
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
|
148
|
+
2.When applying self attention of encoder,
|
|
149
|
+
the mask is in (#batch, T, T) shape.
|
|
150
|
+
3.When applying self attention of decoder,
|
|
151
|
+
the mask is in (#batch, L, L) shape.
|
|
152
|
+
4.If the different position in decoder see different block
|
|
153
|
+
of the encoder, such as Mocha, the passed in mask could be
|
|
154
|
+
in (#batch, L, T) shape. But there is no such case in current
|
|
155
|
+
CosyVoice.
|
|
156
|
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
|
157
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
158
|
+
and `head * d_k == size`
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
|
163
|
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
|
164
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
165
|
+
and `head * d_k == size`
|
|
166
|
+
|
|
167
|
+
"""
|
|
168
|
+
q, k, v = self.forward_qkv(query, key, value)
|
|
169
|
+
|
|
170
|
+
# NOTE(xcsong):
|
|
171
|
+
# when export onnx model, for 1st chunk, we feed
|
|
172
|
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
|
173
|
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
|
174
|
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
|
175
|
+
# and we will always do splitting and
|
|
176
|
+
# concatnation(this will simplify onnx export). Note that
|
|
177
|
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
|
178
|
+
# when export jit model, for 1st chunk, we always feed
|
|
179
|
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
|
180
|
+
# >>> a = torch.ones((1, 2, 0, 4))
|
|
181
|
+
# >>> b = torch.ones((1, 2, 3, 4))
|
|
182
|
+
# >>> c = torch.cat((a, b), dim=2)
|
|
183
|
+
# >>> torch.equal(b, c) # True
|
|
184
|
+
# >>> d = torch.split(a, 2, dim=-1)
|
|
185
|
+
# >>> torch.equal(d[0], d[1]) # True
|
|
186
|
+
if cache.size(0) > 0:
|
|
187
|
+
key_cache, value_cache = torch.split(cache,
|
|
188
|
+
cache.size(-1) // 2,
|
|
189
|
+
dim=-1)
|
|
190
|
+
k = torch.cat([key_cache, k], dim=2)
|
|
191
|
+
v = torch.cat([value_cache, v], dim=2)
|
|
192
|
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
|
193
|
+
# non-trivial to calculate `next_cache_start` here.
|
|
194
|
+
new_cache = torch.cat((k, v), dim=-1)
|
|
195
|
+
|
|
196
|
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
|
197
|
+
return self.forward_attention(v, scores, mask), new_cache
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
201
|
+
"""Multi-Head Attention layer with relative position encoding.
|
|
202
|
+
Paper: https://arxiv.org/abs/1901.02860
|
|
203
|
+
Args:
|
|
204
|
+
n_head (int): The number of heads.
|
|
205
|
+
n_feat (int): The number of features.
|
|
206
|
+
dropout_rate (float): Dropout rate.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(self,
|
|
210
|
+
n_head: int,
|
|
211
|
+
n_feat: int,
|
|
212
|
+
dropout_rate: float,
|
|
213
|
+
key_bias: bool = True):
|
|
214
|
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
|
215
|
+
super().__init__(n_head, n_feat, dropout_rate, key_bias)
|
|
216
|
+
# linear transformation for positional encoding
|
|
217
|
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
|
218
|
+
# these two learnable bias are used in matrix c and matrix d
|
|
219
|
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
|
220
|
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
|
221
|
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
|
222
|
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
|
223
|
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
|
224
|
+
|
|
225
|
+
def rel_shift(self, x):
|
|
226
|
+
"""Compute relative positional encoding.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
|
230
|
+
time1 means the length of query vector.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
torch.Tensor: Output tensor.
|
|
234
|
+
|
|
235
|
+
"""
|
|
236
|
+
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
|
237
|
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
|
238
|
+
|
|
239
|
+
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
|
240
|
+
x = x_padded[:, :, 1:].view_as(x)[
|
|
241
|
+
:, :, :, : x.size(-1) // 2 + 1
|
|
242
|
+
] # only keep the positions from 0 to time2
|
|
243
|
+
return x
|
|
244
|
+
|
|
245
|
+
def forward(
|
|
246
|
+
self,
|
|
247
|
+
query: torch.Tensor,
|
|
248
|
+
key: torch.Tensor,
|
|
249
|
+
value: torch.Tensor,
|
|
250
|
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
251
|
+
pos_emb: torch.Tensor = torch.empty(0),
|
|
252
|
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
|
253
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
254
|
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
|
255
|
+
Args:
|
|
256
|
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
257
|
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
258
|
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
259
|
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
|
260
|
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
|
261
|
+
pos_emb (torch.Tensor): Positional embedding tensor
|
|
262
|
+
(#batch, time2, size).
|
|
263
|
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
|
264
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
265
|
+
and `head * d_k == size`
|
|
266
|
+
Returns:
|
|
267
|
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
|
268
|
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
|
269
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
270
|
+
and `head * d_k == size`
|
|
271
|
+
"""
|
|
272
|
+
q, k, v = self.forward_qkv(query, key, value)
|
|
273
|
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
|
274
|
+
|
|
275
|
+
# NOTE(xcsong):
|
|
276
|
+
# when export onnx model, for 1st chunk, we feed
|
|
277
|
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
|
278
|
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
|
279
|
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
|
280
|
+
# and we will always do splitting and
|
|
281
|
+
# concatnation(this will simplify onnx export). Note that
|
|
282
|
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
|
283
|
+
# when export jit model, for 1st chunk, we always feed
|
|
284
|
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
|
285
|
+
# >>> a = torch.ones((1, 2, 0, 4))
|
|
286
|
+
# >>> b = torch.ones((1, 2, 3, 4))
|
|
287
|
+
# >>> c = torch.cat((a, b), dim=2)
|
|
288
|
+
# >>> torch.equal(b, c) # True
|
|
289
|
+
# >>> d = torch.split(a, 2, dim=-1)
|
|
290
|
+
# >>> torch.equal(d[0], d[1]) # True
|
|
291
|
+
if cache.size(0) > 0:
|
|
292
|
+
key_cache, value_cache = torch.split(cache,
|
|
293
|
+
cache.size(-1) // 2,
|
|
294
|
+
dim=-1)
|
|
295
|
+
k = torch.cat([key_cache, k], dim=2)
|
|
296
|
+
v = torch.cat([value_cache, v], dim=2)
|
|
297
|
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
|
298
|
+
# non-trivial to calculate `next_cache_start` here.
|
|
299
|
+
new_cache = torch.cat((k, v), dim=-1)
|
|
300
|
+
|
|
301
|
+
n_batch_pos = pos_emb.size(0)
|
|
302
|
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
|
303
|
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
|
304
|
+
|
|
305
|
+
# (batch, head, time1, d_k)
|
|
306
|
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
|
307
|
+
# (batch, head, time1, d_k)
|
|
308
|
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
|
309
|
+
|
|
310
|
+
# compute attention score
|
|
311
|
+
# first compute matrix a and matrix c
|
|
312
|
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
|
313
|
+
# (batch, head, time1, time2)
|
|
314
|
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
|
315
|
+
|
|
316
|
+
# compute matrix b and matrix d
|
|
317
|
+
# (batch, head, time1, time2)
|
|
318
|
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
|
319
|
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
|
320
|
+
if matrix_ac.shape != matrix_bd.shape:
|
|
321
|
+
matrix_bd = self.rel_shift(matrix_bd)
|
|
322
|
+
|
|
323
|
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
|
324
|
+
self.d_k) # (batch, head, time1, time2)
|
|
325
|
+
|
|
326
|
+
return self.forward_attention(v, scores, mask), new_cache
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
|
2
|
+
# 2024 Alibaba Inc (Xiang Lyu)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
|
16
|
+
"""ConvolutionModule definition."""
|
|
17
|
+
|
|
18
|
+
from typing import Tuple
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from torch import nn
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ConvolutionModule(nn.Module):
|
|
25
|
+
"""ConvolutionModule in Conformer model."""
|
|
26
|
+
|
|
27
|
+
def __init__(self,
|
|
28
|
+
channels: int,
|
|
29
|
+
kernel_size: int = 15,
|
|
30
|
+
activation: nn.Module = nn.ReLU(),
|
|
31
|
+
norm: str = "batch_norm",
|
|
32
|
+
causal: bool = False,
|
|
33
|
+
bias: bool = True):
|
|
34
|
+
"""Construct an ConvolutionModule object.
|
|
35
|
+
Args:
|
|
36
|
+
channels (int): The number of channels of conv layers.
|
|
37
|
+
kernel_size (int): Kernel size of conv layers.
|
|
38
|
+
causal (int): Whether use causal convolution or not
|
|
39
|
+
"""
|
|
40
|
+
super().__init__()
|
|
41
|
+
|
|
42
|
+
self.pointwise_conv1 = nn.Conv1d(
|
|
43
|
+
channels,
|
|
44
|
+
2 * channels,
|
|
45
|
+
kernel_size=1,
|
|
46
|
+
stride=1,
|
|
47
|
+
padding=0,
|
|
48
|
+
bias=bias,
|
|
49
|
+
)
|
|
50
|
+
# self.lorder is used to distinguish if it's a causal convolution,
|
|
51
|
+
# if self.lorder > 0: it's a causal convolution, the input will be
|
|
52
|
+
# padded with self.lorder frames on the left in forward.
|
|
53
|
+
# else: it's a symmetrical convolution
|
|
54
|
+
if causal:
|
|
55
|
+
padding = 0
|
|
56
|
+
self.lorder = kernel_size - 1
|
|
57
|
+
else:
|
|
58
|
+
# kernel_size should be an odd number for none causal convolution
|
|
59
|
+
assert (kernel_size - 1) % 2 == 0
|
|
60
|
+
padding = (kernel_size - 1) // 2
|
|
61
|
+
self.lorder = 0
|
|
62
|
+
self.depthwise_conv = nn.Conv1d(
|
|
63
|
+
channels,
|
|
64
|
+
channels,
|
|
65
|
+
kernel_size,
|
|
66
|
+
stride=1,
|
|
67
|
+
padding=padding,
|
|
68
|
+
groups=channels,
|
|
69
|
+
bias=bias,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
assert norm in ['batch_norm', 'layer_norm']
|
|
73
|
+
if norm == "batch_norm":
|
|
74
|
+
self.use_layer_norm = False
|
|
75
|
+
self.norm = nn.BatchNorm1d(channels)
|
|
76
|
+
else:
|
|
77
|
+
self.use_layer_norm = True
|
|
78
|
+
self.norm = nn.LayerNorm(channels)
|
|
79
|
+
|
|
80
|
+
self.pointwise_conv2 = nn.Conv1d(
|
|
81
|
+
channels,
|
|
82
|
+
channels,
|
|
83
|
+
kernel_size=1,
|
|
84
|
+
stride=1,
|
|
85
|
+
padding=0,
|
|
86
|
+
bias=bias,
|
|
87
|
+
)
|
|
88
|
+
self.activation = activation
|
|
89
|
+
|
|
90
|
+
def forward(
|
|
91
|
+
self,
|
|
92
|
+
x: torch.Tensor,
|
|
93
|
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
94
|
+
cache: torch.Tensor = torch.zeros((0, 0, 0)),
|
|
95
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
96
|
+
"""Compute convolution module.
|
|
97
|
+
Args:
|
|
98
|
+
x (torch.Tensor): Input tensor (#batch, time, channels).
|
|
99
|
+
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
|
|
100
|
+
(0, 0, 0) means fake mask.
|
|
101
|
+
cache (torch.Tensor): left context cache, it is only
|
|
102
|
+
used in causal convolution (#batch, channels, cache_t),
|
|
103
|
+
(0, 0, 0) meas fake cache.
|
|
104
|
+
Returns:
|
|
105
|
+
torch.Tensor: Output tensor (#batch, time, channels).
|
|
106
|
+
"""
|
|
107
|
+
# exchange the temporal dimension and the feature dimension
|
|
108
|
+
x = x.transpose(1, 2) # (#batch, channels, time)
|
|
109
|
+
|
|
110
|
+
# mask batch padding
|
|
111
|
+
if mask_pad.size(2) > 0: # time > 0
|
|
112
|
+
x.masked_fill_(~mask_pad, 0.0)
|
|
113
|
+
|
|
114
|
+
if self.lorder > 0:
|
|
115
|
+
if cache.size(2) == 0: # cache_t == 0
|
|
116
|
+
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
|
|
117
|
+
else:
|
|
118
|
+
assert cache.size(0) == x.size(0) # equal batch
|
|
119
|
+
assert cache.size(1) == x.size(1) # equal channel
|
|
120
|
+
x = torch.cat((cache, x), dim=2)
|
|
121
|
+
assert (x.size(2) > self.lorder)
|
|
122
|
+
new_cache = x[:, :, -self.lorder:]
|
|
123
|
+
else:
|
|
124
|
+
# It's better we just return None if no cache is required,
|
|
125
|
+
# However, for JIT export, here we just fake one tensor instead of
|
|
126
|
+
# None.
|
|
127
|
+
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
|
128
|
+
|
|
129
|
+
# GLU mechanism
|
|
130
|
+
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
|
131
|
+
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
|
132
|
+
|
|
133
|
+
# 1D Depthwise Conv
|
|
134
|
+
x = self.depthwise_conv(x)
|
|
135
|
+
if self.use_layer_norm:
|
|
136
|
+
x = x.transpose(1, 2)
|
|
137
|
+
x = self.activation(self.norm(x))
|
|
138
|
+
if self.use_layer_norm:
|
|
139
|
+
x = x.transpose(1, 2)
|
|
140
|
+
x = self.pointwise_conv2(x)
|
|
141
|
+
# mask batch padding
|
|
142
|
+
if mask_pad.size(2) > 0: # time > 0
|
|
143
|
+
x.masked_fill_(~mask_pad, 0.0)
|
|
144
|
+
|
|
145
|
+
return x.transpose(1, 2), new_cache
|