xinference 1.4.1__py3-none-any.whl → 1.5.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +50 -1
- xinference/client/restful/restful_client.py +82 -2
- xinference/constants.py +3 -0
- xinference/core/chat_interface.py +297 -83
- xinference/core/model.py +1 -0
- xinference/core/progress_tracker.py +16 -8
- xinference/core/supervisor.py +45 -1
- xinference/core/worker.py +262 -37
- xinference/deploy/cmdline.py +33 -1
- xinference/model/audio/core.py +11 -1
- xinference/model/audio/megatts.py +105 -0
- xinference/model/audio/model_spec.json +24 -1
- xinference/model/audio/model_spec_modelscope.json +26 -1
- xinference/model/core.py +14 -0
- xinference/model/embedding/core.py +6 -1
- xinference/model/flexible/core.py +6 -1
- xinference/model/image/core.py +6 -1
- xinference/model/image/model_spec.json +17 -1
- xinference/model/image/model_spec_modelscope.json +17 -1
- xinference/model/llm/__init__.py +0 -4
- xinference/model/llm/core.py +4 -0
- xinference/model/llm/llama_cpp/core.py +40 -16
- xinference/model/llm/llm_family.json +415 -84
- xinference/model/llm/llm_family.py +24 -1
- xinference/model/llm/llm_family_modelscope.json +449 -0
- xinference/model/llm/mlx/core.py +16 -2
- xinference/model/llm/transformers/__init__.py +14 -0
- xinference/model/llm/transformers/core.py +30 -6
- xinference/model/llm/transformers/gemma3.py +17 -2
- xinference/model/llm/transformers/intern_vl.py +28 -18
- xinference/model/llm/transformers/minicpmv26.py +21 -2
- xinference/model/llm/transformers/qwen-omni.py +308 -0
- xinference/model/llm/transformers/qwen2_audio.py +1 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -4
- xinference/model/llm/utils.py +11 -1
- xinference/model/llm/vllm/core.py +35 -0
- xinference/model/llm/vllm/distributed_executor.py +8 -2
- xinference/model/rerank/core.py +6 -1
- xinference/model/utils.py +118 -1
- xinference/model/video/core.py +6 -1
- xinference/thirdparty/megatts3/__init__.py +0 -0
- xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
- xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
- xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
- xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
- xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
- xinference/types.py +10 -0
- xinference/utils.py +54 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
- xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
- xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
- xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
- xinference/web/ui/src/locales/en.json +2 -1
- xinference/web/ui/src/locales/zh.json +2 -1
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/METADATA +129 -114
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/RECORD +96 -60
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
- xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
- xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
- xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
- /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/entry_points.txt +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info/licenses}/LICENSE +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
# MIT License
|
|
2
|
+
|
|
3
|
+
# Copyright (c) 2023 Alexander Tong
|
|
4
|
+
|
|
5
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
# in the Software without restriction, including without limitation the rights
|
|
8
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
# furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
# copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
# SOFTWARE.
|
|
22
|
+
|
|
23
|
+
# Copyright (c) [2023] [Alexander Tong]
|
|
24
|
+
# Copyright (c) [2025] [Ziyue Jiang]
|
|
25
|
+
# SPDX-License-Identifier: MIT
|
|
26
|
+
# This file has been modified by Ziyue Jiang on 2025/03/19
|
|
27
|
+
# Original file was released under MIT, with the full license text # available at https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE.
|
|
28
|
+
# This modified file is released under the same license.
|
|
29
|
+
|
|
30
|
+
import math
|
|
31
|
+
import torch
|
|
32
|
+
from typing import Union
|
|
33
|
+
from torch.distributions import LogisticNormal
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class LogitNormalTrainingTimesteps:
|
|
37
|
+
def __init__(self, T=1000.0, loc=0.0, scale=1.0):
|
|
38
|
+
assert T > 0
|
|
39
|
+
self.T = T
|
|
40
|
+
self.dist = LogisticNormal(loc, scale)
|
|
41
|
+
|
|
42
|
+
def sample(self, size, device):
|
|
43
|
+
t = self.dist.sample(size)[..., 0].to(device)
|
|
44
|
+
return t
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def pad_t_like_x(t, x):
|
|
48
|
+
"""Function to reshape the time vector t by the number of dimensions of x.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
x : Tensor, shape (bs, *dim)
|
|
53
|
+
represents the source minibatch
|
|
54
|
+
t : FloatTensor, shape (bs)
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
t : Tensor, shape (bs, number of x dimensions)
|
|
59
|
+
|
|
60
|
+
Example
|
|
61
|
+
-------
|
|
62
|
+
x: Tensor (bs, C, W, H)
|
|
63
|
+
t: Vector (bs)
|
|
64
|
+
pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
|
|
65
|
+
"""
|
|
66
|
+
if isinstance(t, (float, int)):
|
|
67
|
+
return t
|
|
68
|
+
return t.reshape(-1, *([1] * (x.dim() - 1)))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ConditionalFlowMatcher:
|
|
72
|
+
"""Base class for conditional flow matching methods. This class implements the independent
|
|
73
|
+
conditional flow matching methods from [1] and serves as a parent class for all other flow
|
|
74
|
+
matching methods.
|
|
75
|
+
|
|
76
|
+
It implements:
|
|
77
|
+
- Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
|
|
78
|
+
- conditional flow matching ut(x1|x0) = x1 - x0
|
|
79
|
+
- score function $\nabla log p_t(x|x0, x1)$
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, sigma: Union[float, int] = 0.0):
|
|
83
|
+
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
sigma : Union[float, int]
|
|
88
|
+
"""
|
|
89
|
+
self.sigma = sigma
|
|
90
|
+
self.time_sampler = LogitNormalTrainingTimesteps()
|
|
91
|
+
|
|
92
|
+
def compute_mu_t(self, x0, x1, t):
|
|
93
|
+
"""
|
|
94
|
+
Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
|
|
95
|
+
|
|
96
|
+
Parameters
|
|
97
|
+
----------
|
|
98
|
+
x0 : Tensor, shape (bs, *dim)
|
|
99
|
+
represents the source minibatch
|
|
100
|
+
x1 : Tensor, shape (bs, *dim)
|
|
101
|
+
represents the target minibatch
|
|
102
|
+
t : FloatTensor, shape (bs)
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
mean mu_t: t * x1 + (1 - t) * x0
|
|
107
|
+
|
|
108
|
+
References
|
|
109
|
+
----------
|
|
110
|
+
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
|
|
111
|
+
"""
|
|
112
|
+
t = pad_t_like_x(t, x0)
|
|
113
|
+
return t * x1 + (1 - t) * x0
|
|
114
|
+
|
|
115
|
+
def compute_sigma_t(self, t):
|
|
116
|
+
"""
|
|
117
|
+
Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
t : FloatTensor, shape (bs)
|
|
122
|
+
|
|
123
|
+
Returns
|
|
124
|
+
-------
|
|
125
|
+
standard deviation sigma
|
|
126
|
+
|
|
127
|
+
References
|
|
128
|
+
----------
|
|
129
|
+
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
|
|
130
|
+
"""
|
|
131
|
+
del t
|
|
132
|
+
return self.sigma
|
|
133
|
+
|
|
134
|
+
def sample_xt(self, x0, x1, t, epsilon):
|
|
135
|
+
"""
|
|
136
|
+
Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
|
|
137
|
+
|
|
138
|
+
Parameters
|
|
139
|
+
----------
|
|
140
|
+
x0 : Tensor, shape (bs, *dim)
|
|
141
|
+
represents the source minibatch
|
|
142
|
+
x1 : Tensor, shape (bs, *dim)
|
|
143
|
+
represents the target minibatch
|
|
144
|
+
t : FloatTensor, shape (bs)
|
|
145
|
+
epsilon : Tensor, shape (bs, *dim)
|
|
146
|
+
noise sample from N(0, 1)
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
xt : Tensor, shape (bs, *dim)
|
|
151
|
+
|
|
152
|
+
References
|
|
153
|
+
----------
|
|
154
|
+
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
|
|
155
|
+
"""
|
|
156
|
+
mu_t = self.compute_mu_t(x0, x1, t)
|
|
157
|
+
sigma_t = self.compute_sigma_t(t)
|
|
158
|
+
sigma_t = pad_t_like_x(sigma_t, x0)
|
|
159
|
+
return mu_t + sigma_t * epsilon
|
|
160
|
+
|
|
161
|
+
def compute_conditional_flow(self, x0, x1, t, xt):
|
|
162
|
+
"""
|
|
163
|
+
Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
|
|
164
|
+
|
|
165
|
+
Parameters
|
|
166
|
+
----------
|
|
167
|
+
x0 : Tensor, shape (bs, *dim)
|
|
168
|
+
represents the source minibatch
|
|
169
|
+
x1 : Tensor, shape (bs, *dim)
|
|
170
|
+
represents the target minibatch
|
|
171
|
+
t : FloatTensor, shape (bs)
|
|
172
|
+
xt : Tensor, shape (bs, *dim)
|
|
173
|
+
represents the samples drawn from probability path pt
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
ut : conditional vector field ut(x1|x0) = x1 - x0
|
|
178
|
+
|
|
179
|
+
References
|
|
180
|
+
----------
|
|
181
|
+
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
|
|
182
|
+
"""
|
|
183
|
+
del t, xt
|
|
184
|
+
return x1 - x0
|
|
185
|
+
|
|
186
|
+
def sample_noise_like(self, x):
|
|
187
|
+
return torch.randn_like(x)
|
|
188
|
+
|
|
189
|
+
def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
|
|
190
|
+
"""
|
|
191
|
+
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
|
|
192
|
+
and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
x0 : Tensor, shape (bs, *dim)
|
|
197
|
+
represents the source minibatch
|
|
198
|
+
x1 : Tensor, shape (bs, *dim)
|
|
199
|
+
represents the target minibatch
|
|
200
|
+
(optionally) t : Tensor, shape (bs)
|
|
201
|
+
represents the time levels
|
|
202
|
+
if None, drawn from uniform [0,1]
|
|
203
|
+
return_noise : bool
|
|
204
|
+
return the noise sample epsilon
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
Returns
|
|
208
|
+
-------
|
|
209
|
+
t : FloatTensor, shape (bs)
|
|
210
|
+
xt : Tensor, shape (bs, *dim)
|
|
211
|
+
represents the samples drawn from probability path pt
|
|
212
|
+
ut : conditional vector field ut(x1|x0) = x1 - x0
|
|
213
|
+
(optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon
|
|
214
|
+
|
|
215
|
+
References
|
|
216
|
+
----------
|
|
217
|
+
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
|
|
218
|
+
"""
|
|
219
|
+
if t is None:
|
|
220
|
+
# t = torch.rand(x0.shape[0]).type_as(x0)
|
|
221
|
+
t = self.time_sampler.sample([x0.shape[0]], x0.device).type_as(x0)
|
|
222
|
+
|
|
223
|
+
assert len(t) == x0.shape[0], "t has to have batch size dimension"
|
|
224
|
+
|
|
225
|
+
eps = self.sample_noise_like(x0)
|
|
226
|
+
xt = self.sample_xt(x0, x1, t, eps)
|
|
227
|
+
ut = self.compute_conditional_flow(x0, x1, t, xt)
|
|
228
|
+
if return_noise:
|
|
229
|
+
return t, xt, ut, eps
|
|
230
|
+
else:
|
|
231
|
+
return t, xt, ut
|
|
232
|
+
|
|
233
|
+
def compute_lambda(self, t):
|
|
234
|
+
"""Compute the lambda function, see Eq.(23) [3].
|
|
235
|
+
|
|
236
|
+
Parameters
|
|
237
|
+
----------
|
|
238
|
+
t : FloatTensor, shape (bs)
|
|
239
|
+
|
|
240
|
+
Returns
|
|
241
|
+
-------
|
|
242
|
+
lambda : score weighting function
|
|
243
|
+
|
|
244
|
+
References
|
|
245
|
+
----------
|
|
246
|
+
[4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al.
|
|
247
|
+
"""
|
|
248
|
+
sigma_t = self.compute_sigma_t(t)
|
|
249
|
+
return 2 * sigma_t / (self.sigma**2 + 1e-8)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher):
|
|
253
|
+
"""Albergo et al. 2023 trigonometric interpolants class. This class inherits the
|
|
254
|
+
ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in
|
|
255
|
+
order to compute [3]'s trigonometric interpolants.
|
|
256
|
+
|
|
257
|
+
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
|
|
258
|
+
"""
|
|
259
|
+
|
|
260
|
+
def compute_mu_t(self, x0, x1, t):
|
|
261
|
+
r"""Compute the mean of the probability path (Eq.5) from [3].
|
|
262
|
+
|
|
263
|
+
Parameters
|
|
264
|
+
----------
|
|
265
|
+
x0 : Tensor, shape (bs, *dim)
|
|
266
|
+
represents the source minibatch
|
|
267
|
+
x1 : Tensor, shape (bs, *dim)
|
|
268
|
+
represents the target minibatch
|
|
269
|
+
t : FloatTensor, shape (bs)
|
|
270
|
+
|
|
271
|
+
Returns
|
|
272
|
+
-------
|
|
273
|
+
mean mu_t: cos(pi t/2)x0 + sin(pi t/2)x1
|
|
274
|
+
|
|
275
|
+
References
|
|
276
|
+
----------
|
|
277
|
+
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
|
|
278
|
+
"""
|
|
279
|
+
t = pad_t_like_x(t, x0)
|
|
280
|
+
return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1
|
|
281
|
+
|
|
282
|
+
def compute_conditional_flow(self, x0, x1, t, xt):
|
|
283
|
+
r"""Compute the conditional vector field similar to [3].
|
|
284
|
+
|
|
285
|
+
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(pi*t/2) x0),
|
|
286
|
+
see Eq.(21) [3].
|
|
287
|
+
|
|
288
|
+
Parameters
|
|
289
|
+
----------
|
|
290
|
+
x0 : Tensor, shape (bs, *dim)
|
|
291
|
+
represents the source minibatch
|
|
292
|
+
x1 : Tensor, shape (bs, *dim)
|
|
293
|
+
represents the target minibatch
|
|
294
|
+
t : FloatTensor, shape (bs)
|
|
295
|
+
xt : Tensor, shape (bs, *dim)
|
|
296
|
+
represents the samples drawn from probability path pt
|
|
297
|
+
|
|
298
|
+
Returns
|
|
299
|
+
-------
|
|
300
|
+
ut : conditional vector field
|
|
301
|
+
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(\pi*t/2) x0)
|
|
302
|
+
|
|
303
|
+
References
|
|
304
|
+
----------
|
|
305
|
+
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
|
|
306
|
+
"""
|
|
307
|
+
del xt
|
|
308
|
+
t = pad_t_like_x(t, x0)
|
|
309
|
+
return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0)
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from torch import nn
|
|
17
|
+
|
|
18
|
+
from tts.modules.llm_dit.cfm import ConditionalFlowMatcher
|
|
19
|
+
from tts.modules.ar_dur.commons.layers import Embedding
|
|
20
|
+
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
|
|
21
|
+
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
|
|
22
|
+
from tts.modules.ar_dur.ar_dur_predictor import expand_states
|
|
23
|
+
from tts.modules.llm_dit.transformer import Transformer
|
|
24
|
+
from tts.modules.llm_dit.time_embedding import TimestepEmbedding
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Diffusion(nn.Module):
|
|
28
|
+
def __init__(self):
|
|
29
|
+
super().__init__()
|
|
30
|
+
# Hparams
|
|
31
|
+
# cond dim
|
|
32
|
+
self.local_cond_dim = 512
|
|
33
|
+
self.ctx_mask_dim = 16
|
|
34
|
+
self.in_channels = 32
|
|
35
|
+
self.out_channels = 32
|
|
36
|
+
# LLM
|
|
37
|
+
self.encoder_dim = 1024
|
|
38
|
+
self.encoder_n_layers = 24
|
|
39
|
+
self.encoder_n_heads = 16
|
|
40
|
+
self.max_seq_len = 16384
|
|
41
|
+
self.multiple_of = 256
|
|
42
|
+
|
|
43
|
+
self.ctx_mask_proj = nn.Linear(1, self.ctx_mask_dim)
|
|
44
|
+
self.local_cond_project = nn.Linear(
|
|
45
|
+
self.out_channels + self.ctx_mask_dim, self.local_cond_dim)
|
|
46
|
+
|
|
47
|
+
self.encoder = Transformer(self.encoder_n_layers, self.encoder_dim, self.encoder_n_heads, self.max_seq_len)
|
|
48
|
+
|
|
49
|
+
self.x_prenet = nn.Linear(self.in_channels, self.encoder_dim)
|
|
50
|
+
self.prenet = nn.Linear(self.local_cond_dim, self.encoder_dim)
|
|
51
|
+
self.postnet = nn.Linear(self.encoder_dim, self.out_channels)
|
|
52
|
+
|
|
53
|
+
self.flow_matcher = ConditionalFlowMatcher(sigma=0.0)
|
|
54
|
+
# The implementation of TimestepEmbedding is a modified version from F5-TTS (https://github.com/SWivid/F5-TTS),
|
|
55
|
+
# which is licensed under the MIT License.
|
|
56
|
+
self.f5_time_embed = TimestepEmbedding(self.encoder_dim)
|
|
57
|
+
|
|
58
|
+
# text encoder
|
|
59
|
+
self.ph_encoder = RelTransformerEncoder(
|
|
60
|
+
302, self.encoder_dim, self.encoder_dim,
|
|
61
|
+
self.encoder_dim * 2, 4, 6,
|
|
62
|
+
3, 0.0, prenet=True, pre_ln=True)
|
|
63
|
+
self.tone_embed = Embedding(32, self.encoder_dim, padding_idx=0)
|
|
64
|
+
self.ph_pos_embed = PosEmb(self.encoder_dim)
|
|
65
|
+
self.ling_pre_net = torch.nn.Sequential(*[
|
|
66
|
+
torch.nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=s * 2, stride=s, padding=s // 2)
|
|
67
|
+
for i, s in enumerate([2, 2])
|
|
68
|
+
])
|
|
69
|
+
|
|
70
|
+
def forward(self, inputs, sigmas=None, x_noisy=None):
|
|
71
|
+
ctx_mask = inputs['ctx_mask']
|
|
72
|
+
ctx_feature = inputs['lat_ctx'] * ctx_mask
|
|
73
|
+
|
|
74
|
+
""" local conditioning (prompt_latent + spk_embed) """
|
|
75
|
+
ctx_mask_emb = self.ctx_mask_proj(ctx_mask)
|
|
76
|
+
# ctx_feature = ctx_feature * (1 - inputs["spk_cfg_mask"][:, :, None])
|
|
77
|
+
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
|
|
78
|
+
local_cond = self.local_cond_project(local_cond)
|
|
79
|
+
|
|
80
|
+
""" diffusion target latent """
|
|
81
|
+
x = inputs['lat']
|
|
82
|
+
|
|
83
|
+
# Here, x is x1 in CFM
|
|
84
|
+
x0 = torch.randn_like(x)
|
|
85
|
+
t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x)
|
|
86
|
+
|
|
87
|
+
# define noisy_input and target
|
|
88
|
+
t = t.bfloat16()
|
|
89
|
+
x_noisy = (xt * (1 - ctx_mask)).bfloat16()
|
|
90
|
+
target = ut
|
|
91
|
+
|
|
92
|
+
# concat condition.
|
|
93
|
+
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
|
|
94
|
+
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['mel2ph']).transpose(1, 2)).transpose(1, 2)
|
|
95
|
+
x_noisy = self.x_prenet(x_noisy) + self.prenet(local_cond) + x_ling
|
|
96
|
+
encoder_out = self.encoder(x_noisy, self.f5_time_embed(t), attn_mask=inputs["text_mel_mask"], do_checkpoint=False)
|
|
97
|
+
pred = self.postnet(encoder_out)
|
|
98
|
+
|
|
99
|
+
return pred, target
|
|
100
|
+
|
|
101
|
+
def forward_ling_encoder(self, txt_tokens, tone_tokens):
|
|
102
|
+
ph_tokens = txt_tokens
|
|
103
|
+
ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1]
|
|
104
|
+
|
|
105
|
+
# enc_ph
|
|
106
|
+
ph_enc_oembed = self.tone_embed(tone_tokens)
|
|
107
|
+
ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
|
|
108
|
+
torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
|
|
109
|
+
ph_enc_oembed = ph_enc_oembed
|
|
110
|
+
ph_enc_oembed = ph_enc_oembed * ph_nonpadding
|
|
111
|
+
x_ling = self.ph_encoder(ph_tokens, other_embeds=ph_enc_oembed) * ph_nonpadding
|
|
112
|
+
return x_ling
|
|
113
|
+
|
|
114
|
+
def _forward(self, x, local_cond, x_ling, timesteps, ctx_mask, dur=None, seq_cfg_w=[1.0,1.0]):
|
|
115
|
+
""" When we use torchdiffeq, we need to include the CFG process inside _forward() """
|
|
116
|
+
x = x * (1 - ctx_mask)
|
|
117
|
+
x = self.x_prenet(x) + self.prenet(local_cond) + x_ling
|
|
118
|
+
pred_v = self.encoder(x, self.f5_time_embed(timesteps), attn_mask=torch.ones((x.size(0), x.size(1)), device=x.device))
|
|
119
|
+
pred = self.postnet(pred_v)
|
|
120
|
+
|
|
121
|
+
""" Perform multi-cond CFG """
|
|
122
|
+
cond_spk_txt, cond_txt, uncond = pred.chunk(3)
|
|
123
|
+
pred = uncond + seq_cfg_w[0] * (cond_txt - uncond) + seq_cfg_w[1] * (cond_spk_txt - cond_txt)
|
|
124
|
+
return pred
|
|
125
|
+
|
|
126
|
+
@torch.no_grad()
|
|
127
|
+
def inference(self, inputs, timesteps=20, seq_cfg_w=[1.0, 1.0], **kwargs):
|
|
128
|
+
# txt embedding
|
|
129
|
+
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
|
|
130
|
+
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['dur']).transpose(1, 2)).transpose(1, 2)
|
|
131
|
+
|
|
132
|
+
# speaker embedding
|
|
133
|
+
ctx_feature = inputs['lat_ctx']
|
|
134
|
+
ctx_feature[1:, :, :] = 0 # prefix spk cfg
|
|
135
|
+
ctx_mask_emb = self.ctx_mask_proj(inputs['ctx_mask'])
|
|
136
|
+
|
|
137
|
+
# local conditioning.
|
|
138
|
+
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
|
|
139
|
+
local_cond = self.local_cond_project(local_cond)
|
|
140
|
+
|
|
141
|
+
''' Euler ODE solver '''
|
|
142
|
+
bsz, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1))
|
|
143
|
+
# Sway sampling from F5-TTS (https://github.com/SWivid/F5-TTS),
|
|
144
|
+
# which is licensed under the MIT License.
|
|
145
|
+
sway_sampling_coef = -1.0
|
|
146
|
+
t_schedule = torch.linspace(0, 1, timesteps + 1, device=device, dtype=x_ling.dtype)
|
|
147
|
+
if sway_sampling_coef is not None:
|
|
148
|
+
t_schedule = t_schedule + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_schedule) - 1 + t_schedule)
|
|
149
|
+
|
|
150
|
+
# AMO sampling implementation for "AMO Sampler: Enhancing Text Rendering with Overshooting" (https://arxiv.org/pdf/2411.19415)
|
|
151
|
+
def amo_sampling(z_t, t, t_next, v):
|
|
152
|
+
# Upcast to avoid precision issues when computing prev_sample
|
|
153
|
+
z_t = z_t.to(torch.float32)
|
|
154
|
+
|
|
155
|
+
# Constant definition in Algorithm 1
|
|
156
|
+
s = t_next
|
|
157
|
+
c = 3
|
|
158
|
+
|
|
159
|
+
# Line 7 in Algorithm 1
|
|
160
|
+
o = min(t_next + c * (t_next - t), 1)
|
|
161
|
+
pred_z_o = z_t + (o - t) * v
|
|
162
|
+
|
|
163
|
+
# Line 11 in Algorithm 1
|
|
164
|
+
a = s / o
|
|
165
|
+
b = ((1 - s) ** 2 - (a * (1 - o)) ** 2) ** 0.5
|
|
166
|
+
noise_i = torch.randn(size=z_t.shape, device=z_t.device)
|
|
167
|
+
z_t_next = a * pred_z_o + b * noise_i
|
|
168
|
+
return z_t_next.to(v.dtype)
|
|
169
|
+
|
|
170
|
+
x = torch.randn([1, frm_len, self.out_channels], device=device)
|
|
171
|
+
for step_index in range(timesteps):
|
|
172
|
+
x = x.to(torch.float32)
|
|
173
|
+
sigma = t_schedule[step_index].to(x_ling.dtype)
|
|
174
|
+
sigma_next = t_schedule[step_index + 1]
|
|
175
|
+
model_out = self._forward(torch.cat([x] * bsz), local_cond, x_ling, timesteps=sigma.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=inputs['dur'], seq_cfg_w=seq_cfg_w)
|
|
176
|
+
x = amo_sampling(x, sigma, sigma_next, model_out)
|
|
177
|
+
# Cast sample back to model compatible dtype
|
|
178
|
+
x = x.to(model_out.dtype)
|
|
179
|
+
|
|
180
|
+
return x
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import math
|
|
16
|
+
import torch
|
|
17
|
+
from torch import nn
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SinusPositionEmbedding(nn.Module):
|
|
21
|
+
def __init__(self, dim):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.dim = dim
|
|
24
|
+
|
|
25
|
+
def forward(self, x, scale=1000):
|
|
26
|
+
device = x.device
|
|
27
|
+
half_dim = self.dim // 2
|
|
28
|
+
emb = math.log(10000) / (half_dim - 1)
|
|
29
|
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
|
30
|
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
|
31
|
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
|
32
|
+
return emb
|
|
33
|
+
|
|
34
|
+
class TimestepEmbedding(nn.Module):
|
|
35
|
+
def __init__(self, dim, freq_embed_dim=256):
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
|
38
|
+
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
|
39
|
+
|
|
40
|
+
def forward(self, timestep): # noqa: F821
|
|
41
|
+
time_hidden = self.time_embed(timestep)
|
|
42
|
+
time_hidden = time_hidden.to(timestep.dtype)
|
|
43
|
+
time = self.time_mlp(time_hidden) # b d
|
|
44
|
+
return time
|