xinference 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (96) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +97 -8
  3. xinference/client/restful/restful_client.py +51 -11
  4. xinference/core/media_interface.py +758 -0
  5. xinference/core/model.py +49 -9
  6. xinference/core/worker.py +31 -37
  7. xinference/deploy/utils.py +0 -3
  8. xinference/model/audio/__init__.py +16 -27
  9. xinference/model/audio/core.py +1 -0
  10. xinference/model/audio/cosyvoice.py +4 -2
  11. xinference/model/audio/model_spec.json +20 -3
  12. xinference/model/audio/model_spec_modelscope.json +18 -1
  13. xinference/model/embedding/__init__.py +16 -24
  14. xinference/model/image/__init__.py +15 -25
  15. xinference/model/llm/__init__.py +37 -110
  16. xinference/model/llm/core.py +15 -6
  17. xinference/model/llm/llama_cpp/core.py +25 -353
  18. xinference/model/llm/llm_family.json +613 -89
  19. xinference/model/llm/llm_family.py +9 -1
  20. xinference/model/llm/llm_family_modelscope.json +540 -90
  21. xinference/model/llm/mlx/core.py +6 -3
  22. xinference/model/llm/reasoning_parser.py +281 -5
  23. xinference/model/llm/sglang/core.py +16 -3
  24. xinference/model/llm/transformers/chatglm.py +2 -2
  25. xinference/model/llm/transformers/cogagent.py +1 -1
  26. xinference/model/llm/transformers/cogvlm2.py +1 -1
  27. xinference/model/llm/transformers/core.py +9 -3
  28. xinference/model/llm/transformers/glm4v.py +1 -1
  29. xinference/model/llm/transformers/minicpmv26.py +1 -1
  30. xinference/model/llm/transformers/qwen-omni.py +6 -0
  31. xinference/model/llm/transformers/qwen_vl.py +1 -1
  32. xinference/model/llm/utils.py +68 -45
  33. xinference/model/llm/vllm/core.py +38 -18
  34. xinference/model/llm/vllm/xavier/test/test_xavier.py +1 -10
  35. xinference/model/rerank/__init__.py +13 -24
  36. xinference/model/video/__init__.py +15 -25
  37. xinference/model/video/core.py +3 -3
  38. xinference/model/video/diffusers.py +133 -16
  39. xinference/model/video/model_spec.json +54 -0
  40. xinference/model/video/model_spec_modelscope.json +56 -0
  41. xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
  42. xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
  43. xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
  44. xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
  45. xinference/thirdparty/cosyvoice/bin/train.py +7 -2
  46. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
  47. xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
  48. xinference/thirdparty/cosyvoice/cli/model.py +140 -155
  49. xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
  50. xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
  51. xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
  52. xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
  53. xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
  54. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
  55. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
  56. xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
  57. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
  58. xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
  59. xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
  60. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
  61. xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
  62. xinference/thirdparty/cosyvoice/utils/common.py +1 -1
  63. xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
  64. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
  65. xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
  66. xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
  67. xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
  68. xinference/types.py +0 -71
  69. xinference/web/ui/build/asset-manifest.json +3 -3
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
  72. xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
  74. xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
  75. xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
  76. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
  77. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
  78. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
  79. xinference/web/ui/src/locales/en.json +6 -4
  80. xinference/web/ui/src/locales/zh.json +6 -4
  81. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/METADATA +56 -36
  82. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/RECORD +87 -87
  83. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/WHEEL +1 -1
  84. xinference/core/image_interface.py +0 -377
  85. xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
  86. xinference/web/ui/build/static/js/main.91e77b5c.js +0 -3
  87. xinference/web/ui/build/static/js/main.91e77b5c.js.map +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/5e6edb0fb87e3798f142e9abf8dd2dc46bab33a60d31dff525797c0c99887097.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/6087820be1bd5c02c42dff797e7df365448ef35ab26dd5d6bd33e967e05cbfd4.json +0 -1
  91. xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
  92. xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
  93. /xinference/web/ui/build/static/js/{main.91e77b5c.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
  94. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/entry_points.txt +0 -0
  95. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/licenses/LICENSE +0 -0
  96. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/top_level.txt +0 -0
@@ -91,6 +91,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
91
91
  conds = conds.transpose(1, 2)
92
92
 
93
93
  mask = (~make_pad_mask(feat_len)).to(h)
94
+ # NOTE this is unnecessary, feat/h already same shape
94
95
  feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
95
96
  loss, _ = self.decoder.compute_loss(
96
97
  feat.transpose(1, 2).contiguous(),
@@ -116,7 +117,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
116
117
  embedding = F.normalize(embedding, dim=1)
117
118
  embedding = self.spk_embed_affine_layer(embedding)
118
119
 
119
- # concat text and prompt_text
120
+ # concat speech token and prompt speech token
120
121
  token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
121
122
  token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
122
123
  mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
@@ -129,7 +130,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
129
130
  h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
130
131
 
131
132
  # get conditions
132
- conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
133
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
133
134
  conds[:, :mel_len1] = prompt_feat
134
135
  conds = conds.transpose(1, 2)
135
136
 
@@ -141,11 +142,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
141
142
  cond=conds,
142
143
  n_timesteps=10,
143
144
  prompt_len=mel_len1,
144
- flow_cache=flow_cache
145
+ cache=flow_cache
145
146
  )
146
147
  feat = feat[:, :, mel_len1:]
147
148
  assert feat.shape[2] == mel_len2
148
- return feat, flow_cache
149
+ return feat.float(), flow_cache
149
150
 
150
151
 
151
152
  class CausalMaskedDiffWithXvec(torch.nn.Module):
@@ -186,6 +187,53 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
186
187
  self.token_mel_ratio = token_mel_ratio
187
188
  self.pre_lookahead_len = pre_lookahead_len
188
189
 
190
+ def forward(
191
+ self,
192
+ batch: dict,
193
+ device: torch.device,
194
+ ) -> Dict[str, Optional[torch.Tensor]]:
195
+ token = batch['speech_token'].to(device)
196
+ token_len = batch['speech_token_len'].to(device)
197
+ feat = batch['speech_feat'].to(device)
198
+ feat_len = batch['speech_feat_len'].to(device)
199
+ embedding = batch['embedding'].to(device)
200
+
201
+ # NOTE unified training, static_chunk_size > 0 or = 0
202
+ streaming = True if random.random() < 0.5 else False
203
+
204
+ # xvec projection
205
+ embedding = F.normalize(embedding, dim=1)
206
+ embedding = self.spk_embed_affine_layer(embedding)
207
+
208
+ # concat text and prompt_text
209
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
210
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
211
+
212
+ # text encode
213
+ h, h_lengths = self.encoder(token, token_len, streaming=streaming)
214
+ h = self.encoder_proj(h)
215
+
216
+ # get conditions
217
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
218
+ conds = torch.zeros(feat.shape, device=token.device)
219
+ for i, j in enumerate(feat_len):
220
+ if random.random() < 0.5:
221
+ continue
222
+ index = random.randint(0, int(0.3 * j))
223
+ conds[i, :index] = feat[i, :index]
224
+ conds = conds.transpose(1, 2)
225
+
226
+ mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
227
+ loss, _ = self.decoder.compute_loss(
228
+ feat.transpose(1, 2).contiguous(),
229
+ mask.unsqueeze(1),
230
+ h.transpose(1, 2).contiguous(),
231
+ embedding,
232
+ cond=conds,
233
+ streaming=streaming,
234
+ )
235
+ return {'loss': loss}
236
+
189
237
  @torch.inference_mode()
190
238
  def inference(self,
191
239
  token,
@@ -195,6 +243,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
195
243
  prompt_feat,
196
244
  prompt_feat_len,
197
245
  embedding,
246
+ cache,
198
247
  finalize):
199
248
  assert token.shape[0] == 1
200
249
  # xvec projection
@@ -207,25 +256,34 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
207
256
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
208
257
 
209
258
  # text encode
210
- h, h_lengths = self.encoder(token, token_len)
211
- if finalize is False:
212
- h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
259
+ if finalize is True:
260
+ h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache'])
261
+ else:
262
+ token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
263
+ h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache'])
264
+ cache['encoder_cache']['offset'] = encoder_cache[0]
265
+ cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1]
266
+ cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2]
267
+ cache['encoder_cache']['upsample_offset'] = encoder_cache[3]
268
+ cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4]
269
+ cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5]
213
270
  mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
