xinference 1.4.1__py3-none-any.whl → 1.5.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (104) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +50 -1
  3. xinference/client/restful/restful_client.py +82 -2
  4. xinference/constants.py +3 -0
  5. xinference/core/chat_interface.py +297 -83
  6. xinference/core/model.py +1 -0
  7. xinference/core/progress_tracker.py +16 -8
  8. xinference/core/supervisor.py +45 -1
  9. xinference/core/worker.py +262 -37
  10. xinference/deploy/cmdline.py +33 -1
  11. xinference/model/audio/core.py +11 -1
  12. xinference/model/audio/megatts.py +105 -0
  13. xinference/model/audio/model_spec.json +24 -1
  14. xinference/model/audio/model_spec_modelscope.json +26 -1
  15. xinference/model/core.py +14 -0
  16. xinference/model/embedding/core.py +6 -1
  17. xinference/model/flexible/core.py +6 -1
  18. xinference/model/image/core.py +6 -1
  19. xinference/model/image/model_spec.json +17 -1
  20. xinference/model/image/model_spec_modelscope.json +17 -1
  21. xinference/model/llm/__init__.py +0 -4
  22. xinference/model/llm/core.py +4 -0
  23. xinference/model/llm/llama_cpp/core.py +40 -16
  24. xinference/model/llm/llm_family.json +415 -84
  25. xinference/model/llm/llm_family.py +24 -1
  26. xinference/model/llm/llm_family_modelscope.json +449 -0
  27. xinference/model/llm/mlx/core.py +16 -2
  28. xinference/model/llm/transformers/__init__.py +14 -0
  29. xinference/model/llm/transformers/core.py +30 -6
  30. xinference/model/llm/transformers/gemma3.py +17 -2
  31. xinference/model/llm/transformers/intern_vl.py +28 -18
  32. xinference/model/llm/transformers/minicpmv26.py +21 -2
  33. xinference/model/llm/transformers/qwen-omni.py +308 -0
  34. xinference/model/llm/transformers/qwen2_audio.py +1 -1
  35. xinference/model/llm/transformers/qwen2_vl.py +20 -4
  36. xinference/model/llm/utils.py +11 -1
  37. xinference/model/llm/vllm/core.py +35 -0
  38. xinference/model/llm/vllm/distributed_executor.py +8 -2
  39. xinference/model/rerank/core.py +6 -1
  40. xinference/model/utils.py +118 -1
  41. xinference/model/video/core.py +6 -1
  42. xinference/thirdparty/megatts3/__init__.py +0 -0
  43. xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
  44. xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
  45. xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
  46. xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
  47. xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
  48. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
  49. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
  50. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
  51. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
  52. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
  53. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
  54. xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
  55. xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
  56. xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
  57. xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
  58. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
  59. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
  60. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
  61. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
  62. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
  63. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
  64. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
  65. xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
  66. xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
  67. xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
  68. xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
  69. xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
  70. xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
  71. xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
  72. xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
  73. xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
  74. xinference/types.py +10 -0
  75. xinference/utils.py +54 -0
  76. xinference/web/ui/build/asset-manifest.json +6 -6
  77. xinference/web/ui/build/index.html +1 -1
  78. xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
  79. xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
  80. xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
  81. xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
  82. xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
  83. xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
  84. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
  86. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
  87. xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
  88. xinference/web/ui/src/locales/en.json +2 -1
  89. xinference/web/ui/src/locales/zh.json +2 -1
  90. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/METADATA +129 -114
  91. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/RECORD +96 -60
  92. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/WHEEL +1 -1
  93. xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
  94. xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
  95. xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
  96. xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
  97. xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
  98. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
  99. xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
  100. xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
  101. /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
  102. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/entry_points.txt +0 -0
  103. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info/licenses}/LICENSE +0 -0
  104. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,309 @@
