xinference 0.13.1__py3-none-any.whl → 0.13.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (82) hide show
  1. xinference/__init__.py +0 -1
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +99 -5
  4. xinference/client/restful/restful_client.py +98 -1
  5. xinference/core/chat_interface.py +2 -2
  6. xinference/core/model.py +85 -26
  7. xinference/core/scheduler.py +4 -4
  8. xinference/model/audio/chattts.py +40 -8
  9. xinference/model/audio/core.py +5 -2
  10. xinference/model/audio/cosyvoice.py +136 -0
  11. xinference/model/audio/model_spec.json +24 -0
  12. xinference/model/audio/model_spec_modelscope.json +27 -0
  13. xinference/model/flexible/launchers/__init__.py +1 -0
  14. xinference/model/flexible/launchers/image_process_launcher.py +70 -0
  15. xinference/model/image/core.py +3 -0
  16. xinference/model/image/model_spec.json +21 -0
  17. xinference/model/image/stable_diffusion/core.py +49 -7
  18. xinference/model/llm/llm_family.json +1065 -106
  19. xinference/model/llm/llm_family.py +26 -6
  20. xinference/model/llm/llm_family_csghub.json +39 -0
  21. xinference/model/llm/llm_family_modelscope.json +460 -47
  22. xinference/model/llm/pytorch/chatglm.py +243 -5
  23. xinference/model/llm/pytorch/cogvlm2.py +1 -1
  24. xinference/model/llm/sglang/core.py +7 -2
  25. xinference/model/llm/utils.py +78 -1
  26. xinference/model/llm/vllm/core.py +11 -0
  27. xinference/thirdparty/cosyvoice/__init__.py +0 -0
  28. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  29. xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
  30. xinference/thirdparty/cosyvoice/bin/train.py +136 -0
  31. xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
  32. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
  33. xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
  34. xinference/thirdparty/cosyvoice/cli/model.py +60 -0
  35. xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
  36. xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
  37. xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
  38. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  39. xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
  40. xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
  41. xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
  42. xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
  43. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  44. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
  45. xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
  46. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  47. xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
  48. xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
  49. xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
  50. xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
  51. xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
  52. xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
  53. xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
  54. xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
  55. xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
  56. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
  57. xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
  58. xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  59. xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
  60. xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
  61. xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
  62. xinference/thirdparty/cosyvoice/utils/common.py +103 -0
  63. xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
  64. xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
  65. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
  66. xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
  67. xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
  68. xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
  69. xinference/web/ui/build/asset-manifest.json +3 -3
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
  72. xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
  74. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/METADATA +18 -8
  75. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/RECORD +80 -36
  76. xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
  78. /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
  79. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
  80. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
  81. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
  82. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,326 @@
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ class MultiHeadedAttention(nn.Module):
27
+ """Multi-Head Attention layer.
28
+
29
+ Args:
30
+ n_head (int): The number of heads.
31
+ n_feat (int): The number of features.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+
36
+ def __init__(self,
37
+ n_head: int,
38
+ n_feat: int,
39
+ dropout_rate: float,
40
+ key_bias: bool = True):
41
+ """Construct an MultiHeadedAttention object."""
42
+ super().__init__()
43
+ assert n_feat % n_head == 0
44
+ # We assume d_v always equals d_k
45
+ self.d_k = n_feat // n_head
46
+ self.h = n_head
47
+ self.linear_q = nn.Linear(n_feat, n_feat)
48
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
49
+ self.linear_v = nn.Linear(n_feat, n_feat)
50
+ self.linear_out = nn.Linear(n_feat, n_feat)
51
+ self.dropout = nn.Dropout(p=dropout_rate)
52
+
53
+ def forward_qkv(
54
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
55
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56
+ """Transform query, key and value.
57
+
58
+ Args:
59
+ query (torch.Tensor): Query tensor (#batch, time1, size).
60
+ key (torch.Tensor): Key tensor (#batch, time2, size).
61
+ value (torch.Tensor): Value tensor (#batch, time2, size).
62
+
63
+ Returns:
64
+ torch.Tensor: Transformed query tensor, size
65
+ (#batch, n_head, time1, d_k).
66
+ torch.Tensor: Transformed key tensor, size
67
+ (#batch, n_head, time2, d_k).
68
+ torch.Tensor: Transformed value tensor, size
69
+ (#batch, n_head, time2, d_k).
70
+
71
+ """
72
+ n_batch = query.size(0)
73
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
74
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
75
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
76
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
77
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
78
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
79
+
80
+ return q, k, v
81
+
82
+ def forward_attention(
83
+ self,
84
+ value: torch.Tensor,
85
+ scores: torch.Tensor,
86
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
87
+ ) -> torch.Tensor:
88
+ """Compute attention context vector.
89
+
90
+ Args:
91
+ value (torch.Tensor): Transformed value, size
92
+ (#batch, n_head, time2, d_k).
93
+ scores (torch.Tensor): Attention score, size
94
+ (#batch, n_head, time1, time2).
95
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
96
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
97
+
98
+ Returns:
99
+ torch.Tensor: Transformed value (#batch, time1, d_model)
100
+ weighted by the attention score (#batch, time1, time2).
101
+
102
+ """
103
+ n_batch = value.size(0)
104
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
105
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
106
+ # 1st chunk to ease the onnx export.]
107
+ # 2. pytorch training
108
+ if mask.size(2) > 0: # time2 > 0
109
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
110
+ # For last chunk, time2 might be larger than scores.size(-1)
111
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
112
+ scores = scores.masked_fill(mask, -float('inf'))
113
+ attn = torch.softmax(scores, dim=-1).masked_fill(
114
+ mask, 0.0) # (batch, head, time1, time2)
115
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
116
+ # 1. onnx(16/-1, -1/-1, 16/0)
117
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
118
+ else:
119
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
120
+
121
+ p_attn = self.dropout(attn)
122
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
123
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
124
+ self.h * self.d_k)
125
+ ) # (batch, time1, d_model)
126
+
127
+ return self.linear_out(x) # (batch, time1, d_model)
128
+
129
+ def forward(
130
+ self,
131
+ query: torch.Tensor,
132
+ key: torch.Tensor,
133
+ value: torch.Tensor,
134
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
135
+ pos_emb: torch.Tensor = torch.empty(0),
136
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """Compute scaled dot product attention.
139
+
140
+ Args:
141
+ query (torch.Tensor): Query tensor (#batch, time1, size).
142
+ key (torch.Tensor): Key tensor (#batch, time2, size).
143
+ value (torch.Tensor): Value tensor (#batch, time2, size).
144
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
145
+ (#batch, time1, time2).
146
+ 1.When applying cross attention between decoder and encoder,
147
+ the batch padding mask for input is in (#batch, 1, T) shape.
148
+ 2.When applying self attention of encoder,
149
+ the mask is in (#batch, T, T) shape.
150
+ 3.When applying self attention of decoder,
151
+ the mask is in (#batch, L, L) shape.
152
+ 4.If the different position in decoder see different block
153
+ of the encoder, such as Mocha, the passed in mask could be
154
+ in (#batch, L, T) shape. But there is no such case in current
155
+ CosyVoice.
156
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
157
+ where `cache_t == chunk_size * num_decoding_left_chunks`
158
+ and `head * d_k == size`
159
+
160
+
161
+ Returns:
162
+ torch.Tensor: Output tensor (#batch, time1, d_model).
163
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
164
+ where `cache_t == chunk_size * num_decoding_left_chunks`
165
+ and `head * d_k == size`
166
+
167
+ """
168
+ q, k, v = self.forward_qkv(query, key, value)
169
+
170
+ # NOTE(xcsong):
171
+ # when export onnx model, for 1st chunk, we feed
172
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
173
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
174
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
175
+ # and we will always do splitting and
176
+ # concatnation(this will simplify onnx export). Note that
177
+ # it's OK to concat & split zero-shaped tensors(see code below).
178
+ # when export jit model, for 1st chunk, we always feed
179
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
180
+ # >>> a = torch.ones((1, 2, 0, 4))
181
+ # >>> b = torch.ones((1, 2, 3, 4))
182
+ # >>> c = torch.cat((a, b), dim=2)
183
+ # >>> torch.equal(b, c) # True
184
+ # >>> d = torch.split(a, 2, dim=-1)
185
+ # >>> torch.equal(d[0], d[1]) # True
186
+ if cache.size(0) > 0:
187
+ key_cache, value_cache = torch.split(cache,
188
+ cache.size(-1) // 2,
189
+ dim=-1)
190
+ k = torch.cat([key_cache, k], dim=2)
191
+ v = torch.cat([value_cache, v], dim=2)
192
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
193
+ # non-trivial to calculate `next_cache_start` here.
194
+ new_cache = torch.cat((k, v), dim=-1)
195
+
196
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
197
+ return self.forward_attention(v, scores, mask), new_cache
198
+
199
+
200
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
201
+ """Multi-Head Attention layer with relative position encoding.
202
+ Paper: https://arxiv.org/abs/1901.02860
203
+ Args:
204
+ n_head (int): The number of heads.
205
+ n_feat (int): The number of features.
206
+ dropout_rate (float): Dropout rate.
207
+ """
208
+
209
+ def __init__(self,
210
+ n_head: int,
211
+ n_feat: int,
212
+ dropout_rate: float,
213
+ key_bias: bool = True):
214
+ """Construct an RelPositionMultiHeadedAttention object."""
215
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
216
+ # linear transformation for positional encoding
217
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
218
+ # these two learnable bias are used in matrix c and matrix d
219
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
220
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
221
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
222
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
+
225
+ def rel_shift(self, x):
226
+ """Compute relative positional encoding.
227
+
228
+ Args:
229
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
230
+ time1 means the length of query vector.
231
+
232
+ Returns:
233
+ torch.Tensor: Output tensor.
234
+
235
+ """
236
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
237
+ x_padded = torch.cat([zero_pad, x], dim=-1)
238
+
239
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
240
+ x = x_padded[:, :, 1:].view_as(x)[
241
+ :, :, :, : x.size(-1) // 2 + 1
242
+ ] # only keep the positions from 0 to time2
243
+ return x
244
+
245
+ def forward(
246
+ self,
247
+ query: torch.Tensor,
248
+ key: torch.Tensor,
249
+ value: torch.Tensor,
250
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
251
+ pos_emb: torch.Tensor = torch.empty(0),
252
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
253
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
254
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
255
+ Args:
256
+ query (torch.Tensor): Query tensor (#batch, time1, size).
257
+ key (torch.Tensor): Key tensor (#batch, time2, size).
258
+ value (torch.Tensor): Value tensor (#batch, time2, size).
259
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
260
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
261
+ pos_emb (torch.Tensor): Positional embedding tensor
262
+ (#batch, time2, size).
263
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
264
+ where `cache_t == chunk_size * num_decoding_left_chunks`
265
+ and `head * d_k == size`
266
+ Returns:
267
+ torch.Tensor: Output tensor (#batch, time1, d_model).
268
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
269
+ where `cache_t == chunk_size * num_decoding_left_chunks`
270
+ and `head * d_k == size`
271
+ """
272
+ q, k, v = self.forward_qkv(query, key, value)
273
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
274
+
275
+ # NOTE(xcsong):
276
+ # when export onnx model, for 1st chunk, we feed
277
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
278
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
279
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
280
+ # and we will always do splitting and
281
+ # concatnation(this will simplify onnx export). Note that
282
+ # it's OK to concat & split zero-shaped tensors(see code below).
283
+ # when export jit model, for 1st chunk, we always feed
284
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
285
+ # >>> a = torch.ones((1, 2, 0, 4))
286
+ # >>> b = torch.ones((1, 2, 3, 4))
287
+ # >>> c = torch.cat((a, b), dim=2)
288
+ # >>> torch.equal(b, c) # True
289
+ # >>> d = torch.split(a, 2, dim=-1)
290
+ # >>> torch.equal(d[0], d[1]) # True
291
+ if cache.size(0) > 0:
292
+ key_cache, value_cache = torch.split(cache,
293
+ cache.size(-1) // 2,
294
+ dim=-1)
295
+ k = torch.cat([key_cache, k], dim=2)
296
+ v = torch.cat([value_cache, v], dim=2)
297
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
298
+ # non-trivial to calculate `next_cache_start` here.
299
+ new_cache = torch.cat((k, v), dim=-1)
300
+
301
+ n_batch_pos = pos_emb.size(0)
302
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
303
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
304
+
305
+ # (batch, head, time1, d_k)
306
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
307
+ # (batch, head, time1, d_k)
308
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
309
+
310
+ # compute attention score
311
+ # first compute matrix a and matrix c
312
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
313
+ # (batch, head, time1, time2)
314
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
315
+
316
+ # compute matrix b and matrix d
317
+ # (batch, head, time1, time2)
318
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
319
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
320
+ if matrix_ac.shape != matrix_bd.shape:
321
+ matrix_bd = self.rel_shift(matrix_bd)
322
+
323
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
324
+ self.d_k) # (batch, head, time1, time2)
325
+
326
+ return self.forward_attention(v, scores, mask), new_cache
@@ -0,0 +1,145 @@
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """ConvolutionModule definition."""
17
+
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class ConvolutionModule(nn.Module):
25
+ """ConvolutionModule in Conformer model."""
26
+
27
+ def __init__(self,
28
+ channels: int,
29
+ kernel_size: int = 15,
30
+ activation: nn.Module = nn.ReLU(),
31
+ norm: str = "batch_norm",
32
+ causal: bool = False,
33
+ bias: bool = True):
34
+ """Construct an ConvolutionModule object.
35
+ Args:
36
+ channels (int): The number of channels of conv layers.
37
+ kernel_size (int): Kernel size of conv layers.
38
+ causal (int): Whether use causal convolution or not
39
+ """
40
+ super().__init__()
41
+
42
+ self.pointwise_conv1 = nn.Conv1d(
43
+ channels,
44
+ 2 * channels,
45
+ kernel_size=1,
46
+ stride=1,
47
+ padding=0,
48
+ bias=bias,
49
+ )
50
+ # self.lorder is used to distinguish if it's a causal convolution,
51
+ # if self.lorder > 0: it's a causal convolution, the input will be
52
+ # padded with self.lorder frames on the left in forward.
53
+ # else: it's a symmetrical convolution
54
+ if causal:
55
+ padding = 0
56
+ self.lorder = kernel_size - 1
57
+ else:
58
+ # kernel_size should be an odd number for none causal convolution
59
+ assert (kernel_size - 1) % 2 == 0
60
+ padding = (kernel_size - 1) // 2
61
+ self.lorder = 0
62
+ self.depthwise_conv = nn.Conv1d(
63
+ channels,
64
+ channels,
65
+ kernel_size,
66
+ stride=1,
67
+ padding=padding,
68
+ groups=channels,
69
+ bias=bias,
70
+ )
71
+
72
+ assert norm in ['batch_norm', 'layer_norm']
73
+ if norm == "batch_norm":
74
+ self.use_layer_norm = False
75
+ self.norm = nn.BatchNorm1d(channels)
76
+ else:
77
+ self.use_layer_norm = True
78
+ self.norm = nn.LayerNorm(channels)
79
+
80
+ self.pointwise_conv2 = nn.Conv1d(
81
+ channels,
82
+ channels,
83
+ kernel_size=1,
84
+ stride=1,
85
+ padding=0,
86
+ bias=bias,
87
+ )
88
+ self.activation = activation
89
+
90
+ def forward(
91
+ self,
92
+ x: torch.Tensor,
93
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
94
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
95
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
96
+ """Compute convolution module.
97
+ Args:
98
+ x (torch.Tensor): Input tensor (#batch, time, channels).
99
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
100
+ (0, 0, 0) means fake mask.
101
+ cache (torch.Tensor): left context cache, it is only
102
+ used in causal convolution (#batch, channels, cache_t),
103
+ (0, 0, 0) meas fake cache.
104
+ Returns:
105
+ torch.Tensor: Output tensor (#batch, time, channels).
106
+ """
107
+ # exchange the temporal dimension and the feature dimension
108
+ x = x.transpose(1, 2) # (#batch, channels, time)
109
+
110
+ # mask batch padding
111
+ if mask_pad.size(2) > 0: # time > 0
112
+ x.masked_fill_(~mask_pad, 0.0)
113
+
114
+ if self.lorder > 0:
115
+ if cache.size(2) == 0: # cache_t == 0
116
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
117
+ else:
118
+ assert cache.size(0) == x.size(0) # equal batch
119
+ assert cache.size(1) == x.size(1) # equal channel
120
+ x = torch.cat((cache, x), dim=2)
121
+ assert (x.size(2) > self.lorder)
122
+ new_cache = x[:, :, -self.lorder:]
123
+ else:
124
+ # It's better we just return None if no cache is required,
125
+ # However, for JIT export, here we just fake one tensor instead of
126
+ # None.
127
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
128
+
129
+ # GLU mechanism
130
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
131
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
132
+
133
+ # 1D Depthwise Conv
134
+ x = self.depthwise_conv(x)
135
+ if self.use_layer_norm:
136
+ x = x.transpose(1, 2)
137
+ x = self.activation(self.norm(x))
138
+ if self.use_layer_norm:
139
+ x = x.transpose(1, 2)
140
+ x = self.pointwise_conv2(x)
141
+ # mask batch padding
142
+ if mask_pad.size(2) > 0: # time > 0
143
+ x.masked_fill_(~mask_pad, 0.0)
144
+
145
+ return x.transpose(1, 2), new_cache