214
271
  h = self.encoder_proj(h)
215
272
 
216
273
  # get conditions
217
- conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
274
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
218
275
  conds[:, :mel_len1] = prompt_feat
219
276
  conds = conds.transpose(1, 2)
220
277
 
221
278
  mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
222
- feat, _ = self.decoder(
279
+ feat, cache['decoder_cache'] = self.decoder(
223
280
  mu=h.transpose(1, 2).contiguous(),
224
281
  mask=mask.unsqueeze(1),
225
282
  spks=embedding,
226
283
  cond=conds,
227
- n_timesteps=10
284
+ n_timesteps=10,
285
+ cache=cache['decoder_cache']
228
286
  )
229
287
  feat = feat[:, :, mel_len1:]
230
288
  assert feat.shape[2] == mel_len2
231
- return feat, None
289
+ return feat.float(), cache
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import onnxruntime
14
+ import threading
15
15
  import torch
16
16
  import torch.nn.functional as F
17
17
  from matcha.models.components.flow_matching import BASECFM
@@ -31,9 +31,10 @@ class ConditionalCFM(BASECFM):
31
31
  in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
32
32
  # Just change the architecture of the estimator here
33
33
  self.estimator = estimator
34
+ self.lock = threading.Lock()
34
35
 
35
36
  @torch.inference_mode()
36
- def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
37
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
37
38
  """Forward diffusion
38
39
 
39
40
  Args:
@@ -52,20 +53,20 @@ class ConditionalCFM(BASECFM):
52
53
  shape: (batch_size, n_feats, mel_timesteps)
53
54
  """
54
55
 
55
- z = torch.randn_like(mu) * temperature
56
- cache_size = flow_cache.shape[2]
56
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
57
+ cache_size = cache.shape[2]
57
58
  # fix prompt and overlap part mu and z
58
59
  if cache_size != 0:
59
- z[:, :, :cache_size] = flow_cache[:, :, :, 0]
60
- mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
60
+ z[:, :, :cache_size] = cache[:, :, :, 0]
61
+ mu[:, :, :cache_size] = cache[:, :, :, 1]
61
62
  z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
62
63
  mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
63
- flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
64
+ cache = torch.stack([z_cache, mu_cache], dim=-1)
64
65
 
65
66
  t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
66
67
  if self.t_scheduler == 'cosine':
67
68
  t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
68
- return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
69
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
69
70
 
70
71
  def solve_euler(self, x, t_span, mu, mask, spks, cond):
71
72
  """
@@ -89,36 +90,29 @@ class ConditionalCFM(BASECFM):
89
90
  # Or in future might add like a return_all_steps flag
90
91
  sol = []
91
92
 
92
- if self.inference_cfg_rate > 0:
93
- # Do not use concat, it may cause memory format changed and trt infer with wrong results!
94
- x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
95
- mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
96
- mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
97
- t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
98
- spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
99
- cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
100
- else:
101
- x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
93
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
94
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
95
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
96
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
97
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
98
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
99
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
102
100
  for step in range(1, len(t_span)):
103
101
  # Classifier-Free Guidance inference introduced in VoiceBox
104
- if self.inference_cfg_rate > 0:
105
- x_in[:] = x
106
- mask_in[:] = mask
107
- mu_in[0] = mu
108
- t_in[:] = t.unsqueeze(0)
109
- spks_in[0] = spks
110
- cond_in[0] = cond
111
- else:
112
- x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
102
+ x_in[:] = x
103
+ mask_in[:] = mask
104
+ mu_in[0] = mu
105
+ t_in[:] = t.unsqueeze(0)
106
+ spks_in[0] = spks
107
+ cond_in[0] = cond
113
108
  dphi_dt = self.forward_estimator(
114
109
  x_in, mask_in,
115
110
  mu_in, t_in,
116
111
  spks_in,
117
112
  cond_in
118
113
  )
119
- if self.inference_cfg_rate > 0:
120
- dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
121
- dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
114
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
115
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
122
116
  x = x + dt * dphi_dt
123
117
  t = t + dt
124
118
  sol.append(x)
@@ -129,36 +123,26 @@ class ConditionalCFM(BASECFM):
129
123
 
130
124
  def forward_estimator(self, x, mask, mu, t, spks, cond):
131
125
  if isinstance(self.estimator, torch.nn.Module):
132
- return self.estimator.forward(x, mask, mu, t, spks, cond)
133
- elif isinstance(self.estimator, onnxruntime.InferenceSession):
134
- ort_inputs = {
135
- 'x': x.cpu().numpy(),
136
- 'mask': mask.cpu().numpy(),
137
- 'mu': mu.cpu().numpy(),
138
- 't': t.cpu().numpy(),
139
- 'spks': spks.cpu().numpy(),
140
- 'cond': cond.cpu().numpy()
141
- }
142
- output = self.estimator.run(None, ort_inputs)[0]
143
- return torch.tensor(output, dtype=x.dtype, device=x.device)
126
+ return self.estimator(x, mask, mu, t, spks, cond)
144
127
  else:
145
- self.estimator.set_input_shape('x', (2, 80, x.size(2)))
146
- self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
147
- self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
148
- self.estimator.set_input_shape('t', (2,))
149
- self.estimator.set_input_shape('spks', (2, 80))
150
- self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
151
- # run trt engine
152
- self.estimator.execute_v2([x.contiguous().data_ptr(),
153
- mask.contiguous().data_ptr(),
154
- mu.contiguous().data_ptr(),
155
- t.contiguous().data_ptr(),
156
- spks.contiguous().data_ptr(),
157
- cond.contiguous().data_ptr(),
158
- x.data_ptr()])
128
+ with self.lock:
129
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
130
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
131
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
132
+ self.estimator.set_input_shape('t', (2,))
133
+ self.estimator.set_input_shape('spks', (2, 80))
134
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
135
+ # run trt engine
136
+ assert self.estimator.execute_v2([x.contiguous().data_ptr(),
137
+ mask.contiguous().data_ptr(),
138
+ mu.contiguous().data_ptr(),
139
+ t.contiguous().data_ptr(),
140
+ spks.contiguous().data_ptr(),
141
+ cond.contiguous().data_ptr(),
142
+ x.data_ptr()]) is True
159
143
  return x
160
144
 
161
- def compute_loss(self, x1, mask, mu, spks=None, cond=None):
145
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
162
146
  """Computes diffusion loss
163
147
 
164
148
  Args:
@@ -195,7 +179,7 @@ class ConditionalCFM(BASECFM):
195
179
  spks = spks * cfg_mask.view(-1, 1)
196
180
  cond = cond * cfg_mask.view(-1, 1, 1)
197
181
 
198
- pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
182
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
199
183
  loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
200
184
  return loss, y
201
185
 
@@ -206,7 +190,7 @@ class CausalConditionalCFM(ConditionalCFM):
206
190
  self.rand_noise = torch.randn([1, 80, 50 * 300])
207
191
 
208
192
  @torch.inference_mode()
209
- def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
193
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}):
210
194
  """Forward diffusion
211
195
 
212
196
  Args:
@@ -225,11 +209,131 @@ class CausalConditionalCFM(ConditionalCFM):
225
209
  shape: (batch_size, n_feats, mel_timesteps)
226
210
  """
