xinference 0.13.1__py3-none-any.whl → 0.13.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (82) hide show
  1. xinference/__init__.py +0 -1
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +99 -5
  4. xinference/client/restful/restful_client.py +98 -1
  5. xinference/core/chat_interface.py +2 -2
  6. xinference/core/model.py +85 -26
  7. xinference/core/scheduler.py +4 -4
  8. xinference/model/audio/chattts.py +40 -8
  9. xinference/model/audio/core.py +5 -2
  10. xinference/model/audio/cosyvoice.py +136 -0
  11. xinference/model/audio/model_spec.json +24 -0
  12. xinference/model/audio/model_spec_modelscope.json +27 -0
  13. xinference/model/flexible/launchers/__init__.py +1 -0
  14. xinference/model/flexible/launchers/image_process_launcher.py +70 -0
  15. xinference/model/image/core.py +3 -0
  16. xinference/model/image/model_spec.json +21 -0
  17. xinference/model/image/stable_diffusion/core.py +49 -7
  18. xinference/model/llm/llm_family.json +1065 -106
  19. xinference/model/llm/llm_family.py +26 -6
  20. xinference/model/llm/llm_family_csghub.json +39 -0
  21. xinference/model/llm/llm_family_modelscope.json +460 -47
  22. xinference/model/llm/pytorch/chatglm.py +243 -5
  23. xinference/model/llm/pytorch/cogvlm2.py +1 -1
  24. xinference/model/llm/sglang/core.py +7 -2
  25. xinference/model/llm/utils.py +78 -1
  26. xinference/model/llm/vllm/core.py +11 -0
  27. xinference/thirdparty/cosyvoice/__init__.py +0 -0
  28. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  29. xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
  30. xinference/thirdparty/cosyvoice/bin/train.py +136 -0
  31. xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
  32. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
  33. xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
  34. xinference/thirdparty/cosyvoice/cli/model.py +60 -0
  35. xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
  36. xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
  37. xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
  38. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  39. xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
  40. xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
  41. xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
  42. xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
  43. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  44. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
  45. xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
  46. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  47. xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
  48. xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
  49. xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
  50. xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
  51. xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
  52. xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
  53. xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
  54. xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
  55. xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
  56. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
  57. xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
  58. xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  59. xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
  60. xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
  61. xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
  62. xinference/thirdparty/cosyvoice/utils/common.py +103 -0
  63. xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
  64. xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
  65. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
  66. xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
  67. xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
  68. xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
  69. xinference/web/ui/build/asset-manifest.json +3 -3
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
  72. xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
  74. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/METADATA +18 -8
  75. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/RECORD +80 -36
  76. xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
  78. /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
  79. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
  80. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
  81. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
  82. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,391 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ import typing as tp
