xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +28 -6
- xinference/core/utils.py +10 -6
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/core.py +10 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +200 -0
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +36 -111
- xinference/model/audio/model_spec.json +27 -3
- xinference/model/audio/model_spec_modelscope.json +18 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/embedding/core.py +203 -142
- xinference/model/embedding/model_spec.json +7 -0
- xinference/model/embedding/model_spec_modelscope.json +8 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +127 -4
- xinference/model/image/model_spec_modelscope.json +130 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/__init__.py +2 -2
- xinference/model/llm/llm_family.json +219 -53
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +167 -20
- xinference/model/llm/mlx/core.py +287 -51
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/qwen2_vl.py +2 -0
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +5 -1
- xinference/model/llm/vllm/core.py +16 -2
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import gc
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torchaudio
|
|
8
|
+
import wandb
|
|
9
|
+
from accelerate import Accelerator
|
|
10
|
+
from accelerate.utils import DistributedDataParallelKwargs
|
|
11
|
+
from ema_pytorch import EMA
|
|
12
|
+
from torch.optim import AdamW
|
|
13
|
+
from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
|
14
|
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler
|
|
15
|
+
from tqdm import tqdm
|
|
16
|
+
|
|
17
|
+
from f5_tts.model import CFM
|
|
18
|
+
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
|
19
|
+
from f5_tts.model.utils import default, exists
|
|
20
|
+
|
|
21
|
+
# trainer
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Trainer:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
model: CFM,
|
|
28
|
+
epochs,
|
|
29
|
+
learning_rate,
|
|
30
|
+
num_warmup_updates=20000,
|
|
31
|
+
save_per_updates=1000,
|
|
32
|
+
checkpoint_path=None,
|
|
33
|
+
batch_size=32,
|
|
34
|
+
batch_size_type: str = "sample",
|
|
35
|
+
max_samples=32,
|
|
36
|
+
grad_accumulation_steps=1,
|
|
37
|
+
max_grad_norm=1.0,
|
|
38
|
+
noise_scheduler: str | None = None,
|
|
39
|
+
duration_predictor: torch.nn.Module | None = None,
|
|
40
|
+
logger: str | None = "wandb", # "wandb" | "tensorboard" | None
|
|
41
|
+
wandb_project="test_e2-tts",
|
|
42
|
+
wandb_run_name="test_run",
|
|
43
|
+
wandb_resume_id: str = None,
|
|
44
|
+
log_samples: bool = False,
|
|
45
|
+
last_per_steps=None,
|
|
46
|
+
accelerate_kwargs: dict = dict(),
|
|
47
|
+
ema_kwargs: dict = dict(),
|
|
48
|
+
bnb_optimizer: bool = False,
|
|
49
|
+
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
|
50
|
+
is_local_vocoder: bool = False, # use local path vocoder
|
|
51
|
+
local_vocoder_path: str = "", # local vocoder path
|
|
52
|
+
):
|
|
53
|
+
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
|
54
|
+
|
|
55
|
+
if logger == "wandb" and not wandb.api.api_key:
|
|
56
|
+
logger = None
|
|
57
|
+
print(f"Using logger: {logger}")
|
|
58
|
+
self.log_samples = log_samples
|
|
59
|
+
|
|
60
|
+
self.accelerator = Accelerator(
|
|
61
|
+
log_with=logger if logger == "wandb" else None,
|
|
62
|
+
kwargs_handlers=[ddp_kwargs],
|
|
63
|
+
gradient_accumulation_steps=grad_accumulation_steps,
|
|
64
|
+
**accelerate_kwargs,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
self.logger = logger
|
|
68
|
+
if self.logger == "wandb":
|
|
69
|
+
if exists(wandb_resume_id):
|
|
70
|
+
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
|
|
71
|
+
else:
|
|
72
|
+
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
|
73
|
+
|
|
74
|
+
self.accelerator.init_trackers(
|
|
75
|
+
project_name=wandb_project,
|
|
76
|
+
init_kwargs=init_kwargs,
|
|
77
|
+
config={
|
|
78
|
+
"epochs": epochs,
|
|
79
|
+
"learning_rate": learning_rate,
|
|
80
|
+
"num_warmup_updates": num_warmup_updates,
|
|
81
|
+
"batch_size": batch_size,
|
|
82
|
+
"batch_size_type": batch_size_type,
|
|
83
|
+
"max_samples": max_samples,
|
|
84
|
+
"grad_accumulation_steps": grad_accumulation_steps,
|
|
85
|
+
"max_grad_norm": max_grad_norm,
|
|
86
|
+
"gpus": self.accelerator.num_processes,
|
|
87
|
+
"noise_scheduler": noise_scheduler,
|
|
88
|
+
},
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
elif self.logger == "tensorboard":
|
|
92
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
93
|
+
|
|
94
|
+
self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")
|
|
95
|
+
|
|
96
|
+
self.model = model
|
|
97
|
+
|
|
98
|
+
if self.is_main:
|
|
99
|
+
self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
|
|
100
|
+
self.ema_model.to(self.accelerator.device)
|
|
101
|
+
|
|
102
|
+
self.epochs = epochs
|
|
103
|
+
self.num_warmup_updates = num_warmup_updates
|
|
104
|
+
self.save_per_updates = save_per_updates
|
|
105
|
+
self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
|
|
106
|
+
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
|
107
|
+
|
|
108
|
+
self.batch_size = batch_size
|
|
109
|
+
self.batch_size_type = batch_size_type
|
|
110
|
+
self.max_samples = max_samples
|
|
111
|
+
self.grad_accumulation_steps = grad_accumulation_steps
|
|
112
|
+
self.max_grad_norm = max_grad_norm
|
|
113
|
+
|
|
114
|
+
# mel vocoder config
|
|
115
|
+
self.vocoder_name = mel_spec_type
|
|
116
|
+
self.is_local_vocoder = is_local_vocoder
|
|
117
|
+
self.local_vocoder_path = local_vocoder_path
|
|
118
|
+
|
|
119
|
+
self.noise_scheduler = noise_scheduler
|
|
120
|
+
|
|
121
|
+
self.duration_predictor = duration_predictor
|
|
122
|
+
|
|
123
|
+
if bnb_optimizer:
|
|
124
|
+
import bitsandbytes as bnb
|
|
125
|
+
|
|
126
|
+
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
|
127
|
+
else:
|
|
128
|
+
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
|
129
|
+
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def is_main(self):
|
|
133
|
+
return self.accelerator.is_main_process
|
|
134
|
+
|
|
135
|
+
def save_checkpoint(self, step, last=False):
|
|
136
|
+
self.accelerator.wait_for_everyone()
|
|
137
|
+
if self.is_main:
|
|
138
|
+
checkpoint = dict(
|
|
139
|
+
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
|
|
140
|
+
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
|
|
141
|
+
ema_model_state_dict=self.ema_model.state_dict(),
|
|
142
|
+
scheduler_state_dict=self.scheduler.state_dict(),
|
|
143
|
+
step=step,
|
|
144
|
+
)
|
|
145
|
+
if not os.path.exists(self.checkpoint_path):
|
|
146
|
+
os.makedirs(self.checkpoint_path)
|
|
147
|
+
if last:
|
|
148
|
+
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
|
149
|
+
print(f"Saved last checkpoint at step {step}")
|
|
150
|
+
else:
|
|
151
|
+
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
|
|
152
|
+
|
|
153
|
+
def load_checkpoint(self):
|
|
154
|
+
if (
|
|
155
|
+
not exists(self.checkpoint_path)
|
|
156
|
+
or not os.path.exists(self.checkpoint_path)
|
|
157
|
+
or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path))
|
|
158
|
+
):
|
|
159
|
+
return 0
|
|
160
|
+
|
|
161
|
+
self.accelerator.wait_for_everyone()
|
|
162
|
+
if "model_last.pt" in os.listdir(self.checkpoint_path):
|
|
163
|
+
latest_checkpoint = "model_last.pt"
|
|
164
|
+
else:
|
|
165
|
+
latest_checkpoint = sorted(
|
|
166
|
+
[f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
|
|
167
|
+
key=lambda x: int("".join(filter(str.isdigit, x))),
|
|
168
|
+
)[-1]
|
|
169
|
+
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
|
170
|
+
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
|
|
171
|
+
|
|
172
|
+
# patch for backward compatibility, 305e3ea
|
|
173
|
+
for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
|
|
174
|
+
if key in checkpoint["ema_model_state_dict"]:
|
|
175
|
+
del checkpoint["ema_model_state_dict"][key]
|
|
176
|
+
|
|
177
|
+
if self.is_main:
|
|
178
|
+
self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
|
|
179
|
+
|
|
180
|
+
if "step" in checkpoint:
|
|
181
|
+
# patch for backward compatibility, 305e3ea
|
|
182
|
+
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
|
|
183
|
+
if key in checkpoint["model_state_dict"]:
|
|
184
|
+
del checkpoint["model_state_dict"][key]
|
|
185
|
+
|
|
186
|
+
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
|
|
187
|
+
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
|
|
188
|
+
if self.scheduler:
|
|
189
|
+
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
|
190
|
+
step = checkpoint["step"]
|
|
191
|
+
else:
|
|
192
|
+
checkpoint["model_state_dict"] = {
|
|
193
|
+
k.replace("ema_model.", ""): v
|
|
194
|
+
for k, v in checkpoint["ema_model_state_dict"].items()
|
|
195
|
+
if k not in ["initted", "step"]
|
|
196
|
+
}
|
|
197
|
+
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
|
|
198
|
+
step = 0
|
|
199
|
+
|
|
200
|
+
del checkpoint
|
|
201
|
+
gc.collect()
|
|
202
|
+
return step
|
|
203
|
+
|
|
204
|
+
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
|
|
205
|
+
if self.log_samples:
|
|
206
|
+
from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
|
|
207
|
+
|
|
208
|
+
vocoder = load_vocoder(
|
|
209
|
+
vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path
|
|
210
|
+
)
|
|
211
|
+
target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
|
|
212
|
+
log_samples_path = f"{self.checkpoint_path}/samples"
|
|
213
|
+
os.makedirs(log_samples_path, exist_ok=True)
|
|
214
|
+
|
|
215
|
+
if exists(resumable_with_seed):
|
|
216
|
+
generator = torch.Generator()
|
|
217
|
+
generator.manual_seed(resumable_with_seed)
|
|
218
|
+
else:
|
|
219
|
+
generator = None
|
|
220
|
+
|
|
221
|
+
if self.batch_size_type == "sample":
|
|
222
|
+
train_dataloader = DataLoader(
|
|
223
|
+
train_dataset,
|
|
224
|
+
collate_fn=collate_fn,
|
|
225
|
+
num_workers=num_workers,
|
|
226
|
+
pin_memory=True,
|
|
227
|
+
persistent_workers=True,
|
|
228
|
+
batch_size=self.batch_size,
|
|
229
|
+
shuffle=True,
|
|
230
|
+
generator=generator,
|
|
231
|
+
)
|
|
232
|
+
elif self.batch_size_type == "frame":
|
|
233
|
+
self.accelerator.even_batches = False
|
|
234
|
+
sampler = SequentialSampler(train_dataset)
|
|
235
|
+
batch_sampler = DynamicBatchSampler(
|
|
236
|
+
sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
|
|
237
|
+
)
|
|
238
|
+
train_dataloader = DataLoader(
|
|
239
|
+
train_dataset,
|
|
240
|
+
collate_fn=collate_fn,
|
|
241
|
+
num_workers=num_workers,
|
|
242
|
+
pin_memory=True,
|
|
243
|
+
persistent_workers=True,
|
|
244
|
+
batch_sampler=batch_sampler,
|
|
245
|
+
)
|
|
246
|
+
else:
|
|
247
|
+
raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
|
|
248
|
+
|
|
249
|
+
# accelerator.prepare() dispatches batches to devices;
|
|
250
|
+
# which means the length of dataloader calculated before, should consider the number of devices
|
|
251
|
+
warmup_steps = (
|
|
252
|
+
self.num_warmup_updates * self.accelerator.num_processes
|
|
253
|
+
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
|
|
254
|
+
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
|
255
|
+
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
|
|
256
|
+
decay_steps = total_steps - warmup_steps
|
|
257
|
+
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
|
|
258
|
+
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
|
|
259
|
+
self.scheduler = SequentialLR(
|
|
260
|
+
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
|
|
261
|
+
)
|
|
262
|
+
train_dataloader, self.scheduler = self.accelerator.prepare(
|
|
263
|
+
train_dataloader, self.scheduler
|
|
264
|
+
) # actual steps = 1 gpu steps / gpus
|
|
265
|
+
start_step = self.load_checkpoint()
|
|
266
|
+
global_step = start_step
|
|
267
|
+
|
|
268
|
+
if exists(resumable_with_seed):
|
|
269
|
+
orig_epoch_step = len(train_dataloader)
|
|
270
|
+
skipped_epoch = int(start_step // orig_epoch_step)
|
|
271
|
+
skipped_batch = start_step % orig_epoch_step
|
|
272
|
+
skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
|
|
273
|
+
else:
|
|
274
|
+
skipped_epoch = 0
|
|
275
|
+
|
|
276
|
+
for epoch in range(skipped_epoch, self.epochs):
|
|
277
|
+
self.model.train()
|
|
278
|
+
if exists(resumable_with_seed) and epoch == skipped_epoch:
|
|
279
|
+
progress_bar = tqdm(
|
|
280
|
+
skipped_dataloader,
|
|
281
|
+
desc=f"Epoch {epoch+1}/{self.epochs}",
|
|
282
|
+
unit="step",
|
|
283
|
+
disable=not self.accelerator.is_local_main_process,
|
|
284
|
+
initial=skipped_batch,
|
|
285
|
+
total=orig_epoch_step,
|
|
286
|
+
)
|
|
287
|
+
else:
|
|
288
|
+
progress_bar = tqdm(
|
|
289
|
+
train_dataloader,
|
|
290
|
+
desc=f"Epoch {epoch+1}/{self.epochs}",
|
|
291
|
+
unit="step",
|
|
292
|
+
disable=not self.accelerator.is_local_main_process,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
for batch in progress_bar:
|
|
296
|
+
with self.accelerator.accumulate(self.model):
|
|
297
|
+
text_inputs = batch["text"]
|
|
298
|
+
mel_spec = batch["mel"].permute(0, 2, 1)
|
|
299
|
+
mel_lengths = batch["mel_lengths"]
|
|
300
|
+
|
|
301
|
+
# TODO. add duration predictor training
|
|
302
|
+
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
|
|
303
|
+
dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
|
|
304
|
+
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
|
|
305
|
+
|
|
306
|
+
loss, cond, pred = self.model(
|
|
307
|
+
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
|
|
308
|
+
)
|
|
309
|
+
self.accelerator.backward(loss)
|
|
310
|
+
|
|
311
|
+
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
|
312
|
+
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
|
313
|
+
|
|
314
|
+
self.optimizer.step()
|
|
315
|
+
self.scheduler.step()
|
|
316
|
+
self.optimizer.zero_grad()
|
|
317
|
+
|
|
318
|
+
if self.is_main:
|
|
319
|
+
self.ema_model.update()
|
|
320
|
+
|
|
321
|
+
global_step += 1
|
|
322
|
+
|
|
323
|
+
if self.accelerator.is_local_main_process:
|
|
324
|
+
self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
|
|
325
|
+
if self.logger == "tensorboard":
|
|
326
|
+
self.writer.add_scalar("loss", loss.item(), global_step)
|
|
327
|
+
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
|
|
328
|
+
|
|
329
|
+
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
|
|
330
|
+
|
|
331
|
+
if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
|
|
332
|
+
self.save_checkpoint(global_step)
|
|
333
|
+
|
|
334
|
+
if self.log_samples and self.accelerator.is_local_main_process:
|
|
335
|
+
ref_audio_len = mel_lengths[0]
|
|
336
|
+
infer_text = [
|
|
337
|
+
text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
|
|
338
|
+
]
|
|
339
|
+
with torch.inference_mode():
|
|
340
|
+
generated, _ = self.accelerator.unwrap_model(self.model).sample(
|
|
341
|
+
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
|
|
342
|
+
text=infer_text,
|
|
343
|
+
duration=ref_audio_len * 2,
|
|
344
|
+
steps=nfe_step,
|
|
345
|
+
cfg_strength=cfg_strength,
|
|
346
|
+
sway_sampling_coef=sway_sampling_coef,
|
|
347
|
+
)
|
|
348
|
+
generated = generated.to(torch.float32)
|
|
349
|
+
gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
|
|
350
|
+
ref_mel_spec = batch["mel"][0].unsqueeze(0)
|
|
351
|
+
if self.vocoder_name == "vocos":
|
|
352
|
+
gen_audio = vocoder.decode(gen_mel_spec).cpu()
|
|
353
|
+
ref_audio = vocoder.decode(ref_mel_spec).cpu()
|
|
354
|
+
elif self.vocoder_name == "bigvgan":
|
|
355
|
+
gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
|
|
356
|
+
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
|
|
357
|
+
|
|
358
|
+
torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
|
|
359
|
+
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
|
|
360
|
+
|
|
361
|
+
if global_step % self.last_per_steps == 0:
|
|
362
|
+
self.save_checkpoint(global_step, last=True)
|
|
363
|
+
|
|
364
|
+
self.save_checkpoint(global_step, last=True)
|
|
365
|
+
|
|
366
|
+
self.accelerator.end_training()
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import random
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from importlib.resources import files
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
10
|
+
|
|
11
|
+
import jieba
|
|
12
|
+
from pypinyin import lazy_pinyin, Style
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# seed everything
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def seed_everything(seed=0):
|
|
19
|
+
random.seed(seed)
|
|
20
|
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
|
21
|
+
torch.manual_seed(seed)
|
|
22
|
+
torch.cuda.manual_seed(seed)
|
|
23
|
+
torch.cuda.manual_seed_all(seed)
|
|
24
|
+
torch.backends.cudnn.deterministic = True
|
|
25
|
+
torch.backends.cudnn.benchmark = False
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# helpers
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def exists(v):
|
|
32
|
+
return v is not None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def default(v, d):
|
|
36
|
+
return v if exists(v) else d
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# tensor helpers
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
|
|
43
|
+
if not exists(length):
|
|
44
|
+
length = t.amax()
|
|
45
|
+
|
|
46
|
+
seq = torch.arange(length, device=t.device)
|
|
47
|
+
return seq[None, :] < t[:, None]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
|
|
51
|
+
max_seq_len = seq_len.max().item()
|
|
52
|
+
seq = torch.arange(max_seq_len, device=start.device).long()
|
|
53
|
+
start_mask = seq[None, :] >= start[:, None]
|
|
54
|
+
end_mask = seq[None, :] < end[:, None]
|
|
55
|
+
return start_mask & end_mask
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
|
|
59
|
+
lengths = (frac_lengths * seq_len).long()
|
|
60
|
+
max_start = seq_len - lengths
|
|
61
|
+
|
|
62
|
+
rand = torch.rand_like(frac_lengths)
|
|
63
|
+
start = (max_start * rand).long().clamp(min=0)
|
|
64
|
+
end = start + lengths
|
|
65
|
+
|
|
66
|
+
return mask_from_start_end_indices(seq_len, start, end)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
|
|
70
|
+
if not exists(mask):
|
|
71
|
+
return t.mean(dim=1)
|
|
72
|
+
|
|
73
|
+
t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
|
|
74
|
+
num = t.sum(dim=1)
|
|
75
|
+
den = mask.float().sum(dim=1)
|
|
76
|
+
|
|
77
|
+
return num / den.clamp(min=1.0)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# simple utf-8 tokenizer, since paper went character based
|
|
81
|
+
def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
|
|
82
|
+
list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
|
|
83
|
+
text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
|
|
84
|
+
return text
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# char tokenizer, based on custom dataset's extracted .txt file
|
|
88
|
+
def list_str_to_idx(
|
|
89
|
+
text: list[str] | list[list[str]],
|
|
90
|
+
vocab_char_map: dict[str, int], # {char: idx}
|
|
91
|
+
padding_value=-1,
|
|
92
|
+
) -> int["b nt"]: # noqa: F722
|
|
93
|
+
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
|
94
|
+
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
|
95
|
+
return text
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# Get tokenizer
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
|
102
|
+
"""
|
|
103
|
+
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
|
104
|
+
- "char" for char-wise tokenizer, need .txt vocab_file
|
|
105
|
+
- "byte" for utf-8 tokenizer
|
|
106
|
+
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
|
107
|
+
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
|
108
|
+
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
|
109
|
+
- if use "byte", set to 256 (unicode byte range)
|
|
110
|
+
"""
|
|
111
|
+
if tokenizer in ["pinyin", "char"]:
|
|
112
|
+
tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
|
|
113
|
+
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
|
114
|
+
vocab_char_map = {}
|
|
115
|
+
for i, char in enumerate(f):
|
|
116
|
+
vocab_char_map[char[:-1]] = i
|
|
117
|
+
vocab_size = len(vocab_char_map)
|
|
118
|
+
assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
|
|
119
|
+
|
|
120
|
+
elif tokenizer == "byte":
|
|
121
|
+
vocab_char_map = None
|
|
122
|
+
vocab_size = 256
|
|
123
|
+
|
|
124
|
+
elif tokenizer == "custom":
|
|
125
|
+
with open(dataset_name, "r", encoding="utf-8") as f:
|
|
126
|
+
vocab_char_map = {}
|
|
127
|
+
for i, char in enumerate(f):
|
|
128
|
+
vocab_char_map[char[:-1]] = i
|
|
129
|
+
vocab_size = len(vocab_char_map)
|
|
130
|
+
|
|
131
|
+
return vocab_char_map, vocab_size
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# convert char to pinyin
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def convert_char_to_pinyin(text_list, polyphone=True):
|
|
138
|
+
final_text_list = []
|
|
139
|
+
god_knows_why_en_testset_contains_zh_quote = str.maketrans(
|
|
140
|
+
{"“": '"', "”": '"', "‘": "'", "’": "'"}
|
|
141
|
+
) # in case librispeech (orig no-pc) test-clean
|
|
142
|
+
custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
|
|
143
|
+
for text in text_list:
|
|
144
|
+
char_list = []
|
|
145
|
+
text = text.translate(god_knows_why_en_testset_contains_zh_quote)
|
|
146
|
+
text = text.translate(custom_trans)
|
|
147
|
+
for seg in jieba.cut(text):
|
|
148
|
+
seg_byte_len = len(bytes(seg, "UTF-8"))
|
|
149
|
+
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
|
150
|
+
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
|
151
|
+
char_list.append(" ")
|
|
152
|
+
char_list.extend(seg)
|
|
153
|
+
elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
|
|
154
|
+
seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
|
155
|
+
for c in seg:
|
|
156
|
+
if c not in "。,、;:?!《》【】—…":
|
|
157
|
+
char_list.append(" ")
|
|
158
|
+
char_list.append(c)
|
|
159
|
+
else: # if mixed chinese characters, alphabets and symbols
|
|
160
|
+
for c in seg:
|
|
161
|
+
if ord(c) < 256:
|
|
162
|
+
char_list.extend(c)
|
|
163
|
+
else:
|
|
164
|
+
if c not in "。,、;:?!《》【】—…":
|
|
165
|
+
char_list.append(" ")
|
|
166
|
+
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
|
|
167
|
+
else: # if is zh punc
|
|
168
|
+
char_list.append(c)
|
|
169
|
+
final_text_list.append(char_list)
|
|
170
|
+
|
|
171
|
+
return final_text_list
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
# filter func for dirty data with many repetitions
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def repetition_found(text, length=2, tolerance=10):
|
|
178
|
+
pattern_count = defaultdict(int)
|
|
179
|
+
for i in range(len(text) - length + 1):
|
|
180
|
+
pattern = text[i : i + length]
|
|
181
|
+
pattern_count[pattern] += 1
|
|
182
|
+
for pattern, count in pattern_count.items():
|
|
183
|
+
if count > tolerance:
|
|
184
|
+
return True
|
|
185
|
+
return False
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""ADAPTIVE BATCH SIZE"""
|
|
2
|
+
|
|
3
|
+
print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
|
|
4
|
+
print(" -> least padding, gather wavs with accumulated frames in a batch\n")
|
|
5
|
+
|
|
6
|
+
# data
|
|
7
|
+
total_hours = 95282
|
|
8
|
+
mel_hop_length = 256
|
|
9
|
+
mel_sampling_rate = 24000
|
|
10
|
+
|
|
11
|
+
# target
|
|
12
|
+
wanted_max_updates = 1000000
|
|
13
|
+
|
|
14
|
+
# train params
|
|
15
|
+
gpus = 8
|
|
16
|
+
frames_per_gpu = 38400 # 8 * 38400 = 307200
|
|
17
|
+
grad_accum = 1
|
|
18
|
+
|
|
19
|
+
# intermediate
|
|
20
|
+
mini_batch_frames = frames_per_gpu * grad_accum * gpus
|
|
21
|
+
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
|
|
22
|
+
updates_per_epoch = total_hours / mini_batch_hours
|
|
23
|
+
steps_per_epoch = updates_per_epoch * grad_accum
|
|
24
|
+
|
|
25
|
+
# result
|
|
26
|
+
epochs = wanted_max_updates / updates_per_epoch
|
|
27
|
+
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
|
|
28
|
+
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
|
29
|
+
print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
|
30
|
+
|
|
31
|
+
# others
|
|
32
|
+
print(f"total {total_hours:.0f} hours")
|
|
33
|
+
print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
sys.path.append(os.getcwd())
|
|
5
|
+
|
|
6
|
+
from f5_tts.model import CFM, DiT
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import thop
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
""" ~155M """
|
|
13
|
+
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
|
|
14
|
+
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
|
|
15
|
+
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
|
|
16
|
+
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
|
17
|
+
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
|
|
18
|
+
# transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
|
|
19
|
+
|
|
20
|
+
""" ~335M """
|
|
21
|
+
# FLOPs: 622.1 G, Params: 333.2 M
|
|
22
|
+
# transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
|
23
|
+
# FLOPs: 363.4 G, Params: 335.8 M
|
|
24
|
+
transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
model = CFM(transformer=transformer)
|
|
28
|
+
target_sample_rate = 24000
|
|
29
|
+
n_mel_channels = 100
|
|
30
|
+
hop_length = 256
|
|
31
|
+
duration = 20
|
|
32
|
+
frame_length = int(duration * target_sample_rate / hop_length)
|
|
33
|
+
text_length = 150
|
|
34
|
+
|
|
35
|
+
flops, params = thop.profile(
|
|
36
|
+
model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
|
|
37
|
+
)
|
|
38
|
+
print(f"FLOPs: {flops / 1e9} G")
|
|
39
|
+
print(f"Params: {params / 1e6} M")
|