227
211
 
228
- z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
229
- if self.fp16 is True:
230
- z = z.half()
212
+ offset = cache.pop('offset')
213
+ z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature
214
+ z = z[:, :, offset:]
215
+ offset += mu.size(2)
231
216
  # fix prompt and overlap part mu and z
232
217
  t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
233
218
  if self.t_scheduler == 'cosine':
234
219
  t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
235
- return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
220
+ mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache)
221
+ cache['offset'] = offset
222
+ return mel, cache
223
+
224
+ def solve_euler(self, x, t_span, mu, mask, spks, cond, cache):
225
+ """
226
+ Fixed euler solver for ODEs.
227
+ Args:
228
+ x (torch.Tensor): random noise
229
+ t_span (torch.Tensor): n_timesteps interpolated
230
+ shape: (n_timesteps + 1,)
231
+ mu (torch.Tensor): output of encoder
232
+ shape: (batch_size, n_feats, mel_timesteps)
233
+ mask (torch.Tensor): output_mask
234
+ shape: (batch_size, 1, mel_timesteps)
235
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
236
+ shape: (batch_size, spk_emb_dim)
237
+ cond: Not used but kept for future purposes
238
+ """
239
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
240
+ t = t.unsqueeze(dim=0)
241
+
242
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
243
+ # Or in future might add like a return_all_steps flag
244
+ sol = []
245
+
246
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
247
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
248
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
249
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
250
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
251
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
252
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
253
+ flow_cache_size = cache['down_blocks_kv_cache'].shape[4]
254
+ for step in range(1, len(t_span)):
255
+ # Classifier-Free Guidance inference introduced in VoiceBox
256
+ x_in[:] = x
257
+ mask_in[:] = mask
258
+ mu_in[0] = mu
259
+ t_in[:] = t.unsqueeze(0)
260
+ spks_in[0] = spks
261
+ cond_in[0] = cond
262
+ cache_step = {k: v[step - 1] for k, v in cache.items()}
263
+ dphi_dt, cache_step = self.forward_estimator(
264
+ x_in, mask_in,
265
+ mu_in, t_in,
266
+ spks_in,
267
+ cond_in,
268
+ cache_step
269
+ )
270
+ # NOTE if smaller than flow_cache_size, means last chunk, no need to cache
271
+ if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size:
272
+ cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
273
+ cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:, :, :, -flow_cache_size:]
274
+ cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
275
+ cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:, :, :, -flow_cache_size:]
276
+ cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
277
+ cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:, :, :, -flow_cache_size:]
278
+ cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
279
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
280
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
281
+ x = x + dt * dphi_dt
282
+ t = t + dt
283
+ sol.append(x)
284
+ if step < len(t_span) - 1:
285
+ dt = t_span[step + 1] - t
286
+ return sol[-1].float(), cache
287
+
288
+ def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
289
+ if isinstance(self.estimator, torch.nn.Module):
290
+ x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
291
+ cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
292
+ else:
293
+ with self.lock:
294
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
295
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
296
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
297
+ self.estimator.set_input_shape('t', (2,))
298
+ self.estimator.set_input_shape('spks', (2, 80))
299
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
300
+ self.estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
301
+ self.estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
302
+ self.estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
303
+ self.estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
304
+ self.estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
305
+ self.estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
306
+ self.estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
307
+ # run trt engine
308
+ down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
309
+ mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
310
+ up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
311
+ assert self.estimator.execute_v2([x.contiguous().data_ptr(),
312
+ mask.contiguous().data_ptr(),
313
+ mu.contiguous().data_ptr(),
314
+ t.contiguous().data_ptr(),
315
+ spks.contiguous().data_ptr(),
316
+ cond.contiguous().data_ptr(),
317
+ cache['down_blocks_conv_cache'].contiguous().data_ptr(),
318
+ cache['down_blocks_kv_cache'].contiguous().data_ptr(),
319
+ cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
320
+ cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
321
+ cache['up_blocks_conv_cache'].contiguous().data_ptr(),
322
+ cache['up_blocks_kv_cache'].contiguous().data_ptr(),
323
+ cache['final_blocks_conv_cache'].contiguous().data_ptr(),
324
+ x.data_ptr(),
325
+ cache['down_blocks_conv_cache'].data_ptr(),
326
+ down_blocks_kv_cache_out.data_ptr(),
327
+ cache['mid_blocks_conv_cache'].data_ptr(),
328
+ mid_blocks_kv_cache_out.data_ptr(),
329
+ cache['up_blocks_conv_cache'].data_ptr(),
330
+ up_blocks_kv_cache_out.data_ptr(),
331
+ cache['final_blocks_conv_cache'].data_ptr()]) is True
332
+ cache = (cache['down_blocks_conv_cache'],
333
+ down_blocks_kv_cache_out,
334
+ cache['mid_blocks_conv_cache'],
335
+ mid_blocks_kv_cache_out,
336
+ cache['up_blocks_conv_cache'],
337
+ up_blocks_kv_cache_out,
338
+ cache['final_blocks_conv_cache'])
339
+ return x, cache
@@ -51,6 +51,7 @@ class InterpolateRegulator(nn.Module):
51
51
 