1
+ # MIT License
2
+
3
+ # Copyright (c) 2023 Alexander Tong
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # Copyright (c) [2023] [Alexander Tong]
24
+ # Copyright (c) [2025] [Ziyue Jiang]
25
+ # SPDX-License-Identifier: MIT
26
+ # This file has been modified by Ziyue Jiang on 2025/03/19
27
+ # Original file was released under MIT, with the full license text # available at https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE.
28
+ # This modified file is released under the same license.
29
+
30
+ import math
31
+ import torch
32
+ from typing import Union
33
+ from torch.distributions import LogisticNormal
34
+
35
+
36
+ class LogitNormalTrainingTimesteps:
37
+ def __init__(self, T=1000.0, loc=0.0, scale=1.0):
38
+ assert T > 0
39
+ self.T = T
40
+ self.dist = LogisticNormal(loc, scale)
41
+
42
+ def sample(self, size, device):
43
+ t = self.dist.sample(size)[..., 0].to(device)
44
+ return t
45
+
46
+
47
+ def pad_t_like_x(t, x):
48
+ """Function to reshape the time vector t by the number of dimensions of x.
49
+
50
+ Parameters
51
+ ----------
52
+ x : Tensor, shape (bs, *dim)
53
+ represents the source minibatch
54
+ t : FloatTensor, shape (bs)
55
+
56
+ Returns
57
+ -------
58
+ t : Tensor, shape (bs, number of x dimensions)
59
+
60
+ Example
61
+ -------
62
+ x: Tensor (bs, C, W, H)
63
+ t: Vector (bs)
64
+ pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
65
+ """
66
+ if isinstance(t, (float, int)):
67
+ return t
68
+ return t.reshape(-1, *([1] * (x.dim() - 1)))
69
+
70
+
71
+ class ConditionalFlowMatcher:
72
+ """Base class for conditional flow matching methods. This class implements the independent
73
+ conditional flow matching methods from [1] and serves as a parent class for all other flow
74
+ matching methods.
75
+
76
+ It implements:
77
+ - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
78
+ - conditional flow matching ut(x1|x0) = x1 - x0
79
+ - score function $\nabla log p_t(x|x0, x1)$
80
+ """
81
+
82
+ def __init__(self, sigma: Union[float, int] = 0.0):
83
+ r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.
84
+
85
+ Parameters
86
+ ----------
87
+ sigma : Union[float, int]
88
+ """
89
+ self.sigma = sigma
90
+ self.time_sampler = LogitNormalTrainingTimesteps()
91
+
92
+ def compute_mu_t(self, x0, x1, t):
93
+ """
94
+ Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
95
+
96
+ Parameters
97
+ ----------
98
+ x0 : Tensor, shape (bs, *dim)
99
+ represents the source minibatch
100
+ x1 : Tensor, shape (bs, *dim)
101
+ represents the target minibatch
102
+ t : FloatTensor, shape (bs)
103
+
104
+ Returns
105
+ -------
106
+ mean mu_t: t * x1 + (1 - t) * x0
107
+
108
+ References
109
+ ----------
110
+ [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
111
+ """
112
+ t = pad_t_like_x(t, x0)
113
+ return t * x1 + (1 - t) * x0
114
+
115
+ def compute_sigma_t(self, t):
116
+ """
117
+ Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
118
+
119
+ Parameters
120
+ ----------
121
+ t : FloatTensor, shape (bs)
122
+
123
+ Returns
124
+ -------
125
+ standard deviation sigma
126
+
127
+ References
128
+ ----------
129
+ [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
130
+ """
131
+ del t
132
+ return self.sigma
133
+
134
+ def sample_xt(self, x0, x1, t, epsilon):
135
+ """
136
+ Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
137
+
138
+ Parameters
139
+ ----------
140
+ x0 : Tensor, shape (bs, *dim)
141
+ represents the source minibatch
142
+ x1 : Tensor, shape (bs, *dim)
143
+ represents the target minibatch
144
+ t : FloatTensor, shape (bs)
145
+ epsilon : Tensor, shape (bs, *dim)
146
+ noise sample from N(0, 1)
147
+
148
+ Returns
149
+ -------
150
+ xt : Tensor, shape (bs, *dim)
151
+
152
+ References
153
+ ----------
154
+ [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
155
+ """
156
+ mu_t = self.compute_mu_t(x0, x1, t)
157
+ sigma_t = self.compute_sigma_t(t)
158
+ sigma_t = pad_t_like_x(sigma_t, x0)
159
+ return mu_t + sigma_t * epsilon
160
+
161
+ def compute_conditional_flow(self, x0, x1, t, xt):
162
+ """
163
+ Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
164
+
165
+ Parameters
166
+ ----------
167
+ x0 : Tensor, shape (bs, *dim)
168
+ represents the source minibatch
169
+ x1 : Tensor, shape (bs, *dim)
170
+ represents the target minibatch
171
+ t : FloatTensor, shape (bs)
172
+ xt : Tensor, shape (bs, *dim)
173
+ represents the samples drawn from probability path pt
174
+
175
+ Returns
176
+ -------
177
+ ut : conditional vector field ut(x1|x0) = x1 - x0
178
+
179
+ References
180
+ ----------
181
+ [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
182
+ """
183
+ del t, xt
184
+ return x1 - x0
185
+
186
+ def sample_noise_like(self, x):
187
+ return torch.randn_like(x)
188
+
189
+ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
190
+ """
191
+ Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
192
+ and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
193
+
194
+ Parameters
195
+ ----------
196
+ x0 : Tensor, shape (bs, *dim)
197
+ represents the source minibatch
198
+ x1 : Tensor, shape (bs, *dim)
199
+ represents the target minibatch
200
+ (optionally) t : Tensor, shape (bs)
201
+ represents the time levels
202
+ if None, drawn from uniform [0,1]
203
+ return_noise : bool
204
+ return the noise sample epsilon
205
+
206
+
207
+ Returns
208
+ -------
209
+ t : FloatTensor, shape (bs)
210
+ xt : Tensor, shape (bs, *dim)
211
+ represents the samples drawn from probability path pt
212
+ ut : conditional vector field ut(x1|x0) = x1 - x0
213
+ (optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon
214
+
215
+ References
216
+ ----------
217
+ [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
218
+ """
219
+ if t is None:
220
+ # t = torch.rand(x0.shape[0]).type_as(x0)
221
+ t = self.time_sampler.sample([x0.shape[0]], x0.device).type_as(x0)
222
+
223
+ assert len(t) == x0.shape[0], "t has to have batch size dimension"
224
+
225
+ eps = self.sample_noise_like(x0)
226
+ xt = self.sample_xt(x0, x1, t, eps)
227
+ ut = self.compute_conditional_flow(x0, x1, t, xt)
228
+ if return_noise:
229
+ return t, xt, ut, eps
230
+ else:
231
+ return t, xt, ut
232
+
233
+ def compute_lambda(self, t):
234
+ """Compute the lambda function, see Eq.(23) [3].
235
+
236
+ Parameters
237
+ ----------
238
+ t : FloatTensor, shape (bs)
239
+
240
+ Returns
241
+ -------
242
+ lambda : score weighting function
243
+
244
+ References
245
+ ----------
246
+ [4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al.
247
+ """
248
+ sigma_t = self.compute_sigma_t(t)
249
+ return 2 * sigma_t / (self.sigma**2 + 1e-8)
250
+
251
+
252
+ class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher):
253
+ """Albergo et al. 2023 trigonometric interpolants class. This class inherits the
254
+ ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in
255
+ order to compute [3]'s trigonometric interpolants.
256
+
257
+ [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
258
+ """
259
+
260
+ def compute_mu_t(self, x0, x1, t):
261
+ r"""Compute the mean of the probability path (Eq.5) from [3].
262
+
263
+ Parameters
264
+ ----------
265
+ x0 : Tensor, shape (bs, *dim)
266
+ represents the source minibatch
267
+ x1 : Tensor, shape (bs, *dim)
268
+ represents the target minibatch
269
+ t : FloatTensor, shape (bs)
270
+
271
+ Returns
272
+ -------
273
+ mean mu_t: cos(pi t/2)x0 + sin(pi t/2)x1
274
+
275
+ References
276
+ ----------
277
+ [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
278
+ """
279
+ t = pad_t_like_x(t, x0)
280
+ return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1
281
+
282
+ def compute_conditional_flow(self, x0, x1, t, xt):
283
+ r"""Compute the conditional vector field similar to [3].
284
+
285
+ ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(pi*t/2) x0),
286
+ see Eq.(21) [3].
287
+
288
+ Parameters
289
+ ----------
290
+ x0 : Tensor, shape (bs, *dim)
291
+ represents the source minibatch
292
+ x1 : Tensor, shape (bs, *dim)
293
+ represents the target minibatch
294
+ t : FloatTensor, shape (bs)
295
+ xt : Tensor, shape (bs, *dim)
296
+ represents the samples drawn from probability path pt
297
+
298
+ Returns
299
+ -------
300
+ ut : conditional vector field
301
+ ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(\pi*t/2) x0)
302
+
303
+ References
304
+ ----------
305
+ [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
306
+ """
307
+ del xt
308
+ t = pad_t_like_x(t, x0)
309
+ return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0)
@@ -0,0 +1,180 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch import nn
17
+
18
+ from tts.modules.llm_dit.cfm import ConditionalFlowMatcher
19
+ from tts.modules.ar_dur.commons.layers import Embedding
20
+ from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
21
+ from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
22
+ from tts.modules.ar_dur.ar_dur_predictor import expand_states
23
+ from tts.modules.llm_dit.transformer import Transformer
24
+ from tts.modules.llm_dit.time_embedding import TimestepEmbedding
25
+
26
+
27
+ class Diffusion(nn.Module):
28
+ def __init__(self):
29
+ super().__init__()
30
+ # Hparams
31
+ # cond dim
32
+ self.local_cond_dim = 512
33
+ self.ctx_mask_dim = 16
34
+ self.in_channels = 32
35
+ self.out_channels = 32
36
+ # LLM
37
+ self.encoder_dim = 1024
38
+ self.encoder_n_layers = 24
39
+ self.encoder_n_heads = 16
40
+ self.max_seq_len = 16384
41
+ self.multiple_of = 256
42
+
43
+ self.ctx_mask_proj = nn.Linear(1, self.ctx_mask_dim)
44
+ self.local_cond_project = nn.Linear(
45
+ self.out_channels + self.ctx_mask_dim, self.local_cond_dim)
46
+
47
+ self.encoder = Transformer(self.encoder_n_layers, self.encoder_dim, self.encoder_n_heads, self.max_seq_len)
48
+
49
+ self.x_prenet = nn.Linear(self.in_channels, self.encoder_dim)
50
+ self.prenet = nn.Linear(self.local_cond_dim, self.encoder_dim)
51
+ self.postnet = nn.Linear(self.encoder_dim, self.out_channels)
52
+
53
+ self.flow_matcher = ConditionalFlowMatcher(sigma=0.0)
54
+ # The implementation of TimestepEmbedding is a modified version from F5-TTS (https://github.com/SWivid/F5-TTS),
55
+ # which is licensed under the MIT License.
56
+ self.f5_time_embed = TimestepEmbedding(self.encoder_dim)
57
+
58
+ # text encoder
59
+ self.ph_encoder = RelTransformerEncoder(
60
+ 302, self.encoder_dim, self.encoder_dim,
61
+ self.encoder_dim * 2, 4, 6,
62
+ 3, 0.0, prenet=True, pre_ln=True)
63
+ self.tone_embed = Embedding(32, self.encoder_dim, padding_idx=0)
64
+ self.ph_pos_embed = PosEmb(self.encoder_dim)
65
+ self.ling_pre_net = torch.nn.Sequential(*[
66
+ torch.nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=s * 2, stride=s, padding=s // 2)
67
+ for i, s in enumerate([2, 2])
68
+ ])
69
+
70
+ def forward(self, inputs, sigmas=None, x_noisy=None):
71
+ ctx_mask = inputs['ctx_mask']
72
+ ctx_feature = inputs['lat_ctx'] * ctx_mask
73
+
74
+ """ local conditioning (prompt_latent + spk_embed) """
75
+ ctx_mask_emb = self.ctx_mask_proj(ctx_mask)
76
+ # ctx_feature = ctx_feature * (1 - inputs["spk_cfg_mask"][:, :, None])
77
+ local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
78
+ local_cond = self.local_cond_project(local_cond)
79
+
80
+ """ diffusion target latent """
81
+ x = inputs['lat']
82
+
83
+ # Here, x is x1 in CFM
84
+ x0 = torch.randn_like(x)
85
+ t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x)
86
+
87
+ # define noisy_input and target
88
+ t = t.bfloat16()
89
+ x_noisy = (xt * (1 - ctx_mask)).bfloat16()
90
+ target = ut
91
+
92
+ # concat condition.
93
+ x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
94
+ x_ling = self.ling_pre_net(expand_states(x_ling, inputs['mel2ph']).transpose(1, 2)).transpose(1, 2)
95
+ x_noisy = self.x_prenet(x_noisy) + self.prenet(local_cond) + x_ling
96
+ encoder_out = self.encoder(x_noisy, self.f5_time_embed(t), attn_mask=inputs["text_mel_mask"], do_checkpoint=False)
97
+ pred = self.postnet(encoder_out)
98
+
99
+ return pred, target
100
+
101
+ def forward_ling_encoder(self, txt_tokens, tone_tokens):
102
+ ph_tokens = txt_tokens
103
+ ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1]
104
+
105
+ # enc_ph
106
+ ph_enc_oembed = self.tone_embed(tone_tokens)
107
+ ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
108
+ torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
109
+ ph_enc_oembed = ph_enc_oembed
110
+ ph_enc_oembed = ph_enc_oembed * ph_nonpadding
111
+ x_ling = self.ph_encoder(ph_tokens, other_embeds=ph_enc_oembed) * ph_nonpadding
112
+ return x_ling
113
+
114
+ def _forward(self, x, local_cond, x_ling, timesteps, ctx_mask, dur=None, seq_cfg_w=[1.0,1.0]):
115
+ """ When we use torchdiffeq, we need to include the CFG process inside _forward() """
116
+ x = x * (1 - ctx_mask)
117
+ x = self.x_prenet(x) + self.prenet(local_cond) + x_ling
118
+ pred_v = self.encoder(x, self.f5_time_embed(timesteps), attn_mask=torch.ones((x.size(0), x.size(1)), device=x.device))
119
+ pred = self.postnet(pred_v)
120
+
121
+ """ Perform multi-cond CFG """
122
+ cond_spk_txt, cond_txt, uncond = pred.chunk(3)
123
+ pred = uncond + seq_cfg_w[0] * (cond_txt - uncond) + seq_cfg_w[1] * (cond_spk_txt - cond_txt)
124
+ return pred
125
+
126
+ @torch.no_grad()
127
+ def inference(self, inputs, timesteps=20, seq_cfg_w=[1.0, 1.0], **kwargs):
128
+ # txt embedding
129
+ x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
130
+ x_ling = self.ling_pre_net(expand_states(x_ling, inputs['dur']).transpose(1, 2)).transpose(1, 2)
131
+
132
+ # speaker embedding
133
+ ctx_feature = inputs['lat_ctx']
134
+ ctx_feature[1:, :, :] = 0 # prefix spk cfg
135
+ ctx_mask_emb = self.ctx_mask_proj(inputs['ctx_mask'])
136
+
137
+ # local conditioning.
138
+ local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
139
+ local_cond = self.local_cond_project(local_cond)
140
+
141
+ ''' Euler ODE solver '''
142
+ bsz, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1))
143
+ # Sway sampling from F5-TTS (https://github.com/SWivid/F5-TTS),
144
+ # which is licensed under the MIT License.
145
+ sway_sampling_coef = -1.0
146
+ t_schedule = torch.linspace(0, 1, timesteps + 1, device=device, dtype=x_ling.dtype)
147
+ if sway_sampling_coef is not None:
148
+ t_schedule = t_schedule + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_schedule) - 1 + t_schedule)
149
+
150
+ # AMO sampling implementation for "AMO Sampler: Enhancing Text Rendering with Overshooting" (https://arxiv.org/pdf/2411.19415)
151
+ def amo_sampling(z_t, t, t_next, v):
152
+ # Upcast to avoid precision issues when computing prev_sample
153
+ z_t = z_t.to(torch.float32)
154
+
155
+ # Constant definition in Algorithm 1
156
+ s = t_next
157
+ c = 3
158
+
159
+ # Line 7 in Algorithm 1
160
+ o = min(t_next + c * (t_next - t), 1)
161
+ pred_z_o = z_t + (o - t) * v
162
+
163
+ # Line 11 in Algorithm 1
164
+ a = s / o
165
+ b = ((1 - s) ** 2 - (a * (1 - o)) ** 2) ** 0.5
166
+ noise_i = torch.randn(size=z_t.shape, device=z_t.device)
167
+ z_t_next = a * pred_z_o + b * noise_i
168
+ return z_t_next.to(v.dtype)
169
+
170
+ x = torch.randn([1, frm_len, self.out_channels], device=device)
171
+ for step_index in range(timesteps):
172
+ x = x.to(torch.float32)
173
+ sigma = t_schedule[step_index].to(x_ling.dtype)
174
+ sigma_next = t_schedule[step_index + 1]
175
+ model_out = self._forward(torch.cat([x] * bsz), local_cond, x_ling, timesteps=sigma.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=inputs['dur'], seq_cfg_w=seq_cfg_w)
176
+ x = amo_sampling(x, sigma, sigma_next, model_out)
177
+ # Cast sample back to model compatible dtype
178
+ x = x.to(model_out.dtype)
179
+
180
+ return x
@@ -0,0 +1,44 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import torch
17
+ from torch import nn
18
+
19
+
20
+ class SinusPositionEmbedding(nn.Module):
21
+ def __init__(self, dim):
22
+ super().__init__()
23
+ self.dim = dim
24
+
25
+ def forward(self, x, scale=1000):
26
+ device = x.device
27
+ half_dim = self.dim // 2
28
+ emb = math.log(10000) / (half_dim - 1)
29
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
30
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
31
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
32
+ return emb
33
+
34
+ class TimestepEmbedding(nn.Module):
35
+ def __init__(self, dim, freq_embed_dim=256):
36
+ super().__init__()
37
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
38
+ self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
39
+
40
+ def forward(self, timestep): # noqa: F821
41
+ time_hidden = self.time_embed(timestep)
42
+ time_hidden = time_hidden.to(timestep.dtype)
43
+ time = self.time_mlp(time_hidden) # b d
44
+ return time