xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +28 -6
- xinference/core/utils.py +10 -6
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/core.py +10 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +200 -0
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +36 -111
- xinference/model/audio/model_spec.json +27 -3
- xinference/model/audio/model_spec_modelscope.json +18 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/embedding/core.py +203 -142
- xinference/model/embedding/model_spec.json +7 -0
- xinference/model/embedding/model_spec_modelscope.json +8 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +127 -4
- xinference/model/image/model_spec_modelscope.json +130 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/__init__.py +2 -2
- xinference/model/llm/llm_family.json +219 -53
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +167 -20
- xinference/model/llm/mlx/core.py +287 -51
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/qwen2_vl.py +2 -0
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +5 -1
- xinference/model/llm/vllm/core.py +16 -2
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ein notation:
|
|
3
|
+
b - batch
|
|
4
|
+
n - sequence
|
|
5
|
+
nt - text sequence
|
|
6
|
+
nw - raw wave length
|
|
7
|
+
d - dimension
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from random import random
|
|
13
|
+
from typing import Callable
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn.functional as F
|
|
17
|
+
from torch import nn
|
|
18
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
19
|
+
from torchdiffeq import odeint
|
|
20
|
+
|
|
21
|
+
from f5_tts.model.modules import MelSpec
|
|
22
|
+
from f5_tts.model.utils import (
|
|
23
|
+
default,
|
|
24
|
+
exists,
|
|
25
|
+
lens_to_mask,
|
|
26
|
+
list_str_to_idx,
|
|
27
|
+
list_str_to_tensor,
|
|
28
|
+
mask_from_frac_lengths,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CFM(nn.Module):
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
transformer: nn.Module,
|
|
36
|
+
sigma=0.0,
|
|
37
|
+
odeint_kwargs: dict = dict(
|
|
38
|
+
# atol = 1e-5,
|
|
39
|
+
# rtol = 1e-5,
|
|
40
|
+
method="euler" # 'midpoint'
|
|
41
|
+
),
|
|
42
|
+
audio_drop_prob=0.3,
|
|
43
|
+
cond_drop_prob=0.2,
|
|
44
|
+
num_channels=None,
|
|
45
|
+
mel_spec_module: nn.Module | None = None,
|
|
46
|
+
mel_spec_kwargs: dict = dict(),
|
|
47
|
+
frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
|
|
48
|
+
vocab_char_map: dict[str:int] | None = None,
|
|
49
|
+
):
|
|
50
|
+
super().__init__()
|
|
51
|
+
|
|
52
|
+
self.frac_lengths_mask = frac_lengths_mask
|
|
53
|
+
|
|
54
|
+
# mel spec
|
|
55
|
+
self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
|
|
56
|
+
num_channels = default(num_channels, self.mel_spec.n_mel_channels)
|
|
57
|
+
self.num_channels = num_channels
|
|
58
|
+
|
|
59
|
+
# classifier-free guidance
|
|
60
|
+
self.audio_drop_prob = audio_drop_prob
|
|
61
|
+
self.cond_drop_prob = cond_drop_prob
|
|
62
|
+
|
|
63
|
+
# transformer
|
|
64
|
+
self.transformer = transformer
|
|
65
|
+
dim = transformer.dim
|
|
66
|
+
self.dim = dim
|
|
67
|
+
|
|
68
|
+
# conditional flow related
|
|
69
|
+
self.sigma = sigma
|
|
70
|
+
|
|
71
|
+
# sampling related
|
|
72
|
+
self.odeint_kwargs = odeint_kwargs
|
|
73
|
+
|
|
74
|
+
# vocab map for tokenization
|
|
75
|
+
self.vocab_char_map = vocab_char_map
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def device(self):
|
|
79
|
+
return next(self.parameters()).device
|
|
80
|
+
|
|
81
|
+
@torch.no_grad()
|
|
82
|
+
def sample(
|
|
83
|
+
self,
|
|
84
|
+
cond: float["b n d"] | float["b nw"], # noqa: F722
|
|
85
|
+
text: int["b nt"] | list[str], # noqa: F722
|
|
86
|
+
duration: int | int["b"], # noqa: F821
|
|
87
|
+
*,
|
|
88
|
+
lens: int["b"] | None = None, # noqa: F821
|
|
89
|
+
steps=32,
|
|
90
|
+
cfg_strength=1.0,
|
|
91
|
+
sway_sampling_coef=None,
|
|
92
|
+
seed: int | None = None,
|
|
93
|
+
max_duration=4096,
|
|
94
|
+
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
|
|
95
|
+
no_ref_audio=False,
|
|
96
|
+
duplicate_test=False,
|
|
97
|
+
t_inter=0.1,
|
|
98
|
+
edit_mask=None,
|
|
99
|
+
):
|
|
100
|
+
self.eval()
|
|
101
|
+
# raw wave
|
|
102
|
+
|
|
103
|
+
if cond.ndim == 2:
|
|
104
|
+
cond = self.mel_spec(cond)
|
|
105
|
+
cond = cond.permute(0, 2, 1)
|
|
106
|
+
assert cond.shape[-1] == self.num_channels
|
|
107
|
+
|
|
108
|
+
cond = cond.to(next(self.parameters()).dtype)
|
|
109
|
+
|
|
110
|
+
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
|
111
|
+
if not exists(lens):
|
|
112
|
+
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
|
|
113
|
+
|
|
114
|
+
# text
|
|
115
|
+
|
|
116
|
+
if isinstance(text, list):
|
|
117
|
+
if exists(self.vocab_char_map):
|
|
118
|
+
text = list_str_to_idx(text, self.vocab_char_map).to(device)
|
|
119
|
+
else:
|
|
120
|
+
text = list_str_to_tensor(text).to(device)
|
|
121
|
+
assert text.shape[0] == batch
|
|
122
|
+
|
|
123
|
+
if exists(text):
|
|
124
|
+
text_lens = (text != -1).sum(dim=-1)
|
|
125
|
+
lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
|
|
126
|
+
|
|
127
|
+
# duration
|
|
128
|
+
|
|
129
|
+
cond_mask = lens_to_mask(lens)
|
|
130
|
+
if edit_mask is not None:
|
|
131
|
+
cond_mask = cond_mask & edit_mask
|
|
132
|
+
|
|
133
|
+
if isinstance(duration, int):
|
|
134
|
+
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
|
|
135
|
+
|
|
136
|
+
duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
|
|
137
|
+
duration = duration.clamp(max=max_duration)
|
|
138
|
+
max_duration = duration.amax()
|
|
139
|
+
|
|
140
|
+
# duplicate test corner for inner time step oberservation
|
|
141
|
+
if duplicate_test:
|
|
142
|
+
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
|
|
143
|
+
|
|
144
|
+
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
|
145
|
+
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
|
|
146
|
+
cond_mask = cond_mask.unsqueeze(-1)
|
|
147
|
+
step_cond = torch.where(
|
|
148
|
+
cond_mask, cond, torch.zeros_like(cond)
|
|
149
|
+
) # allow direct control (cut cond audio) with lens passed in
|
|
150
|
+
|
|
151
|
+
if batch > 1:
|
|
152
|
+
mask = lens_to_mask(duration)
|
|
153
|
+
else: # save memory and speed up, as single inference need no mask currently
|
|
154
|
+
mask = None
|
|
155
|
+
|
|
156
|
+
# test for no ref audio
|
|
157
|
+
if no_ref_audio:
|
|
158
|
+
cond = torch.zeros_like(cond)
|
|
159
|
+
|
|
160
|
+
# neural ode
|
|
161
|
+
|
|
162
|
+
def fn(t, x):
|
|
163
|
+
# at each step, conditioning is fixed
|
|
164
|
+
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
|
165
|
+
|
|
166
|
+
# predict flow
|
|
167
|
+
pred = self.transformer(
|
|
168
|
+
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
|
|
169
|
+
)
|
|
170
|
+
if cfg_strength < 1e-5:
|
|
171
|
+
return pred
|
|
172
|
+
|
|
173
|
+
null_pred = self.transformer(
|
|
174
|
+
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
|
|
175
|
+
)
|
|
176
|
+
return pred + (pred - null_pred) * cfg_strength
|
|
177
|
+
|
|
178
|
+
# noise input
|
|
179
|
+
# to make sure batch inference result is same with different batch size, and for sure single inference
|
|
180
|
+
# still some difference maybe due to convolutional layers
|
|
181
|
+
y0 = []
|
|
182
|
+
for dur in duration:
|
|
183
|
+
if exists(seed):
|
|
184
|
+
torch.manual_seed(seed)
|
|
185
|
+
y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
|
|
186
|
+
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
|
|
187
|
+
|
|
188
|
+
t_start = 0
|
|
189
|
+
|
|
190
|
+
# duplicate test corner for inner time step oberservation
|
|
191
|
+
if duplicate_test:
|
|
192
|
+
t_start = t_inter
|
|
193
|
+
y0 = (1 - t_start) * y0 + t_start * test_cond
|
|
194
|
+
steps = int(steps * (1 - t_start))
|
|
195
|
+
|
|
196
|
+
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
|
|
197
|
+
if sway_sampling_coef is not None:
|
|
198
|
+
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
|
199
|
+
|
|
200
|
+
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
|
201
|
+
|
|
202
|
+
sampled = trajectory[-1]
|
|
203
|
+
out = sampled
|
|
204
|
+
out = torch.where(cond_mask, cond, out)
|
|
205
|
+
|
|
206
|
+
if exists(vocoder):
|
|
207
|
+
out = out.permute(0, 2, 1)
|
|
208
|
+
out = vocoder(out)
|
|
209
|
+
|
|
210
|
+
return out, trajectory
|
|
211
|
+
|
|
212
|
+
def forward(
|
|
213
|
+
self,
|
|
214
|
+
inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
|
|
215
|
+
text: int["b nt"] | list[str], # noqa: F722
|
|
216
|
+
*,
|
|
217
|
+
lens: int["b"] | None = None, # noqa: F821
|
|
218
|
+
noise_scheduler: str | None = None,
|
|
219
|
+
):
|
|
220
|
+
# handle raw wave
|
|
221
|
+
if inp.ndim == 2:
|
|
222
|
+
inp = self.mel_spec(inp)
|
|
223
|
+
inp = inp.permute(0, 2, 1)
|
|
224
|
+
assert inp.shape[-1] == self.num_channels
|
|
225
|
+
|
|
226
|
+
batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
|
|
227
|
+
|
|
228
|
+
# handle text as string
|
|
229
|
+
if isinstance(text, list):
|
|
230
|
+
if exists(self.vocab_char_map):
|
|
231
|
+
text = list_str_to_idx(text, self.vocab_char_map).to(device)
|
|
232
|
+
else:
|
|
233
|
+
text = list_str_to_tensor(text).to(device)
|
|
234
|
+
assert text.shape[0] == batch
|
|
235
|
+
|
|
236
|
+
# lens and mask
|
|
237
|
+
if not exists(lens):
|
|
238
|
+
lens = torch.full((batch,), seq_len, device=device)
|
|
239
|
+
|
|
240
|
+
mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
|
|
241
|
+
|
|
242
|
+
# get a random span to mask out for training conditionally
|
|
243
|
+
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
|
|
244
|
+
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
|
|
245
|
+
|
|
246
|
+
if exists(mask):
|
|
247
|
+
rand_span_mask &= mask
|
|
248
|
+
|
|
249
|
+
# mel is x1
|
|
250
|
+
x1 = inp
|
|
251
|
+
|
|
252
|
+
# x0 is gaussian noise
|
|
253
|
+
x0 = torch.randn_like(x1)
|
|
254
|
+
|
|
255
|
+
# time step
|
|
256
|
+
time = torch.rand((batch,), dtype=dtype, device=self.device)
|
|
257
|
+
# TODO. noise_scheduler
|
|
258
|
+
|
|
259
|
+
# sample xt (φ_t(x) in the paper)
|
|
260
|
+
t = time.unsqueeze(-1).unsqueeze(-1)
|
|
261
|
+
φ = (1 - t) * x0 + t * x1
|
|
262
|
+
flow = x1 - x0
|
|
263
|
+
|
|
264
|
+
# only predict what is within the random mask span for infilling
|
|
265
|
+
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
|
|
266
|
+
|
|
267
|
+
# transformer and cfg training with a drop rate
|
|
268
|
+
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
|
|
269
|
+
if random() < self.cond_drop_prob: # p_uncond in voicebox paper
|
|
270
|
+
drop_audio_cond = True
|
|
271
|
+
drop_text = True
|
|
272
|
+
else:
|
|
273
|
+
drop_text = False
|
|
274
|
+
|
|
275
|
+
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
|
|
276
|
+
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
|
|
277
|
+
pred = self.transformer(
|
|
278
|
+
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# flow matching loss
|
|
282
|
+
loss = F.mse_loss(pred, flow, reduction="none")
|
|
283
|
+
loss = loss[rand_span_mask]
|
|
284
|
+
|
|
285
|
+
return loss.mean(), cond, pred
|
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import random
|
|
3
|
+
from importlib.resources import files
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
import torchaudio
|
|
8
|
+
from datasets import Dataset as Dataset_
|
|
9
|
+
from datasets import load_from_disk
|
|
10
|
+
from torch import nn
|
|
11
|
+
from torch.utils.data import Dataset, Sampler
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
|
|
14
|
+
from f5_tts.model.modules import MelSpec
|
|
15
|
+
from f5_tts.model.utils import default
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class HFDataset(Dataset):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
hf_dataset: Dataset,
|
|
22
|
+
target_sample_rate=24_000,
|
|
23
|
+
n_mel_channels=100,
|
|
24
|
+
hop_length=256,
|
|
25
|
+
n_fft=1024,
|
|
26
|
+
win_length=1024,
|
|
27
|
+
mel_spec_type="vocos",
|
|
28
|
+
):
|
|
29
|
+
self.data = hf_dataset
|
|
30
|
+
self.target_sample_rate = target_sample_rate
|
|
31
|
+
self.hop_length = hop_length
|
|
32
|
+
|
|
33
|
+
self.mel_spectrogram = MelSpec(
|
|
34
|
+
n_fft=n_fft,
|
|
35
|
+
hop_length=hop_length,
|
|
36
|
+
win_length=win_length,
|
|
37
|
+
n_mel_channels=n_mel_channels,
|
|
38
|
+
target_sample_rate=target_sample_rate,
|
|
39
|
+
mel_spec_type=mel_spec_type,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def get_frame_len(self, index):
|
|
43
|
+
row = self.data[index]
|
|
44
|
+
audio = row["audio"]["array"]
|
|
45
|
+
sample_rate = row["audio"]["sampling_rate"]
|
|
46
|
+
return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
|
|
47
|
+
|
|
48
|
+
def __len__(self):
|
|
49
|
+
return len(self.data)
|
|
50
|
+
|
|
51
|
+
def __getitem__(self, index):
|
|
52
|
+
row = self.data[index]
|
|
53
|
+
audio = row["audio"]["array"]
|
|
54
|
+
|
|
55
|
+
# logger.info(f"Audio shape: {audio.shape}")
|
|
56
|
+
|
|
57
|
+
sample_rate = row["audio"]["sampling_rate"]
|
|
58
|
+
duration = audio.shape[-1] / sample_rate
|
|
59
|
+
|
|
60
|
+
if duration > 30 or duration < 0.3:
|
|
61
|
+
return self.__getitem__((index + 1) % len(self.data))
|
|
62
|
+
|
|
63
|
+
audio_tensor = torch.from_numpy(audio).float()
|
|
64
|
+
|
|
65
|
+
if sample_rate != self.target_sample_rate:
|
|
66
|
+
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
|
|
67
|
+
audio_tensor = resampler(audio_tensor)
|
|
68
|
+
|
|
69
|
+
audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
|
|
70
|
+
|
|
71
|
+
mel_spec = self.mel_spectrogram(audio_tensor)
|
|
72
|
+
|
|
73
|
+
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
|
|
74
|
+
|
|
75
|
+
text = row["text"]
|
|
76
|
+
|
|
77
|
+
return dict(
|
|
78
|
+
mel_spec=mel_spec,
|
|
79
|
+
text=text,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class CustomDataset(Dataset):
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
custom_dataset: Dataset,
|
|
87
|
+
durations=None,
|
|
88
|
+
target_sample_rate=24_000,
|
|
89
|
+
hop_length=256,
|
|
90
|
+
n_mel_channels=100,
|
|
91
|
+
n_fft=1024,
|
|
92
|
+
win_length=1024,
|
|
93
|
+
mel_spec_type="vocos",
|
|
94
|
+
preprocessed_mel=False,
|
|
95
|
+
mel_spec_module: nn.Module | None = None,
|
|
96
|
+
):
|
|
97
|
+
self.data = custom_dataset
|
|
98
|
+
self.durations = durations
|
|
99
|
+
self.target_sample_rate = target_sample_rate
|
|
100
|
+
self.hop_length = hop_length
|
|
101
|
+
self.n_fft = n_fft
|
|
102
|
+
self.win_length = win_length
|
|
103
|
+
self.mel_spec_type = mel_spec_type
|
|
104
|
+
self.preprocessed_mel = preprocessed_mel
|
|
105
|
+
|
|
106
|
+
if not preprocessed_mel:
|
|
107
|
+
self.mel_spectrogram = default(
|
|
108
|
+
mel_spec_module,
|
|
109
|
+
MelSpec(
|
|
110
|
+
n_fft=n_fft,
|
|
111
|
+
hop_length=hop_length,
|
|
112
|
+
win_length=win_length,
|
|
113
|
+
n_mel_channels=n_mel_channels,
|
|
114
|
+
target_sample_rate=target_sample_rate,
|
|
115
|
+
mel_spec_type=mel_spec_type,
|
|
116
|
+
),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def get_frame_len(self, index):
|
|
120
|
+
if (
|
|
121
|
+
self.durations is not None
|
|
122
|
+
): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
|
|
123
|
+
return self.durations[index] * self.target_sample_rate / self.hop_length
|
|
124
|
+
return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
|
|
125
|
+
|
|
126
|
+
def __len__(self):
|
|
127
|
+
return len(self.data)
|
|
128
|
+
|
|
129
|
+
def __getitem__(self, index):
|
|
130
|
+
while True:
|
|
131
|
+
row = self.data[index]
|
|
132
|
+
audio_path = row["audio_path"]
|
|
133
|
+
text = row["text"]
|
|
134
|
+
duration = row["duration"]
|
|
135
|
+
|
|
136
|
+
# filter by given length
|
|
137
|
+
if 0.3 <= duration <= 30:
|
|
138
|
+
break # valid
|
|
139
|
+
|
|
140
|
+
index = (index + 1) % len(self.data)
|
|
141
|
+
|
|
142
|
+
if self.preprocessed_mel:
|
|
143
|
+
mel_spec = torch.tensor(row["mel_spec"])
|
|
144
|
+
else:
|
|
145
|
+
audio, source_sample_rate = torchaudio.load(audio_path)
|
|
146
|
+
|
|
147
|
+
# make sure mono input
|
|
148
|
+
if audio.shape[0] > 1:
|
|
149
|
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
|
150
|
+
|
|
151
|
+
# resample if necessary
|
|
152
|
+
if source_sample_rate != self.target_sample_rate:
|
|
153
|
+
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
|
|
154
|
+
audio = resampler(audio)
|
|
155
|
+
|
|
156
|
+
# to mel spectrogram
|
|
157
|
+
mel_spec = self.mel_spectrogram(audio)
|
|
158
|
+
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
|
|
159
|
+
|
|
160
|
+
return {
|
|
161
|
+
"mel_spec": mel_spec,
|
|
162
|
+
"text": text,
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
# Dynamic Batch Sampler
|
|
167
|
+
class DynamicBatchSampler(Sampler[list[int]]):
|
|
168
|
+
"""Extension of Sampler that will do the following:
|
|
169
|
+
1. Change the batch size (essentially number of sequences)
|
|
170
|
+
in a batch to ensure that the total number of frames are less
|
|
171
|
+
than a certain threshold.
|
|
172
|
+
2. Make sure the padding efficiency in the batch is high.
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
def __init__(
|
|
176
|
+
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
|
|
177
|
+
):
|
|
178
|
+
self.sampler = sampler
|
|
179
|
+
self.frames_threshold = frames_threshold
|
|
180
|
+
self.max_samples = max_samples
|
|
181
|
+
|
|
182
|
+
indices, batches = [], []
|
|
183
|
+
data_source = self.sampler.data_source
|
|
184
|
+
|
|
185
|
+
for idx in tqdm(
|
|
186
|
+
self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
|
|
187
|
+
):
|
|
188
|
+
indices.append((idx, data_source.get_frame_len(idx)))
|
|
189
|
+
indices.sort(key=lambda elem: elem[1])
|
|
190
|
+
|
|
191
|
+
batch = []
|
|
192
|
+
batch_frames = 0
|
|
193
|
+
for idx, frame_len in tqdm(
|
|
194
|
+
indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
|
|
195
|
+
):
|
|
196
|
+
if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
|
|
197
|
+
batch.append(idx)
|
|
198
|
+
batch_frames += frame_len
|
|
199
|
+
else:
|
|
200
|
+
if len(batch) > 0:
|
|
201
|
+
batches.append(batch)
|
|
202
|
+
if frame_len <= self.frames_threshold:
|
|
203
|
+
batch = [idx]
|
|
204
|
+
batch_frames = frame_len
|
|
205
|
+
else:
|
|
206
|
+
batch = []
|
|
207
|
+
batch_frames = 0
|
|
208
|
+
|
|
209
|
+
if not drop_last and len(batch) > 0:
|
|
210
|
+
batches.append(batch)
|
|
211
|
+
|
|
212
|
+
del indices
|
|
213
|
+
|
|
214
|
+
# if want to have different batches between epochs, may just set a seed and log it in ckpt
|
|
215
|
+
# cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
|
|
216
|
+
# e.g. for epoch n, use (random_seed + n)
|
|
217
|
+
random.seed(random_seed)
|
|
218
|
+
random.shuffle(batches)
|
|
219
|
+
|
|
220
|
+
self.batches = batches
|
|
221
|
+
|
|
222
|
+
def __iter__(self):
|
|
223
|
+
return iter(self.batches)
|
|
224
|
+
|
|
225
|
+
def __len__(self):
|
|
226
|
+
return len(self.batches)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
# Load dataset
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def load_dataset(
|
|
233
|
+
dataset_name: str,
|
|
234
|
+
tokenizer: str = "pinyin",
|
|
235
|
+
dataset_type: str = "CustomDataset",
|
|
236
|
+
audio_type: str = "raw",
|
|
237
|
+
mel_spec_module: nn.Module | None = None,
|
|
238
|
+
mel_spec_kwargs: dict = dict(),
|
|
239
|
+
) -> CustomDataset | HFDataset:
|
|
240
|
+
"""
|
|
241
|
+
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
|
|
242
|
+
- "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
print("Loading dataset ...")
|
|
246
|
+
|
|
247
|
+
if dataset_type == "CustomDataset":
|
|
248
|
+
rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}"))
|
|
249
|
+
if audio_type == "raw":
|
|
250
|
+
try:
|
|
251
|
+
train_dataset = load_from_disk(f"{rel_data_path}/raw")
|
|
252
|
+
except: # noqa: E722
|
|
253
|
+
train_dataset = Dataset_.from_file(f"{rel_data_path}/raw.arrow")
|
|
254
|
+
preprocessed_mel = False
|
|
255
|
+
elif audio_type == "mel":
|
|
256
|
+
train_dataset = Dataset_.from_file(f"{rel_data_path}/mel.arrow")
|
|
257
|
+
preprocessed_mel = True
|
|
258
|
+
with open(f"{rel_data_path}/duration.json", "r", encoding="utf-8") as f:
|
|
259
|
+
data_dict = json.load(f)
|
|
260
|
+
durations = data_dict["duration"]
|
|
261
|
+
train_dataset = CustomDataset(
|
|
262
|
+
train_dataset,
|
|
263
|
+
durations=durations,
|
|
264
|
+
preprocessed_mel=preprocessed_mel,
|
|
265
|
+
mel_spec_module=mel_spec_module,
|
|
266
|
+
**mel_spec_kwargs,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
elif dataset_type == "CustomDatasetPath":
|
|
270
|
+
try:
|
|
271
|
+
train_dataset = load_from_disk(f"{dataset_name}/raw")
|
|
272
|
+
except: # noqa: E722
|
|
273
|
+
train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
|
|
274
|
+
|
|
275
|
+
with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f:
|
|
276
|
+
data_dict = json.load(f)
|
|
277
|
+
durations = data_dict["duration"]
|
|
278
|
+
train_dataset = CustomDataset(
|
|
279
|
+
train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
elif dataset_type == "HFDataset":
|
|
283
|
+
print(
|
|
284
|
+
"Should manually modify the path of huggingface dataset to your need.\n"
|
|
285
|
+
+ "May also the corresponding script cuz different dataset may have different format."
|
|
286
|
+
)
|
|
287
|
+
pre, post = dataset_name.split("_")
|
|
288
|
+
train_dataset = HFDataset(
|
|
289
|
+
load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))),
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
return train_dataset
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
# collation
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def collate_fn(batch):
|
|
299
|
+
mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
|
|
300
|
+
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
|
|
301
|
+
max_mel_length = mel_lengths.amax()
|
|
302
|
+
|
|
303
|
+
padded_mel_specs = []
|
|
304
|
+
for spec in mel_specs: # TODO. maybe records mask for attention here
|
|
305
|
+
padding = (0, max_mel_length - spec.size(-1))
|
|
306
|
+
padded_spec = F.pad(spec, padding, value=0)
|
|
307
|
+
padded_mel_specs.append(padded_spec)
|
|
308
|
+
|
|
309
|
+
mel_specs = torch.stack(padded_mel_specs)
|
|
310
|
+
|
|
311
|
+
text = [item["text"] for item in batch]
|
|
312
|
+
text_lengths = torch.LongTensor([len(item) for item in text])
|
|
313
|
+
|
|
314
|
+
return dict(
|
|
315
|
+
mel=mel_specs,
|
|
316
|
+
mel_lengths=mel_lengths,
|
|
317
|
+
text=text,
|
|
318
|
+
text_lengths=text_lengths,
|
|
319
|
+
)
|