52
52
  def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
53
53
  # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
+ # NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
54
55
  # x in (B, T, D)
55
56
  if x2.shape[1] > 40:
56
57
  x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
@@ -1,10 +1,16 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
- from torch.nn.utils import weight_norm
3
+ import torch.nn.functional as F
4
+ try:
5
+ from torch.nn.utils.parametrizations import weight_norm, spectral_norm
6
+ except ImportError:
7
+ from torch.nn.utils import weight_norm, spectral_norm
4
8
  from typing import List, Optional, Tuple
5
9
  from einops import rearrange
6
10
  from torchaudio.transforms import Spectrogram
7
11
 
12
+ LRELU_SLOPE = 0.1
13
+
8
14
 
9
15
  class MultipleDiscriminator(nn.Module):
10
16
  def __init__(
@@ -138,3 +144,87 @@ class DiscriminatorR(nn.Module):
138
144
  x += h
139
145
 
140
146
  return x, fmap
147
+
148
+
149
+ class MultiResSpecDiscriminator(torch.nn.Module):
150
+
151
+ def __init__(self,
152
+ fft_sizes=[1024, 2048, 512],
153
+ hop_sizes=[120, 240, 50],
154
+ win_lengths=[600, 1200, 240],
155
+ window="hann_window"):
156
+
157
+ super(MultiResSpecDiscriminator, self).__init__()
158
+ self.discriminators = nn.ModuleList([
159
+ SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
160
+ SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
161
+ SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)])
162
+
163
+ def forward(self, y, y_hat):
164
+ y_d_rs = []
165
+ y_d_gs = []
166
+ fmap_rs = []
167
+ fmap_gs = []
168
+ for _, d in enumerate(self.discriminators):
169
+ y_d_r, fmap_r = d(y)
170
+ y_d_g, fmap_g = d(y_hat)
171
+ y_d_rs.append(y_d_r)
172
+ fmap_rs.append(fmap_r)
173
+ y_d_gs.append(y_d_g)
174
+ fmap_gs.append(fmap_g)
175
+
176
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
177
+
178
+
179
+ def stft(x, fft_size, hop_size, win_length, window):
180
+ """Perform STFT and convert to magnitude spectrogram.
181
+ Args:
182
+ x (Tensor): Input signal tensor (B, T).
183
+ fft_size (int): FFT size.
184
+ hop_size (int): Hop size.
185
+ win_length (int): Window length.
186
+ window (str): Window function type.
187
+ Returns:
188
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
189
+ """
190
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
191
+
192
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
193
+ return torch.abs(x_stft).transpose(2, 1)
194
+
195
+
196
+ class SpecDiscriminator(nn.Module):
197
+ """docstring for Discriminator."""
198
+
199
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
200
+ super(SpecDiscriminator, self).__init__()
201
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
202
+ self.fft_size = fft_size
203
+ self.shift_size = shift_size
204
+ self.win_length = win_length
205
+ self.window = getattr(torch, window)(win_length)
206
+ self.discriminators = nn.ModuleList([
207
+ norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
208
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
209
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
210
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
211
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
212
+ ])
213
+
214
+ self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
215
+
216
+ def forward(self, y):
217
+
218
+ fmap = []
219
+ y = y.squeeze(1)
220
+ y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.device))
221
+ y = y.unsqueeze(1)
222
+ for _, d in enumerate(self.discriminators):
223
+ y = d(y)
224
+ y = F.leaky_relu(y, LRELU_SLOPE)
225
+ fmap.append(y)
226
+
227
+ y = self.out(y)
228
+ fmap.append(y)
229
+
230
+ return torch.flatten(y, 1, -1), fmap
@@ -13,7 +13,10 @@
13
13
  # limitations under the License.
