xinference 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +23 -1
- xinference/core/model.py +1 -6
- xinference/core/utils.py +10 -6
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +15 -10
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +35 -111
- xinference/model/audio/model_spec.json +19 -3
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +127 -4
- xinference/model/image/model_spec_modelscope.json +130 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/llm_family.json +47 -0
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +49 -0
- xinference/model/llm/mlx/core.py +68 -13
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/qwen2_vl.py +2 -0
- xinference/model/llm/utils.py +1 -0
- xinference/model/llm/vllm/core.py +11 -2
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/METADATA +11 -6
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/RECORD +95 -74
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
|
@@ -23,7 +23,7 @@ import torch.nn.functional as F
|
|
|
23
23
|
|
|
24
24
|
torchaudio.set_audio_backend('soundfile')
|
|
25
25
|
|
|
26
|
-
AUDIO_FORMAT_SETS =
|
|
26
|
+
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def parquet_opener(data, mode='train', tts_data={}):
|
|
@@ -40,20 +40,22 @@ def parquet_opener(data, mode='train', tts_data={}):
|
|
|
40
40
|
assert 'src' in sample
|
|
41
41
|
url = sample['src']
|
|
42
42
|
try:
|
|
43
|
-
df
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
43
|
+
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
|
|
44
|
+
df = df.to_pandas()
|
|
45
|
+
for i in range(len(df)):
|
|
46
|
+
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
|
|
47
|
+
continue
|
|
48
|
+
sample.update(dict(df.loc[i]))
|
|
49
|
+
if mode == 'train':
|
|
50
|
+
# NOTE do not return sample directly, must initialize a new dict
|
|
51
|
+
yield {**sample}
|
|
52
|
+
else:
|
|
53
|
+
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
|
|
54
|
+
yield {**sample, 'tts_index': index, 'tts_text': text}
|
|
54
55
|
except Exception as ex:
|
|
55
56
|
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
|
56
57
|
|
|
58
|
+
|
|
57
59
|
def filter(data,
|
|
58
60
|
max_length=10240,
|
|
59
61
|
min_length=10,
|
|
@@ -84,6 +86,7 @@ def filter(data,
|
|
|
84
86
|
"""
|
|
85
87
|
for sample in data:
|
|
86
88
|
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
|
|
89
|
+
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
|
87
90
|
del sample['audio_data']
|
|
88
91
|
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
|
89
92
|
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
|
|
@@ -133,6 +136,27 @@ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
|
|
|
133
136
|
yield sample
|
|
134
137
|
|
|
135
138
|
|
|
139
|
+
def truncate(data, truncate_length=24576, mode='train'):
|
|
140
|
+
""" Truncate data.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
data: Iterable[{key, wav, label, sample_rate}]
|
|
144
|
+
truncate_length: truncate length
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Iterable[{key, wav, label, sample_rate}]
|
|
148
|
+
"""
|
|
149
|
+
for sample in data:
|
|
150
|
+
waveform = sample['speech']
|
|
151
|
+
if waveform.shape[1] > truncate_length:
|
|
152
|
+
start = random.randint(0, waveform.shape[1] - truncate_length)
|
|
153
|
+
waveform = waveform[:, start: start + truncate_length]
|
|
154
|
+
else:
|
|
155
|
+
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
|
|
156
|
+
sample['speech'] = waveform
|
|
157
|
+
yield sample
|
|
158
|
+
|
|
159
|
+
|
|
136
160
|
def compute_fbank(data,
|
|
137
161
|
feat_extractor,
|
|
138
162
|
mode='train'):
|
|
@@ -152,7 +176,27 @@ def compute_fbank(data,
|
|
|
152
176
|
waveform = sample['speech']
|
|
153
177
|
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
|
154
178
|
sample['speech_feat'] = mat
|
|
155
|
-
|
|
179
|
+
yield sample
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def compute_f0(data, pitch_extractor, mode='train'):
|
|
183
|
+
""" Extract f0
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
data: Iterable[{key, wav, label, sample_rate}]
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Iterable[{key, feat, label}]
|
|
190
|
+
"""
|
|
191
|
+
for sample in data:
|
|
192
|
+
assert 'sample_rate' in sample
|
|
193
|
+
assert 'speech' in sample
|
|
194
|
+
assert 'utt' in sample
|
|
195
|
+
assert 'text_token' in sample
|
|
196
|
+
waveform = sample['speech']
|
|
197
|
+
mat = pitch_extractor(waveform).transpose(1, 2)
|
|
198
|
+
mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
|
|
199
|
+
sample['pitch_feat'] = mat[0, 0]
|
|
156
200
|
yield sample
|
|
157
201
|
|
|
158
202
|
|
|
@@ -308,7 +352,7 @@ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, m
|
|
|
308
352
|
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
|
309
353
|
|
|
310
354
|
|
|
311
|
-
def padding(data, use_spk_embedding, mode='train'):
|
|
355
|
+
def padding(data, use_spk_embedding, mode='train', gan=False):
|
|
312
356
|
""" Padding the data into training data
|
|
313
357
|
|
|
314
358
|
Args:
|
|
@@ -324,6 +368,9 @@ def padding(data, use_spk_embedding, mode='train'):
|
|
|
324
368
|
order = torch.argsort(speech_feat_len, descending=True)
|
|
325
369
|
|
|
326
370
|
utts = [sample[i]['utt'] for i in order]
|
|
371
|
+
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
|
|
372
|
+
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
|
|
373
|
+
speech = pad_sequence(speech, batch_first=True, padding_value=0)
|
|
327
374
|
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
|
328
375
|
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
|
329
376
|
speech_token = pad_sequence(speech_token,
|
|
@@ -342,6 +389,8 @@ def padding(data, use_spk_embedding, mode='train'):
|
|
|
342
389
|
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
|
343
390
|
batch = {
|
|
344
391
|
"utts": utts,
|
|
392
|
+
"speech": speech,
|
|
393
|
+
"speech_len": speech_len,
|
|
345
394
|
"speech_token": speech_token,
|
|
346
395
|
"speech_token_len": speech_token_len,
|
|
347
396
|
"speech_feat": speech_feat,
|
|
@@ -352,6 +401,19 @@ def padding(data, use_spk_embedding, mode='train'):
|
|
|
352
401
|
"utt_embedding": utt_embedding,
|
|
353
402
|
"spk_embedding": spk_embedding,
|
|
354
403
|
}
|
|
404
|
+
if gan is True:
|
|
405
|
+
# in gan train, we need pitch_feat
|
|
406
|
+
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
|
407
|
+
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
|
408
|
+
pitch_feat = pad_sequence(pitch_feat,
|
|
409
|
+
batch_first=True,
|
|
410
|
+
padding_value=0)
|
|
411
|
+
batch["pitch_feat"] = pitch_feat
|
|
412
|
+
batch["pitch_feat_len"] = pitch_feat_len
|
|
413
|
+
else:
|
|
414
|
+
# only gan train needs speech, delete it to save memory
|
|
415
|
+
del batch["speech"]
|
|
416
|
+
del batch["speech_len"]
|
|
355
417
|
if mode == 'inference':
|
|
356
418
|
tts_text = [sample[i]['tts_text'] for i in order]
|
|
357
419
|
tts_index = [sample[i]['tts_index'] for i in order]
|
|
@@ -13,16 +13,83 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import torch
|
|
15
15
|
import torch.nn as nn
|
|
16
|
+
import torch.nn.functional as F
|
|
16
17
|
from einops import pack, rearrange, repeat
|
|
18
|
+
from cosyvoice.utils.common import mask_to_bias
|
|
19
|
+
from cosyvoice.utils.mask import add_optional_chunk_mask
|
|
17
20
|
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
|
18
21
|
from matcha.models.components.transformer import BasicTransformerBlock
|
|
19
22
|
|
|
20
23
|
|
|
24
|
+
class Transpose(torch.nn.Module):
|
|
25
|
+
def __init__(self, dim0: int, dim1: int):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.dim0 = dim0
|
|
28
|
+
self.dim1 = dim1
|
|
29
|
+
|
|
30
|
+
def forward(self, x: torch.Tensor):
|
|
31
|
+
x = torch.transpose(x, self.dim0, self.dim1)
|
|
32
|
+
return x
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class CausalBlock1D(Block1D):
|
|
36
|
+
def __init__(self, dim: int, dim_out: int):
|
|
37
|
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
|
38
|
+
self.block = torch.nn.Sequential(
|
|
39
|
+
CausalConv1d(dim, dim_out, 3),
|
|
40
|
+
Transpose(1, 2),
|
|
41
|
+
nn.LayerNorm(dim_out),
|
|
42
|
+
Transpose(1, 2),
|
|
43
|
+
nn.Mish(),
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
|
47
|
+
output = self.block(x * mask)
|
|
48
|
+
return output * mask
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class CausalResnetBlock1D(ResnetBlock1D):
|
|
52
|
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
|
53
|
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
|
54
|
+
self.block1 = CausalBlock1D(dim, dim_out)
|
|
55
|
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class CausalConv1d(torch.nn.Conv1d):
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
in_channels: int,
|
|
62
|
+
out_channels: int,
|
|
63
|
+
kernel_size: int,
|
|
64
|
+
stride: int = 1,
|
|
65
|
+
dilation: int = 1,
|
|
66
|
+
groups: int = 1,
|
|
67
|
+
bias: bool = True,
|
|
68
|
+
padding_mode: str = 'zeros',
|
|
69
|
+
device=None,
|
|
70
|
+
dtype=None
|
|
71
|
+
) -> None:
|
|
72
|
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
|
73
|
+
kernel_size, stride,
|
|
74
|
+
padding=0, dilation=dilation,
|
|
75
|
+
groups=groups, bias=bias,
|
|
76
|
+
padding_mode=padding_mode,
|
|
77
|
+
device=device, dtype=dtype)
|
|
78
|
+
assert stride == 1
|
|
79
|
+
self.causal_padding = (kernel_size - 1, 0)
|
|
80
|
+
|
|
81
|
+
def forward(self, x: torch.Tensor):
|
|
82
|
+
x = F.pad(x, self.causal_padding)
|
|
83
|
+
x = super(CausalConv1d, self).forward(x)
|
|
84
|
+
return x
|
|
85
|
+
|
|
86
|
+
|
|
21
87
|
class ConditionalDecoder(nn.Module):
|
|
22
88
|
def __init__(
|
|
23
89
|
self,
|
|
24
90
|
in_channels,
|
|
25
91
|
out_channels,
|
|
92
|
+
causal=False,
|
|
26
93
|
channels=(256, 256),
|
|
27
94
|
dropout=0.05,
|
|
28
95
|
attention_head_dim=64,
|
|
@@ -39,7 +106,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
39
106
|
channels = tuple(channels)
|
|
40
107
|
self.in_channels = in_channels
|
|
41
108
|
self.out_channels = out_channels
|
|
42
|
-
|
|
109
|
+
self.causal = causal
|
|
43
110
|
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
|
44
111
|
time_embed_dim = channels[0] * 4
|
|
45
112
|
self.time_mlp = TimestepEmbedding(
|
|
@@ -56,7 +123,8 @@ class ConditionalDecoder(nn.Module):
|
|
|
56
123
|
input_channel = output_channel
|
|
57
124
|
output_channel = channels[i]
|
|
58
125
|
is_last = i == len(channels) - 1
|
|
59
|
-
resnet =
|
|
126
|
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
|
127
|
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
60
128
|
transformer_blocks = nn.ModuleList(
|
|
61
129
|
[
|
|
62
130
|
BasicTransformerBlock(
|
|
@@ -70,14 +138,16 @@ class ConditionalDecoder(nn.Module):
|
|
|
70
138
|
]
|
|
71
139
|
)
|
|
72
140
|
downsample = (
|
|
73
|
-
Downsample1D(output_channel) if not is_last else
|
|
141
|
+
Downsample1D(output_channel) if not is_last else
|
|
142
|
+
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
74
143
|
)
|
|
75
144
|
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
|
76
145
|
|
|
77
|
-
for
|
|
146
|
+
for _ in range(num_mid_blocks):
|
|
78
147
|
input_channel = channels[-1]
|
|
79
148
|
out_channels = channels[-1]
|
|
80
|
-
resnet =
|
|
149
|
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
|
150
|
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
81
151
|
|
|
82
152
|
transformer_blocks = nn.ModuleList(
|
|
83
153
|
[
|
|
@@ -99,7 +169,11 @@ class ConditionalDecoder(nn.Module):
|
|
|
99
169
|
input_channel = channels[i] * 2
|
|
100
170
|
output_channel = channels[i + 1]
|
|
101
171
|
is_last = i == len(channels) - 2
|
|
102
|
-
resnet =
|
|
172
|
+
resnet = CausalResnetBlock1D(
|
|
173
|
+
dim=input_channel,
|
|
174
|
+
dim_out=output_channel,
|
|
175
|
+
time_emb_dim=time_embed_dim,
|
|
176
|
+
) if self.causal else ResnetBlock1D(
|
|
103
177
|
dim=input_channel,
|
|
104
178
|
dim_out=output_channel,
|
|
105
179
|
time_emb_dim=time_embed_dim,
|
|
@@ -119,14 +193,13 @@ class ConditionalDecoder(nn.Module):
|
|
|
119
193
|
upsample = (
|
|
120
194
|
Upsample1D(output_channel, use_conv_transpose=True)
|
|
121
195
|
if not is_last
|
|
122
|
-
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
196
|
+
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
123
197
|
)
|
|
124
198
|
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
|
125
|
-
self.final_block = Block1D(channels[-1], channels[-1])
|
|
199
|
+
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
|
126
200
|
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
|
127
201
|
self.initialize_weights()
|
|
128
202
|
|
|
129
|
-
|
|
130
203
|
def initialize_weights(self):
|
|
131
204
|
for m in self.modules():
|
|
132
205
|
if isinstance(m, nn.Conv1d):
|
|
@@ -159,7 +232,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
159
232
|
_type_: _description_
|
|
160
233
|
"""
|
|
161
234
|
|
|
162
|
-
t = self.time_embeddings(t)
|
|
235
|
+
t = self.time_embeddings(t).to(t.dtype)
|
|
163
236
|
t = self.time_mlp(t)
|
|
164
237
|
|
|
165
238
|
x = pack([x, mu], "b * t")[0]
|
|
@@ -176,7 +249,9 @@ class ConditionalDecoder(nn.Module):
|
|
|
176
249
|
mask_down = masks[-1]
|
|
177
250
|
x = resnet(x, mask_down, t)
|
|
178
251
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
179
|
-
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
|
252
|
+
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
|
253
|
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
254
|
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
|
180
255
|
for transformer_block in transformer_blocks:
|
|
181
256
|
x = transformer_block(
|
|
182
257
|
hidden_states=x,
|
|
@@ -193,7 +268,9 @@ class ConditionalDecoder(nn.Module):
|
|
|
193
268
|
for resnet, transformer_blocks in self.mid_blocks:
|
|
194
269
|
x = resnet(x, mask_mid, t)
|
|
195
270
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
196
|
-
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
|
271
|
+
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
|
272
|
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
273
|
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
|
197
274
|
for transformer_block in transformer_blocks:
|
|
198
275
|
x = transformer_block(
|
|
199
276
|
hidden_states=x,
|
|
@@ -208,7 +285,9 @@ class ConditionalDecoder(nn.Module):
|
|
|
208
285
|
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
|
209
286
|
x = resnet(x, mask_up, t)
|
|
210
287
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
211
|
-
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
|
288
|
+
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
|
289
|
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
290
|
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
|
212
291
|
for transformer_block in transformer_blocks:
|
|
213
292
|
x = transformer_block(
|
|
214
293
|
hidden_states=x,
|
|
@@ -33,8 +33,13 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
|
33
33
|
encoder: torch.nn.Module = None,
|
|
34
34
|
length_regulator: torch.nn.Module = None,
|
|
35
35
|
decoder: torch.nn.Module = None,
|
|
36
|
-
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
|
37
|
-
|
|
36
|
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
|
37
|
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
|
38
|
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
|
39
|
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
|
40
|
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
|
41
|
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
|
42
|
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
|
38
43
|
super().__init__()
|
|
39
44
|
self.input_size = input_size
|
|
40
45
|
self.output_size = output_size
|
|
@@ -104,7 +109,8 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
|
104
109
|
prompt_token_len,
|
|
105
110
|
prompt_feat,
|
|
106
111
|
prompt_feat_len,
|
|
107
|
-
embedding
|
|
112
|
+
embedding,
|
|
113
|
+
flow_cache):
|
|
108
114
|
assert token.shape[0] == 1
|
|
109
115
|
# xvec projection
|
|
110
116
|
embedding = F.normalize(embedding, dim=1)
|
|
@@ -113,23 +119,107 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
|
113
119
|
# concat text and prompt_text
|
|
114
120
|
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
|
115
121
|
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
|
116
|
-
mask = (~make_pad_mask(token_len)).
|
|
122
|
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
|
117
123
|
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
|
118
124
|
|
|
119
125
|
# text encode
|
|
120
126
|
h, h_lengths = self.encoder(token, token_len)
|
|
121
127
|
h = self.encoder_proj(h)
|
|
122
|
-
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 /
|
|
123
|
-
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
|
|
128
|
+
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
|
129
|
+
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
|
124
130
|
|
|
125
131
|
# get conditions
|
|
126
132
|
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
|
127
133
|
conds[:, :mel_len1] = prompt_feat
|
|
128
134
|
conds = conds.transpose(1, 2)
|
|
129
135
|
|
|
130
|
-
# mask = (~make_pad_mask(feat_len)).to(h)
|
|
131
136
|
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
|
132
|
-
feat = self.decoder(
|
|
137
|
+
feat, flow_cache = self.decoder(
|
|
138
|
+
mu=h.transpose(1, 2).contiguous(),
|
|
139
|
+
mask=mask.unsqueeze(1),
|
|
140
|
+
spks=embedding,
|
|
141
|
+
cond=conds,
|
|
142
|
+
n_timesteps=10,
|
|
143
|
+
prompt_len=mel_len1,
|
|
144
|
+
flow_cache=flow_cache
|
|
145
|
+
)
|
|
146
|
+
feat = feat[:, :, mel_len1:]
|
|
147
|
+
assert feat.shape[2] == mel_len2
|
|
148
|
+
return feat, flow_cache
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
152
|
+
def __init__(self,
|
|
153
|
+
input_size: int = 512,
|
|
154
|
+
output_size: int = 80,
|
|
155
|
+
spk_embed_dim: int = 192,
|
|
156
|
+
output_type: str = "mel",
|
|
157
|
+
vocab_size: int = 4096,
|
|
158
|
+
input_frame_rate: int = 50,
|
|
159
|
+
only_mask_loss: bool = True,
|
|
160
|
+
token_mel_ratio: int = 2,
|
|
161
|
+
pre_lookahead_len: int = 3,
|
|
162
|
+
encoder: torch.nn.Module = None,
|
|
163
|
+
decoder: torch.nn.Module = None,
|
|
164
|
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
|
165
|
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
|
166
|
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
|
167
|
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
|
168
|
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
|
169
|
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
|
170
|
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
|
171
|
+
super().__init__()
|
|
172
|
+
self.input_size = input_size
|
|
173
|
+
self.output_size = output_size
|
|
174
|
+
self.decoder_conf = decoder_conf
|
|
175
|
+
self.mel_feat_conf = mel_feat_conf
|
|
176
|
+
self.vocab_size = vocab_size
|
|
177
|
+
self.output_type = output_type
|
|
178
|
+
self.input_frame_rate = input_frame_rate
|
|
179
|
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
|
180
|
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
|
181
|
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
|
182
|
+
self.encoder = encoder
|
|
183
|
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
|
184
|
+
self.decoder = decoder
|
|
185
|
+
self.only_mask_loss = only_mask_loss
|
|
186
|
+
self.token_mel_ratio = token_mel_ratio
|
|
187
|
+
self.pre_lookahead_len = pre_lookahead_len
|
|
188
|
+
|
|
189
|
+
@torch.inference_mode()
|
|
190
|
+
def inference(self,
|
|
191
|
+
token,
|
|
192
|
+
token_len,
|
|
193
|
+
prompt_token,
|
|
194
|
+
prompt_token_len,
|
|
195
|
+
prompt_feat,
|
|
196
|
+
prompt_feat_len,
|
|
197
|
+
embedding,
|
|
198
|
+
finalize):
|
|
199
|
+
assert token.shape[0] == 1
|
|
200
|
+
# xvec projection
|
|
201
|
+
embedding = F.normalize(embedding, dim=1)
|
|
202
|
+
embedding = self.spk_embed_affine_layer(embedding)
|
|
203
|
+
|
|
204
|
+
# concat text and prompt_text
|
|
205
|
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
|
206
|
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
|
207
|
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
|
208
|
+
|
|
209
|
+
# text encode
|
|
210
|
+
h, h_lengths = self.encoder(token, token_len)
|
|
211
|
+
if finalize is False:
|
|
212
|
+
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
|
|
213
|
+
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
|
214
|
+
h = self.encoder_proj(h)
|
|
215
|
+
|
|
216
|
+
# get conditions
|
|
217
|
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
|
218
|
+
conds[:, :mel_len1] = prompt_feat
|
|
219
|
+
conds = conds.transpose(1, 2)
|
|
220
|
+
|
|
221
|
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
|
222
|
+
feat, _ = self.decoder(
|
|
133
223
|
mu=h.transpose(1, 2).contiguous(),
|
|
134
224
|
mask=mask.unsqueeze(1),
|
|
135
225
|
spks=embedding,
|
|
@@ -138,4 +228,4 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
|
138
228
|
)
|
|
139
229
|
feat = feat[:, :, mel_len1:]
|
|
140
230
|
assert feat.shape[2] == mel_len2
|
|
141
|
-
return feat
|
|
231
|
+
return feat, None
|
|
@@ -11,10 +11,12 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import onnxruntime
|
|
14
15
|
import torch
|
|
15
16
|
import torch.nn.functional as F
|
|
16
17
|
from matcha.models.components.flow_matching import BASECFM
|
|
17
18
|
|
|
19
|
+
|
|
18
20
|
class ConditionalCFM(BASECFM):
|
|
19
21
|
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
|
20
22
|
super().__init__(
|
|
@@ -31,7 +33,7 @@ class ConditionalCFM(BASECFM):
|
|
|
31
33
|
self.estimator = estimator
|
|
32
34
|
|
|
33
35
|
@torch.inference_mode()
|
|
34
|
-
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
|
36
|
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
|
|
35
37
|
"""Forward diffusion
|
|
36
38
|
|
|
37
39
|
Args:
|
|
@@ -49,11 +51,21 @@ class ConditionalCFM(BASECFM):
|
|
|
49
51
|
sample: generated mel-spectrogram
|
|
50
52
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
51
53
|
"""
|
|
54
|
+
|
|
52
55
|
z = torch.randn_like(mu) * temperature
|
|
53
|
-
|
|
56
|
+
cache_size = flow_cache.shape[2]
|
|
57
|
+
# fix prompt and overlap part mu and z
|
|
58
|
+
if cache_size != 0:
|
|
59
|
+
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
|
|
60
|
+
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
|
|
61
|
+
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
|
|
62
|
+
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
|
|
63
|
+
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
|
|
64
|
+
|
|
65
|
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
|
54
66
|
if self.t_scheduler == 'cosine':
|
|
55
67
|
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
56
|
-
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
|
68
|
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
|
|
57
69
|
|
|
58
70
|
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
|
59
71
|
"""
|
|
@@ -71,30 +83,80 @@ class ConditionalCFM(BASECFM):
|
|
|
71
83
|
cond: Not used but kept for future purposes
|
|
72
84
|
"""
|
|
73
85
|
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
86
|
+
t = t.unsqueeze(dim=0)
|
|
74
87
|
|
|
75
88
|
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
|
76
89
|
# Or in future might add like a return_all_steps flag
|
|
77
90
|
sol = []
|
|
78
91
|
|
|
92
|
+
if self.inference_cfg_rate > 0:
|
|
93
|
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
|
94
|
+
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
95
|
+
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
|
96
|
+
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
97
|
+
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
|
98
|
+
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
|
99
|
+
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
100
|
+
else:
|
|
101
|
+
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
|
|
79
102
|
for step in range(1, len(t_span)):
|
|
80
|
-
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
|
81
103
|
# Classifier-Free Guidance inference introduced in VoiceBox
|
|
82
104
|
if self.inference_cfg_rate > 0:
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
105
|
+
x_in[:] = x
|
|
106
|
+
mask_in[:] = mask
|
|
107
|
+
mu_in[0] = mu
|
|
108
|
+
t_in[:] = t.unsqueeze(0)
|
|
109
|
+
spks_in[0] = spks
|
|
110
|
+
cond_in[0] = cond
|
|
111
|
+
else:
|
|
112
|
+
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
|
|
113
|
+
dphi_dt = self.forward_estimator(
|
|
114
|
+
x_in, mask_in,
|
|
115
|
+
mu_in, t_in,
|
|
116
|
+
spks_in,
|
|
117
|
+
cond_in
|
|
118
|
+
)
|
|
119
|
+
if self.inference_cfg_rate > 0:
|
|
120
|
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
|
121
|
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
|
91
122
|
x = x + dt * dphi_dt
|
|
92
123
|
t = t + dt
|
|
93
124
|
sol.append(x)
|
|
94
125
|
if step < len(t_span) - 1:
|
|
95
126
|
dt = t_span[step + 1] - t
|
|
96
127
|
|
|
97
|
-
return sol[-1]
|
|
128
|
+
return sol[-1].float()
|
|
129
|
+
|
|
130
|
+
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
|
131
|
+
if isinstance(self.estimator, torch.nn.Module):
|
|
132
|
+
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
|
133
|
+
elif isinstance(self.estimator, onnxruntime.InferenceSession):
|
|
134
|
+
ort_inputs = {
|
|
135
|
+
'x': x.cpu().numpy(),
|
|
136
|
+
'mask': mask.cpu().numpy(),
|
|
137
|
+
'mu': mu.cpu().numpy(),
|
|
138
|
+
't': t.cpu().numpy(),
|
|
139
|
+
'spks': spks.cpu().numpy(),
|
|
140
|
+
'cond': cond.cpu().numpy()
|
|
141
|
+
}
|
|
142
|
+
output = self.estimator.run(None, ort_inputs)[0]
|
|
143
|
+
return torch.tensor(output, dtype=x.dtype, device=x.device)
|
|
144
|
+
else:
|
|
145
|
+
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
|
146
|
+
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
|
147
|
+
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
|
148
|
+
self.estimator.set_input_shape('t', (2,))
|
|
149
|
+
self.estimator.set_input_shape('spks', (2, 80))
|
|
150
|
+
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
|
151
|
+
# run trt engine
|
|
152
|
+
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
|
153
|
+
mask.contiguous().data_ptr(),
|
|
154
|
+
mu.contiguous().data_ptr(),
|
|
155
|
+
t.contiguous().data_ptr(),
|
|
156
|
+
spks.contiguous().data_ptr(),
|
|
157
|
+
cond.contiguous().data_ptr(),
|
|
158
|
+
x.data_ptr()])
|
|
159
|
+
return x
|
|
98
160
|
|
|
99
161
|
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
|
100
162
|
"""Computes diffusion loss
|
|
@@ -136,3 +198,38 @@ class ConditionalCFM(BASECFM):
|
|
|
136
198
|
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
|
137
199
|
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
|
138
200
|
return loss, y
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class CausalConditionalCFM(ConditionalCFM):
|
|
204
|
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
|
205
|
+
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
|
206
|
+
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
|
207
|
+
|
|
208
|
+
@torch.inference_mode()
|
|
209
|
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
|
210
|
+
"""Forward diffusion
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
mu (torch.Tensor): output of encoder
|
|
214
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
215
|
+
mask (torch.Tensor): output_mask
|
|
216
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
217
|
+
n_timesteps (int): number of diffusion steps
|
|
218
|
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
|
219
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
220
|
+
shape: (batch_size, spk_emb_dim)
|
|
221
|
+
cond: Not used but kept for future purposes
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
sample: generated mel-spectrogram
|
|
225
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
|
|
229
|
+
if self.fp16 is True:
|
|
230
|
+
z = z.half()
|
|
231
|
+
# fix prompt and overlap part mu and z
|
|
232
|
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
|
233
|
+
if self.t_scheduler == 'cosine':
|
|
234
|
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
235
|
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|