18
+ import numpy as np
19
+ from scipy.signal import get_window
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.nn import Conv1d
24
+ from torch.nn import ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+ from torch.nn.utils import weight_norm
27
+ from torch.distributions.uniform import Uniform
28
+
29
+ from cosyvoice.transformer.activation import Snake
30
+ from cosyvoice.utils.common import get_padding
31
+ from cosyvoice.utils.common import init_weights
32
+
33
+
34
+ """hifigan based generator implementation.
35
+
36
+ This code is modified from https://github.com/jik876/hifi-gan
37
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
38
+ https://github.com/NVIDIA/BigVGAN
39
+
40
+ """
41
+ class ResBlock(torch.nn.Module):
42
+ """Residual block module in HiFiGAN/BigVGAN."""
43
+ def __init__(
44
+ self,
45
+ channels: int = 512,
46
+ kernel_size: int = 3,
47
+ dilations: tp.List[int] = [1, 3, 5],
48
+ ):
49
+ super(ResBlock, self).__init__()
50
+ self.convs1 = nn.ModuleList()
51
+ self.convs2 = nn.ModuleList()
52
+
53
+ for dilation in dilations:
54
+ self.convs1.append(
55
+ weight_norm(
56
+ Conv1d(
57
+ channels,
58
+ channels,
59
+ kernel_size,
60
+ 1,
61
+ dilation=dilation,
62
+ padding=get_padding(kernel_size, dilation)
63
+ )
64
+ )
65
+ )
66
+ self.convs2.append(
67
+ weight_norm(
68
+ Conv1d(
69
+ channels,
70
+ channels,
71
+ kernel_size,
72
+ 1,
73
+ dilation=1,
74
+ padding=get_padding(kernel_size, 1)
75
+ )
76
+ )
77
+ )
78
+ self.convs1.apply(init_weights)
79
+ self.convs2.apply(init_weights)
80
+ self.activations1 = nn.ModuleList([
81
+ Snake(channels, alpha_logscale=False)
82
+ for _ in range(len(self.convs1))
83
+ ])
84
+ self.activations2 = nn.ModuleList([
85
+ Snake(channels, alpha_logscale=False)
86
+ for _ in range(len(self.convs2))
87
+ ])
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ for idx in range(len(self.convs1)):
91
+ xt = self.activations1[idx](x)
92
+ xt = self.convs1[idx](xt)
93
+ xt = self.activations2[idx](xt)
94
+ xt = self.convs2[idx](xt)
95
+ x = xt + x
96
+ return x
97
+
98
+ def remove_weight_norm(self):
99
+ for idx in range(len(self.convs1)):
100
+ remove_weight_norm(self.convs1[idx])
101
+ remove_weight_norm(self.convs2[idx])
102
+
103
+ class SineGen(torch.nn.Module):
104
+ """ Definition of sine generator
105
+ SineGen(samp_rate, harmonic_num = 0,
106
+ sine_amp = 0.1, noise_std = 0.003,
107
+ voiced_threshold = 0,
108
+ flag_for_pulse=False)
109
+ samp_rate: sampling rate in Hz
110
+ harmonic_num: number of harmonic overtones (default 0)
111
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
112
+ noise_std: std of Gaussian noise (default 0.003)
113
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
114
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
115
+ Note: when flag_for_pulse is True, the first time step of a voiced
116
+ segment is always sin(np.pi) or cos(0)
117
+ """
118
+
119
+ def __init__(self, samp_rate, harmonic_num=0,
120
+ sine_amp=0.1, noise_std=0.003,
121
+ voiced_threshold=0):
122
+ super(SineGen, self).__init__()
123
+ self.sine_amp = sine_amp
124
+ self.noise_std = noise_std
125
+ self.harmonic_num = harmonic_num
126
+ self.sampling_rate = samp_rate
127
+ self.voiced_threshold = voiced_threshold
128
+
129
+ def _f02uv(self, f0):
130
+ # generate uv signal
131
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
132
+ return uv
133
+
134
+ @torch.no_grad()
135
+ def forward(self, f0):
136
+ """
137
+ :param f0: [B, 1, sample_len], Hz
138
+ :return: [B, 1, sample_len]
139
+ """
140
+
141
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
142
+ for i in range(self.harmonic_num + 1):
143
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
144
+
145
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
146
+ u_dist = Uniform(low=-np.pi, high=np.pi)
147
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
148
+ phase_vec[:, 0, :] = 0
149
+
150
+ # generate sine waveforms
151
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
152
+
153
+ # generate uv signal
154
+ uv = self._f02uv(f0)
155
+
156
+ # noise: for unvoiced should be similar to sine_amp
157
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
158
+ # . for voiced regions is self.noise_std
159
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
160
+ noise = noise_amp * torch.randn_like(sine_waves)
161
+
162
+ # first: set the unvoiced part to 0 by uv
163
+ # then: additive noise
164
+ sine_waves = sine_waves * uv + noise
165
+ return sine_waves, uv, noise
166
+
167
+
168
+ class SourceModuleHnNSF(torch.nn.Module):
169
+ """ SourceModule for hn-nsf
170
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
171
+ add_noise_std=0.003, voiced_threshod=0)
172
+ sampling_rate: sampling_rate in Hz
173
+ harmonic_num: number of harmonic above F0 (default: 0)
174
+ sine_amp: amplitude of sine source signal (default: 0.1)
175
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
176
+ note that amplitude of noise in unvoiced is decided
177
+ by sine_amp
178
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
179
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
180
+ F0_sampled (batchsize, length, 1)
181
+ Sine_source (batchsize, length, 1)
182
+ noise_source (batchsize, length 1)
183
+ uv (batchsize, length, 1)
184
+ """
185
+
186
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
187
+ add_noise_std=0.003, voiced_threshod=0):
188
+ super(SourceModuleHnNSF, self).__init__()
189
+
190
+ self.sine_amp = sine_amp
191
+ self.noise_std = add_noise_std
192
+
193
+ # to produce sine waveforms
194
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
195
+ sine_amp, add_noise_std, voiced_threshod)
196
+
197
+ # to merge source harmonics into a single excitation
198
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
199
+ self.l_tanh = torch.nn.Tanh()
200
+
201
+ def forward(self, x):
202
+ """
203
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
204
+ F0_sampled (batchsize, length, 1)
205
+ Sine_source (batchsize, length, 1)
206
+ noise_source (batchsize, length 1)
207
+ """
208
+ # source for harmonic branch
209
+ with torch.no_grad():
210
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
211
+ sine_wavs = sine_wavs.transpose(1, 2)
212
+ uv = uv.transpose(1, 2)
213
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
214
+
215
+ # source for noise branch, in the same shape as uv
216
+ noise = torch.randn_like(uv) * self.sine_amp / 3
217
+ return sine_merge, noise, uv
218
+
219
+
220
+ class HiFTGenerator(nn.Module):
221
+ """
222
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
223
+ https://arxiv.org/abs/2309.09493
224
+ """
225
+ def __init__(
226
+ self,
227
+ in_channels: int = 80,
228
+ base_channels: int = 512,
229
+ nb_harmonics: int = 8,
230
+ sampling_rate: int = 22050,
231
+ nsf_alpha: float = 0.1,
232
+ nsf_sigma: float = 0.003,
233
+ nsf_voiced_threshold: float = 10,
234
+ upsample_rates: tp.List[int] = [8, 8],
235
+ upsample_kernel_sizes: tp.List[int] = [16, 16],
236
+ istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
237
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
238
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
239
+ source_resblock_kernel_sizes: tp.List[int] = [7, 11],
240
+ source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
241
+ lrelu_slope: float = 0.1,
242
+ audio_limit: float = 0.99,
243
+ f0_predictor: torch.nn.Module = None,
244
+ ):
245
+ super(HiFTGenerator, self).__init__()
246
+
247
+ self.out_channels = 1
248
+ self.nb_harmonics = nb_harmonics
249
+ self.sampling_rate = sampling_rate
250
+ self.istft_params = istft_params
251
+ self.lrelu_slope = lrelu_slope
252
+ self.audio_limit = audio_limit
253
+
254
+ self.num_kernels = len(resblock_kernel_sizes)
255
+ self.num_upsamples = len(upsample_rates)
256
+ self.m_source = SourceModuleHnNSF(
257
+ sampling_rate=sampling_rate,
258
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
259
+ harmonic_num=nb_harmonics,
260
+ sine_amp=nsf_alpha,
261
+ add_noise_std=nsf_sigma,
262
+ voiced_threshod=nsf_voiced_threshold)
263
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
264
+
265
+ self.conv_pre = weight_norm(
266
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
267
+ )
268
+
269
+ # Up
270
+ self.ups = nn.ModuleList()
271
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
272
+ self.ups.append(
273
+ weight_norm(
274
+ ConvTranspose1d(
275
+ base_channels // (2**i),
276
+ base_channels // (2**(i + 1)),
277
+ k,
278
+ u,
279
+ padding=(k - u) // 2,
280
+ )
281
+ )
282
+ )
283
+
284
+ # Down
285
+ self.source_downs = nn.ModuleList()
286
+ self.source_resblocks = nn.ModuleList()
287
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
288
+ downsample_cum_rates = np.cumprod(downsample_rates)
289
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
290
+ source_resblock_dilation_sizes)):
291
+ if u == 1:
292
+ self.source_downs.append(
293
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
294
+ )
295
+ else:
296
+ self.source_downs.append(
297
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
298
+ )
299
+
300
+ self.source_resblocks.append(
301
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
302
+ )
303
+
304
+ self.resblocks = nn.ModuleList()
305
+ for i in range(len(self.ups)):
306
+ ch = base_channels // (2**(i + 1))
307
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
308
+ self.resblocks.append(ResBlock(ch, k, d))
309
+
310
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
311
+ self.ups.apply(init_weights)
312
+ self.conv_post.apply(init_weights)
313
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
314
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
315
+ self.f0_predictor = f0_predictor
316
+
317
+ def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
318
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
319
+
320
+ har_source, _, _ = self.m_source(f0)
321
+ return har_source.transpose(1, 2)
322
+
323
+ def _stft(self, x):
324
+ spec = torch.stft(
325
+ x,
326
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
327
+ return_complex=True)
328
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
329
+ return spec[..., 0], spec[..., 1]
330
+
331
+ def _istft(self, magnitude, phase):
332
+ magnitude = torch.clip(magnitude, max=1e2)
333
+ real = magnitude * torch.cos(phase)
334
+ img = magnitude * torch.sin(phase)
335
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
336
+ return inverse_transform
337
+
338
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
339
+ f0 = self.f0_predictor(x)
340
+ s = self._f02source(f0)
341
+
342
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
343
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
344
+
345
+ x = self.conv_pre(x)
346
+ for i in range(self.num_upsamples):
347
+ x = F.leaky_relu(x, self.lrelu_slope)
348
+ x = self.ups[i](x)
349
+
350
+ if i == self.num_upsamples - 1:
351
+ x = self.reflection_pad(x)
352
+
353
+ # fusion
354
+ si = self.source_downs[i](s_stft)
355
+ si = self.source_resblocks[i](si)
356
+ x = x + si
357
+
358
+ xs = None
359
+ for j in range(self.num_kernels):
360
+ if xs is None:
361
+ xs = self.resblocks[i * self.num_kernels + j](x)
362
+ else:
363
+ xs += self.resblocks[i * self.num_kernels + j](x)
364
+ x = xs / self.num_kernels
365
+
366
+ x = F.leaky_relu(x)
367
+ x = self.conv_post(x)
368
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
369
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
370
+
371
+ x = self._istft(magnitude, phase)
372
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
373
+ return x
374
+
375
+ def remove_weight_norm(self):
376
+ print('Removing weight norm...')
377
+ for l in self.ups:
378
+ remove_weight_norm(l)
379
+ for l in self.resblocks:
380
+ l.remove_weight_norm()
381
+ remove_weight_norm(self.conv_pre)
382
+ remove_weight_norm(self.conv_post)
383
+ self.source_module.remove_weight_norm()
384
+ for l in self.source_downs:
385
+ remove_weight_norm(l)
386
+ for l in self.source_resblocks:
387
+ l.remove_weight_norm()
388
+
389
+ @torch.inference_mode()
390
+ def inference(self, mel: torch.Tensor) -> torch.Tensor:
391
+ return self.forward(x=mel)
File without changes
@@ -0,0 +1,206 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Union
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
19
+ from cosyvoice.utils.common import IGNORE_ID
20
+ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
21
+ from cosyvoice.utils.common import th_accuracy
22
+
23
+
24
+ class TransformerLM(torch.nn.Module):
25
+ def __init__(
26
+ self,
27
+ text_encoder_input_size: int,
28
+ llm_input_size: int,
29
+ llm_output_size: int,
30
+ text_token_size: int,
31
+ speech_token_size: int,
32
+ text_encoder: torch.nn.Module,
33
+ llm: torch.nn.Module,
34
+ length_normalized_loss: bool = True,
35
+ lsm_weight: float = 0.0,
36
+ spk_embed_dim: int = 192,
37
+ ):
38
+ super().__init__()
39
+ self.llm_input_size = llm_input_size
40
+ self.speech_token_size = speech_token_size
41
+ # 1. build text token inputs related modules
42
+ self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
43
+ self.text_encoder = text_encoder
44
+ self.text_encoder_affine_layer = nn.Linear(
45
+ self.text_encoder.output_size(),
46
+ llm_input_size
47
+ )
48
+
49
+ # 2. build speech token language model related modules
50
+ self.sos_eos = 0
51
+ self.task_id = 1
52
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
53
+ self.llm = llm
54
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
55
+ self.criterion_ce = LabelSmoothingLoss(
56
+ size=speech_token_size + 1,
57
+ padding_idx=IGNORE_ID,
58
+ smoothing=lsm_weight,
59
+ normalize_length=length_normalized_loss,
60
+ )
61
+
62
+ # 3. [Optional] build speech token related modules
63
+ self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
64
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
65
+
66
+ def encode(
67
+ self,
68
+ text: torch.Tensor,
69
+ text_lengths: torch.Tensor,
70
+ ):
71
+ encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
72
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
73
+ encoder_out = self.text_encoder_affine_layer(encoder_out)
74
+ return encoder_out, encoder_out_lens
75
+
76
+ def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
77
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
78
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
79
+ lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) for i in range(len(text_token))]
80
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
81
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
82
+ return lm_input, lm_input_len
83
+
84
+ def forward(
85
+ self,
86
+ batch: dict,
87
+ device: torch.device,
88
+ ) -> Dict[str, Optional[torch.Tensor]]:
89
+ """
90
+ Args:
91
+ text: (B, L, D)
92
+ text_lengths: (B,)
93
+ audio: (B, T, N) or (B, T)
94
+ audio_lengths: (B,)
95
+ """
96
+ text_token = batch['text_token'].to(device)
97
+ text_token_len = batch['text_token_len'].to(device)
98
+ speech_token = batch['speech_token'].to(device)
99
+ speech_token_len = batch['speech_token_len'].to(device)
100
+ embedding = batch['embedding'].to(device)
101
+
102
+ # 1. prepare llm_target
103
+ lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]
104
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
105
+
106
+ # 1. encode text_token
107
+ text_token = self.text_embedding(text_token)
108
+ text_token, text_token_len = self.encode(text_token, text_token_len)
109
+
110
+ # 2. embedding projection
111
+ embedding = F.normalize(embedding, dim=1)
112
+ embedding = self.spk_embed_affine_layer(embedding)
113
+ embedding = embedding.unsqueeze(1)
114
+
115
+ # 3. eos and task_id
116
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
117
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
118
+
119
+ # 4. encode speech_token
120
+ speech_token = self.speech_embedding(speech_token)
121
+
122
+ # 5. unpad and pad
123
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len)
124
+
125
+ # 6. run lm forward
126
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
127
+ logits = self.llm_decoder(lm_output)
128
+ loss = self.criterion_ce(logits, lm_target)
129
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
130
+ return {'loss': loss, 'acc': acc}
131
+
132
+ def sampling_ids(
133
+ self,
134
+ weighted_scores: torch.Tensor,
135
+ sampling: Union[bool, int, float] = True,
136
+ beam_size: int = 1,
137
+ ignore_eos: bool = True,
138
+ ):
139
+ while True:
140
+ prob, indices = weighted_scores.softmax(dim=-1).topk(sampling)
141
+ top_ids = prob.multinomial(beam_size, replacement=True)
142
+ top_ids = indices[top_ids]
143
+ if (not ignore_eos) or (self.speech_token_size not in top_ids):
144
+ break
145
+ return top_ids
146
+
147
+ @torch.inference_mode()
148
+ def inference(
149
+ self,
150
+ text: torch.Tensor,
151
+ text_len: torch.Tensor,
152
+ prompt_text: torch.Tensor,
153
+ prompt_text_len: torch.Tensor,
154
+ prompt_speech_token: torch.Tensor,
155
+ prompt_speech_token_len: torch.Tensor,
156
+ embedding: torch.Tensor,
157
+ beam_size: int = 1,
158
+ sampling: int = 25,
159
+ max_token_text_ratio: float = 20,
160
+ min_token_text_ratio: float = 2,
161
+ ) -> torch.Tensor:
162
+ device = text.device
163
+ text = torch.concat([prompt_text, text], dim=1)
164
+ text_len += prompt_text_len
165
+ text = self.text_embedding(text)
166
+
167
+ # 1. encode text
168
+ text, text_len = self.encode(text, text_len)
169
+
170
+ # 2. encode embedding
171
+ if embedding.shape[0] != 0:
172
+ embedding = F.normalize(embedding, dim=1)
173
+ embedding = self.spk_embed_affine_layer(embedding)
174
+ embedding = embedding.unsqueeze(dim=1)
175
+ else:
176
+ embedding = torch.zeros(1, 0, self.llm_input_size).to(device)
177
+
178
+ # 3. concat llm_input
179
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
180
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
181
+ if prompt_speech_token_len != 0:
182
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
183
+ else:
184
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device)
185
+ lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
186
+
187
+ # 4. cal min/max_length
188
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
189
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
190
+
191
+ # 5. step by step decode
192
+ out_tokens = []
193
+ offset = 0
194
+ att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
195
+ for i in range(max_len):
196
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
197
+ att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
198
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
199
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
200
+ if top_ids == self.speech_token_size:
201
+ break
202
+ out_tokens.append(top_ids)
203
+ offset += lm_input.size(1)
204
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
205
+
206
+ return torch.tensor([out_tokens], dtype=torch.int64, device=device)
@@ -0,0 +1,84 @@
1
+ # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # 2020 Northwestern Polytechnical University (Pengcheng Guo)
3
+ # 2020 Mobvoi Inc (Binbin Zhang)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Swish() activation function for Conformer."""
18
+
19
+ import torch
20
+ from torch import nn, sin, pow
21
+ from torch.nn import Parameter
22
+
23
+
24
+ class Swish(torch.nn.Module):
25
+ """Construct an Swish object."""
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ """Return Swish activation function."""
29
+ return x * torch.sigmoid(x)
30
+
31
+
32
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
33
+ # LICENSE is in incl_licenses directory.
34
+ class Snake(nn.Module):
35
+ '''
36
+ Implementation of a sine-based periodic activation function
37
+ Shape:
38
+ - Input: (B, C, T)
39
+ - Output: (B, C, T), same shape as the input
40
+ Parameters:
41
+ - alpha - trainable parameter
42
+ References:
43
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
+ https://arxiv.org/abs/2006.08195
45
+ Examples:
46
+ >>> a1 = snake(256)
47
+ >>> x = torch.randn(256)
48
+ >>> x = a1(x)
49
+ '''
50
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51
+ '''
52
+ Initialization.
53
+ INPUT:
54
+ - in_features: shape of the input
55
+ - alpha: trainable parameter
56
+ alpha is initialized to 1 by default, higher values = higher-frequency.
57
+ alpha will be trained along with the rest of your model.
58
+ '''
59
+ super(Snake, self).__init__()
60
+ self.in_features = in_features
61
+
62
+ # initialize alpha
63
+ self.alpha_logscale = alpha_logscale
64
+ if self.alpha_logscale: # log scale alphas initialized to zeros
65
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
66
+ else: # linear scale alphas initialized to ones
67
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
68
+
69
+ self.alpha.requires_grad = alpha_trainable
70
+
71
+ self.no_div_by_zero = 0.000000001
72
+
73
+ def forward(self, x):
74
+ '''
75
+ Forward pass of the function.
76
+ Applies the function to the input elementwise.
77
+ Snake ∶= x + 1/a * sin^2 (xa)
78
+ '''
79
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80
+ if self.alpha_logscale:
81
+ alpha = torch.exp(alpha)
82
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83
+
84
+ return x