14
14
  import torch
15
15
  import torch.nn as nn
16
- from torch.nn.utils import weight_norm
16
+ try:
17
+ from torch.nn.utils.parametrizations import weight_norm
18
+ except ImportError:
19
+ from torch.nn.utils import weight_norm
17
20
 
18
21
 
19
22
  class ConvRNNF0Predictor(nn.Module):
@@ -23,7 +23,10 @@ import torch.nn.functional as F
23
23
  from torch.nn import Conv1d
24
24
  from torch.nn import ConvTranspose1d
25
25
  from torch.nn.utils import remove_weight_norm
26
- from torch.nn.utils import weight_norm
26
+ try:
27
+ from torch.nn.utils.parametrizations import weight_norm
28
+ except ImportError:
29
+ from torch.nn.utils import weight_norm
27
30
  from torch.distributions.uniform import Uniform
28
31
 
29
32
  from cosyvoice.transformer.activation import Snake
@@ -41,7 +41,7 @@ class HiFiGan(nn.Module):
41
41
  loss_fm = feature_loss(fmap_rs, fmap_gs)
42
42
  loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
43
43
  if self.tpr_loss_weight != 0:
44
- loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
44
+ loss_tpr = tpr_loss(y_d_gs, y_d_rs, self.tpr_loss_tau)
45
45
  else:
46
46
  loss_tpr = torch.zeros(1).to(device)
47
47
  loss_f0 = F.l1_loss(generated_f0, pitch_feat)
@@ -56,7 +56,7 @@ class HiFiGan(nn.Module):
56
56
  with torch.no_grad():
57
57
  generated_speech, generated_f0 = self.generator(batch, device)
58
58
  # 2. calculate discriminator outputs
59
- y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
59
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach())
60
60
  # 3. calculate discriminator losses, tpr losses [Optional]
61
61
  loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
62
62
  if self.tpr_loss_weight != 0: