xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (191) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +60 -44
  6. xinference/model/audio/chattts.py +25 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/cosyvoice.py +4 -3
  9. xinference/model/audio/custom.py +4 -5
  10. xinference/model/audio/fish_speech.py +228 -0
  11. xinference/model/audio/model_spec.json +8 -0
  12. xinference/model/embedding/core.py +25 -1
  13. xinference/model/embedding/custom.py +4 -5
  14. xinference/model/flexible/core.py +5 -1
  15. xinference/model/image/custom.py +4 -5
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +66 -3
  19. xinference/model/llm/__init__.py +6 -0
  20. xinference/model/llm/llm_family.json +54 -9
  21. xinference/model/llm/llm_family.py +7 -6
  22. xinference/model/llm/llm_family_modelscope.json +56 -10
  23. xinference/model/llm/lmdeploy/__init__.py +0 -0
  24. xinference/model/llm/lmdeploy/core.py +557 -0
  25. xinference/model/llm/sglang/core.py +7 -1
  26. xinference/model/llm/transformers/cogvlm2.py +4 -45
  27. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  28. xinference/model/llm/transformers/core.py +3 -0
  29. xinference/model/llm/transformers/glm4v.py +2 -23
  30. xinference/model/llm/transformers/intern_vl.py +94 -11
  31. xinference/model/llm/transformers/minicpmv25.py +2 -23
  32. xinference/model/llm/transformers/minicpmv26.py +2 -22
  33. xinference/model/llm/transformers/yi_vl.py +2 -24
  34. xinference/model/llm/utils.py +13 -1
  35. xinference/model/llm/vllm/core.py +1 -34
  36. xinference/model/rerank/custom.py +4 -5
  37. xinference/model/utils.py +41 -1
  38. xinference/model/video/core.py +3 -1
  39. xinference/model/video/diffusers.py +41 -38
  40. xinference/model/video/model_spec.json +24 -1
  41. xinference/model/video/model_spec_modelscope.json +25 -1
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/matcha/__init__.py +0 -0
  137. xinference/thirdparty/matcha/app.py +357 -0
  138. xinference/thirdparty/matcha/cli.py +419 -0
  139. xinference/thirdparty/matcha/data/__init__.py +0 -0
  140. xinference/thirdparty/matcha/data/components/__init__.py +0 -0
  141. xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
  142. xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
  143. xinference/thirdparty/matcha/hifigan/config.py +28 -0
  144. xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
  145. xinference/thirdparty/matcha/hifigan/env.py +17 -0
  146. xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
  147. xinference/thirdparty/matcha/hifigan/models.py +368 -0
  148. xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
  149. xinference/thirdparty/matcha/models/__init__.py +0 -0
  150. xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
  151. xinference/thirdparty/matcha/models/components/__init__.py +0 -0
  152. xinference/thirdparty/matcha/models/components/decoder.py +443 -0
  153. xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
  154. xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
  155. xinference/thirdparty/matcha/models/components/transformer.py +316 -0
  156. xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
  157. xinference/thirdparty/matcha/onnx/__init__.py +0 -0
  158. xinference/thirdparty/matcha/onnx/export.py +181 -0
  159. xinference/thirdparty/matcha/onnx/infer.py +168 -0
  160. xinference/thirdparty/matcha/text/__init__.py +53 -0
  161. xinference/thirdparty/matcha/text/cleaners.py +121 -0
  162. xinference/thirdparty/matcha/text/numbers.py +71 -0
  163. xinference/thirdparty/matcha/text/symbols.py +17 -0
  164. xinference/thirdparty/matcha/train.py +122 -0
  165. xinference/thirdparty/matcha/utils/__init__.py +5 -0
  166. xinference/thirdparty/matcha/utils/audio.py +82 -0
  167. xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
  168. xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
  169. xinference/thirdparty/matcha/utils/instantiators.py +56 -0
  170. xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
  171. xinference/thirdparty/matcha/utils/model.py +90 -0
  172. xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
  173. xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
  174. xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
  175. xinference/thirdparty/matcha/utils/pylogger.py +21 -0
  176. xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
  177. xinference/thirdparty/matcha/utils/utils.py +259 -0
  178. xinference/web/ui/build/asset-manifest.json +3 -3
  179. xinference/web/ui/build/index.html +1 -1
  180. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  181. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  182. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  183. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
  184. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
  185. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  186. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  187. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  188. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
  189. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
  190. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
  191. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,419 @@
1
+ import argparse
2
+ import datetime as dt
3
+ import os
4
+ import warnings
5
+ from pathlib import Path
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import torch
11
+
12
+ from matcha.hifigan.config import v1
13
+ from matcha.hifigan.denoiser import Denoiser
14
+ from matcha.hifigan.env import AttrDict
15
+ from matcha.hifigan.models import Generator as HiFiGAN
16
+ from matcha.models.matcha_tts import MatchaTTS
17
+ from matcha.text import sequence_to_text, text_to_sequence
18
+ from matcha.utils.utils import assert_model_downloaded, get_user_data_dir, intersperse
19
+
20
+ MATCHA_URLS = {
21
+ "matcha_ljspeech": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_ljspeech.ckpt",
22
+ "matcha_vctk": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_vctk.ckpt",
23
+ }
24
+
25
+ VOCODER_URLS = {
26
+ "hifigan_T2_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1", # Old url: https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link
27
+ "hifigan_univ_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/g_02500000", # Old url: https://drive.google.com/file/d/1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW/view?usp=drive_link
28
+ }
29
+
30
+ MULTISPEAKER_MODEL = {
31
+ "matcha_vctk": {"vocoder": "hifigan_univ_v1", "speaking_rate": 0.85, "spk": 0, "spk_range": (0, 107)}
32
+ }
33
+
34
+ SINGLESPEAKER_MODEL = {"matcha_ljspeech": {"vocoder": "hifigan_T2_v1", "speaking_rate": 0.95, "spk": None}}
35
+
36
+
37
+ def plot_spectrogram_to_numpy(spectrogram, filename):
38
+ fig, ax = plt.subplots(figsize=(12, 3))
39
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
40
+ plt.colorbar(im, ax=ax)
41
+ plt.xlabel("Frames")
42
+ plt.ylabel("Channels")
43
+ plt.title("Synthesised Mel-Spectrogram")
44
+ fig.canvas.draw()
45
+ plt.savefig(filename)
46
+
47
+
48
+ def process_text(i: int, text: str, device: torch.device):
49
+ print(f"[{i}] - Input text: {text}")
50
+ x = torch.tensor(
51
+ intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0),
52
+ dtype=torch.long,
53
+ device=device,
54
+ )[None]
55
+ x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
56
+ x_phones = sequence_to_text(x.squeeze(0).tolist())
57
+ print(f"[{i}] - Phonetised text: {x_phones[1::2]}")
58
+
59
+ return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}
60
+
61
+
62
+ def get_texts(args):
63
+ if args.text:
64
+ texts = [args.text]
65
+ else:
66
+ with open(args.file, encoding="utf-8") as f:
67
+ texts = f.readlines()
68
+ return texts
69
+
70
+
71
+ def assert_required_models_available(args):
72
+ save_dir = get_user_data_dir()
73
+ if not hasattr(args, "checkpoint_path") and args.checkpoint_path is None:
74
+ model_path = args.checkpoint_path
75
+ else:
76
+ model_path = save_dir / f"{args.model}.ckpt"
77
+ assert_model_downloaded(model_path, MATCHA_URLS[args.model])
78
+
79
+ vocoder_path = save_dir / f"{args.vocoder}"
80
+ assert_model_downloaded(vocoder_path, VOCODER_URLS[args.vocoder])
81
+ return {"matcha": model_path, "vocoder": vocoder_path}
82
+
83
+
84
+ def load_hifigan(checkpoint_path, device):
85
+ h = AttrDict(v1)
86
+ hifigan = HiFiGAN(h).to(device)
87
+ hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"])
88
+ _ = hifigan.eval()
89
+ hifigan.remove_weight_norm()
90
+ return hifigan
91
+
92
+
93
+ def load_vocoder(vocoder_name, checkpoint_path, device):
94
+ print(f"[!] Loading {vocoder_name}!")
95
+ vocoder = None
96
+ if vocoder_name in ("hifigan_T2_v1", "hifigan_univ_v1"):
97
+ vocoder = load_hifigan(checkpoint_path, device)
98
+ else:
99
+ raise NotImplementedError(
100
+ f"Vocoder {vocoder_name} not implemented! define a load_<<vocoder_name>> method for it"
101
+ )
102
+
103
+ denoiser = Denoiser(vocoder, mode="zeros")
104
+ print(f"[+] {vocoder_name} loaded!")
105
+ return vocoder, denoiser
106
+
107
+
108
+ def load_matcha(model_name, checkpoint_path, device):
109
+ print(f"[!] Loading {model_name}!")
110
+ model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)
111
+ _ = model.eval()
112
+
113
+ print(f"[+] {model_name} loaded!")
114
+ return model
115
+
116
+
117
+ def to_waveform(mel, vocoder, denoiser=None):
118
+ audio = vocoder(mel).clamp(-1, 1)
119
+ if denoiser is not None:
120
+ audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze()
121
+
122
+ return audio.cpu().squeeze()
123
+
124
+
125
+ def save_to_folder(filename: str, output: dict, folder: str):
126
+ folder = Path(folder)
127
+ folder.mkdir(exist_ok=True, parents=True)
128
+ plot_spectrogram_to_numpy(np.array(output["mel"].squeeze().float().cpu()), f"{filename}.png")
129
+ np.save(folder / f"{filename}", output["mel"].cpu().numpy())
130
+ sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24")
131
+ return folder.resolve() / f"{filename}.wav"
132
+
133
+
134
+ def validate_args(args):
135
+ assert (
136
+ args.text or args.file
137
+ ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
138
+ assert args.temperature >= 0, "Sampling temperature cannot be negative"
139
+ assert args.steps > 0, "Number of ODE steps must be greater than 0"
140
+
141
+ if args.checkpoint_path is None:
142
+ # When using pretrained models
143
+ if args.model in SINGLESPEAKER_MODEL:
144
+ args = validate_args_for_single_speaker_model(args)
145
+
146
+ if args.model in MULTISPEAKER_MODEL:
147
+ args = validate_args_for_multispeaker_model(args)
148
+ else:
149
+ # When using a custom model
150
+ if args.vocoder != "hifigan_univ_v1":
151
+ warn_ = "[-] Using custom model checkpoint! I would suggest passing --vocoder hifigan_univ_v1, unless the custom model is trained on LJ Speech."
152
+ warnings.warn(warn_, UserWarning)
153
+ if args.speaking_rate is None:
154
+ args.speaking_rate = 1.0
155
+
156
+ if args.batched:
157
+ assert args.batch_size > 0, "Batch size must be greater than 0"
158
+ assert args.speaking_rate > 0, "Speaking rate must be greater than 0"
159
+
160
+ return args
161
+
162
+
163
+ def validate_args_for_multispeaker_model(args):
164
+ if args.vocoder is not None:
165
+ if args.vocoder != MULTISPEAKER_MODEL[args.model]["vocoder"]:
166
+ warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {MULTISPEAKER_MODEL[args.model]['vocoder']}"
167
+ warnings.warn(warn_, UserWarning)
168
+ else:
169
+ args.vocoder = MULTISPEAKER_MODEL[args.model]["vocoder"]
170
+
171
+ if args.speaking_rate is None:
172
+ args.speaking_rate = MULTISPEAKER_MODEL[args.model]["speaking_rate"]
173
+
174
+ spk_range = MULTISPEAKER_MODEL[args.model]["spk_range"]
175
+ if args.spk is not None:
176
+ assert (
177
+ args.spk >= spk_range[0] and args.spk <= spk_range[-1]
178
+ ), f"Speaker ID must be between {spk_range} for this model."
179
+ else:
180
+ available_spk_id = MULTISPEAKER_MODEL[args.model]["spk"]
181
+ warn_ = f"[!] Speaker ID not provided! Using speaker ID {available_spk_id}"
182
+ warnings.warn(warn_, UserWarning)
183
+ args.spk = available_spk_id
184
+
185
+ return args
186
+
187
+
188
+ def validate_args_for_single_speaker_model(args):
189
+ if args.vocoder is not None:
190
+ if args.vocoder != SINGLESPEAKER_MODEL[args.model]["vocoder"]:
191
+ warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {SINGLESPEAKER_MODEL[args.model]['vocoder']}"
192
+ warnings.warn(warn_, UserWarning)
193
+ else:
194
+ args.vocoder = SINGLESPEAKER_MODEL[args.model]["vocoder"]
195
+
196
+ if args.speaking_rate is None:
197
+ args.speaking_rate = SINGLESPEAKER_MODEL[args.model]["speaking_rate"]
198
+
199
+ if args.spk != SINGLESPEAKER_MODEL[args.model]["spk"]:
200
+ warn_ = f"[-] Ignoring speaker id {args.spk} for {args.model}"
201
+ warnings.warn(warn_, UserWarning)
202
+ args.spk = SINGLESPEAKER_MODEL[args.model]["spk"]
203
+
204
+ return args
205
+
206
+
207
+ @torch.inference_mode()
208
+ def cli():
209
+ parser = argparse.ArgumentParser(
210
+ description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching"
211
+ )
212
+ parser.add_argument(
213
+ "--model",
214
+ type=str,
215
+ default="matcha_ljspeech",
216
+ help="Model to use",
217
+ choices=MATCHA_URLS.keys(),
218
+ )
219
+
220
+ parser.add_argument(
221
+ "--checkpoint_path",
222
+ type=str,
223
+ default=None,
224
+ help="Path to the custom model checkpoint",
225
+ )
226
+
227
+ parser.add_argument(
228
+ "--vocoder",
229
+ type=str,
230
+ default=None,
231
+ help="Vocoder to use (default: will use the one suggested with the pretrained model))",
232
+ choices=VOCODER_URLS.keys(),
233
+ )
234
+ parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
235
+ parser.add_argument("--file", type=str, default=None, help="Text file to synthesize")
236
+ parser.add_argument("--spk", type=int, default=None, help="Speaker ID")
237
+ parser.add_argument(
238
+ "--temperature",
239
+ type=float,
240
+ default=0.667,
241
+ help="Variance of the x0 noise (default: 0.667)",
242
+ )
243
+ parser.add_argument(
244
+ "--speaking_rate",
245
+ type=float,
246
+ default=None,
247
+ help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)",
248
+ )
249
+ parser.add_argument("--steps", type=int, default=10, help="Number of ODE steps (default: 10)")
250
+ parser.add_argument("--cpu", action="store_true", help="Use CPU for inference (default: use GPU if available)")
251
+ parser.add_argument(
252
+ "--denoiser_strength",
253
+ type=float,
254
+ default=0.00025,
255
+ help="Strength of the vocoder bias denoiser (default: 0.00025)",
256
+ )
257
+ parser.add_argument(
258
+ "--output_folder",
259
+ type=str,
260
+ default=os.getcwd(),
261
+ help="Output folder to save results (default: current dir)",
262
+ )
263
+ parser.add_argument("--batched", action="store_true", help="Batched inference (default: False)")
264
+ parser.add_argument(
265
+ "--batch_size", type=int, default=32, help="Batch size only useful when --batched (default: 32)"
266
+ )
267
+
268
+ args = parser.parse_args()
269
+
270
+ args = validate_args(args)
271
+ device = get_device(args)
272
+ print_config(args)
273
+ paths = assert_required_models_available(args)
274
+
275
+ if args.checkpoint_path is not None:
276
+ print(f"[🍵] Loading custom model from {args.checkpoint_path}")
277
+ paths["matcha"] = args.checkpoint_path
278
+ args.model = "custom_model"
279
+
280
+ model = load_matcha(args.model, paths["matcha"], device)
281
+ vocoder, denoiser = load_vocoder(args.vocoder, paths["vocoder"], device)
282
+
283
+ texts = get_texts(args)
284
+
285
+ spk = torch.tensor([args.spk], device=device, dtype=torch.long) if args.spk is not None else None
286
+ if len(texts) == 1 or not args.batched:
287
+ unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk)
288
+ else:
289
+ batched_synthesis(args, device, model, vocoder, denoiser, texts, spk)
290
+
291
+
292
+ class BatchedSynthesisDataset(torch.utils.data.Dataset):
293
+ def __init__(self, processed_texts):
294
+ self.processed_texts = processed_texts
295
+
296
+ def __len__(self):
297
+ return len(self.processed_texts)
298
+
299
+ def __getitem__(self, idx):
300
+ return self.processed_texts[idx]
301
+
302
+
303
+ def batched_collate_fn(batch):
304
+ x = []
305
+ x_lengths = []
306
+
307
+ for b in batch:
308
+ x.append(b["x"].squeeze(0))
309
+ x_lengths.append(b["x_lengths"])
310
+
311
+ x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True)
312
+ x_lengths = torch.concat(x_lengths, dim=0)
313
+ return {"x": x, "x_lengths": x_lengths}
314
+
315
+
316
+ def batched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
317
+ total_rtf = []
318
+ total_rtf_w = []
319
+ processed_text = [process_text(i, text, "cpu") for i, text in enumerate(texts)]
320
+ dataloader = torch.utils.data.DataLoader(
321
+ BatchedSynthesisDataset(processed_text),
322
+ batch_size=args.batch_size,
323
+ collate_fn=batched_collate_fn,
324
+ num_workers=8,
325
+ )
326
+ for i, batch in enumerate(dataloader):
327
+ i = i + 1
328
+ start_t = dt.datetime.now()
329
+ b = batch["x"].shape[0]
330
+ output = model.synthesise(
331
+ batch["x"].to(device),
332
+ batch["x_lengths"].to(device),
333
+ n_timesteps=args.steps,
334
+ temperature=args.temperature,
335
+ spks=spk.expand(b) if spk is not None else spk,
336
+ length_scale=args.speaking_rate,
337
+ )
338
+
339
+ output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
340
+ t = (dt.datetime.now() - start_t).total_seconds()
341
+ rtf_w = t * 22050 / (output["waveform"].shape[-1])
342
+ print(f"[🍵-Batch: {i}] Matcha-TTS RTF: {output['rtf']:.4f}")
343
+ print(f"[🍵-Batch: {i}] Matcha-TTS + VOCODER RTF: {rtf_w:.4f}")
344
+ total_rtf.append(output["rtf"])
345
+ total_rtf_w.append(rtf_w)
346
+ for j in range(output["mel"].shape[0]):
347
+ base_name = f"utterance_{j:03d}_speaker_{args.spk:03d}" if args.spk is not None else f"utterance_{j:03d}"
348
+ length = output["mel_lengths"][j]
349
+ new_dict = {"mel": output["mel"][j][:, :length], "waveform": output["waveform"][j][: length * 256]}
350
+ location = save_to_folder(base_name, new_dict, args.output_folder)
351
+ print(f"[🍵-{j}] Waveform saved: {location}")
352
+
353
+ print("".join(["="] * 100))
354
+ print(f"[🍵] Average Matcha-TTS RTF: {np.mean(total_rtf):.4f} ± {np.std(total_rtf)}")
355
+ print(f"[🍵] Average Matcha-TTS + VOCODER RTF: {np.mean(total_rtf_w):.4f} ± {np.std(total_rtf_w)}")
356
+ print("[🍵] Enjoy the freshly whisked 🍵 Matcha-TTS!")
357
+
358
+
359
+ def unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
360
+ total_rtf = []
361
+ total_rtf_w = []
362
+ for i, text in enumerate(texts):
363
+ i = i + 1
364
+ base_name = f"utterance_{i:03d}_speaker_{args.spk:03d}" if args.spk is not None else f"utterance_{i:03d}"
365
+
366
+ print("".join(["="] * 100))
367
+ text = text.strip()
368
+ text_processed = process_text(i, text, device)
369
+
370
+ print(f"[🍵] Whisking Matcha-T(ea)TS for: {i}")
371
+ start_t = dt.datetime.now()
372
+ output = model.synthesise(
373
+ text_processed["x"],
374
+ text_processed["x_lengths"],
375
+ n_timesteps=args.steps,
376
+ temperature=args.temperature,
377
+ spks=spk,
378
+ length_scale=args.speaking_rate,
379
+ )
380
+ output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
381
+ # RTF with HiFiGAN
382
+ t = (dt.datetime.now() - start_t).total_seconds()
383
+ rtf_w = t * 22050 / (output["waveform"].shape[-1])
384
+ print(f"[🍵-{i}] Matcha-TTS RTF: {output['rtf']:.4f}")
385
+ print(f"[🍵-{i}] Matcha-TTS + VOCODER RTF: {rtf_w:.4f}")
386
+ total_rtf.append(output["rtf"])
387
+ total_rtf_w.append(rtf_w)
388
+
389
+ location = save_to_folder(base_name, output, args.output_folder)
390
+ print(f"[+] Waveform saved: {location}")
391
+
392
+ print("".join(["="] * 100))
393
+ print(f"[🍵] Average Matcha-TTS RTF: {np.mean(total_rtf):.4f} ± {np.std(total_rtf)}")
394
+ print(f"[🍵] Average Matcha-TTS + VOCODER RTF: {np.mean(total_rtf_w):.4f} ± {np.std(total_rtf_w)}")
395
+ print("[🍵] Enjoy the freshly whisked 🍵 Matcha-TTS!")
396
+
397
+
398
+ def print_config(args):
399
+ print("[!] Configurations: ")
400
+ print(f"\t- Model: {args.model}")
401
+ print(f"\t- Vocoder: {args.vocoder}")
402
+ print(f"\t- Temperature: {args.temperature}")
403
+ print(f"\t- Speaking rate: {args.speaking_rate}")
404
+ print(f"\t- Number of ODE steps: {args.steps}")
405
+ print(f"\t- Speaker: {args.spk}")
406
+
407
+
408
+ def get_device(args):
409
+ if torch.cuda.is_available() and not args.cpu:
410
+ print("[+] GPU Available! Using GPU")
411
+ device = torch.device("cuda")
412
+ else:
413
+ print("[-] GPU not available or forced CPU run! Using CPU")
414
+ device = torch.device("cpu")
415
+ return device
416
+
417
+
418
+ if __name__ == "__main__":
419
+ cli()
File without changes
@@ -0,0 +1,274 @@
1
+ import random
2
+ from pathlib import Path
3
+ from typing import Any, Dict, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchaudio as ta
8
+ from lightning import LightningDataModule
9
+ from torch.utils.data.dataloader import DataLoader
10
+
11
+ from matcha.text import text_to_sequence
12
+ from matcha.utils.audio import mel_spectrogram
13
+ from matcha.utils.model import fix_len_compatibility, normalize
14
+ from matcha.utils.utils import intersperse
15
+
16
+
17
+ def parse_filelist(filelist_path, split_char="|"):
18
+ with open(filelist_path, encoding="utf-8") as f:
19
+ filepaths_and_text = [line.strip().split(split_char) for line in f]
20
+ return filepaths_and_text
21
+
22
+
23
+ class TextMelDataModule(LightningDataModule):
24
+ def __init__( # pylint: disable=unused-argument
25
+ self,
26
+ name,
27
+ train_filelist_path,
28
+ valid_filelist_path,
29
+ batch_size,
30
+ num_workers,
31
+ pin_memory,
32
+ cleaners,
33
+ add_blank,
34
+ n_spks,
35
+ n_fft,
36
+ n_feats,
37
+ sample_rate,
38
+ hop_length,
39
+ win_length,
40
+ f_min,
41
+ f_max,
42
+ data_statistics,
43
+ seed,
44
+ load_durations,
45
+ ):
46
+ super().__init__()
47
+
48
+ # this line allows to access init params with 'self.hparams' attribute
49
+ # also ensures init params will be stored in ckpt
50
+ self.save_hyperparameters(logger=False)
51
+
52
+ def setup(self, stage: Optional[str] = None): # pylint: disable=unused-argument
53
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
54
+
55
+ This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
56
+ careful not to execute things like random split twice!
57
+ """
58
+ # load and split datasets only if not loaded already
59
+
60
+ self.trainset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
61
+ self.hparams.train_filelist_path,
62
+ self.hparams.n_spks,
63
+ self.hparams.cleaners,
64
+ self.hparams.add_blank,
65
+ self.hparams.n_fft,
66
+ self.hparams.n_feats,
67
+ self.hparams.sample_rate,
68
+ self.hparams.hop_length,
69
+ self.hparams.win_length,
70
+ self.hparams.f_min,
71
+ self.hparams.f_max,
72
+ self.hparams.data_statistics,
73
+ self.hparams.seed,
74
+ self.hparams.load_durations,
75
+ )
76
+ self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
77
+ self.hparams.valid_filelist_path,
78
+ self.hparams.n_spks,
79
+ self.hparams.cleaners,
80
+ self.hparams.add_blank,
81
+ self.hparams.n_fft,
82
+ self.hparams.n_feats,
83
+ self.hparams.sample_rate,
84
+ self.hparams.hop_length,
85
+ self.hparams.win_length,
86
+ self.hparams.f_min,
87
+ self.hparams.f_max,
88
+ self.hparams.data_statistics,
89
+ self.hparams.seed,
90
+ self.hparams.load_durations,
91
+ )
92
+
93
+ def train_dataloader(self):
94
+ return DataLoader(
95
+ dataset=self.trainset,
96
+ batch_size=self.hparams.batch_size,
97
+ num_workers=self.hparams.num_workers,
98
+ pin_memory=self.hparams.pin_memory,
99
+ shuffle=True,
100
+ collate_fn=TextMelBatchCollate(self.hparams.n_spks),
101
+ )
102
+
103
+ def val_dataloader(self):
104
+ return DataLoader(
105
+ dataset=self.validset,
106
+ batch_size=self.hparams.batch_size,
107
+ num_workers=self.hparams.num_workers,
108
+ pin_memory=self.hparams.pin_memory,
109
+ shuffle=False,
110
+ collate_fn=TextMelBatchCollate(self.hparams.n_spks),
111
+ )
112
+
113
+ def teardown(self, stage: Optional[str] = None):
114
+ """Clean up after fit or test."""
115
+ pass # pylint: disable=unnecessary-pass
116
+
117
+ def state_dict(self):
118
+ """Extra things to save to checkpoint."""
119
+ return {}
120
+
121
+ def load_state_dict(self, state_dict: Dict[str, Any]):
122
+ """Things to do when loading checkpoint."""
123
+ pass # pylint: disable=unnecessary-pass
124
+
125
+
126
+ class TextMelDataset(torch.utils.data.Dataset):
127
+ def __init__(
128
+ self,
129
+ filelist_path,
130
+ n_spks,
131
+ cleaners,
132
+ add_blank=True,
133
+ n_fft=1024,
134
+ n_mels=80,
135
+ sample_rate=22050,
136
+ hop_length=256,
137
+ win_length=1024,
138
+ f_min=0.0,
139
+ f_max=8000,
140
+ data_parameters=None,
141
+ seed=None,
142
+ load_durations=False,
143
+ ):
144
+ self.filepaths_and_text = parse_filelist(filelist_path)
145
+ self.n_spks = n_spks
146
+ self.cleaners = cleaners
147
+ self.add_blank = add_blank
148
+ self.n_fft = n_fft
149
+ self.n_mels = n_mels
150
+ self.sample_rate = sample_rate
151
+ self.hop_length = hop_length
152
+ self.win_length = win_length
153
+ self.f_min = f_min
154
+ self.f_max = f_max
155
+ self.load_durations = load_durations
156
+
157
+ if data_parameters is not None:
158
+ self.data_parameters = data_parameters
159
+ else:
160
+ self.data_parameters = {"mel_mean": 0, "mel_std": 1}
161
+ random.seed(seed)
162
+ random.shuffle(self.filepaths_and_text)
163
+
164
+ def get_datapoint(self, filepath_and_text):
165
+ if self.n_spks > 1:
166
+ filepath, spk, text = (
167
+ filepath_and_text[0],
168
+ int(filepath_and_text[1]),
169
+ filepath_and_text[2],
170
+ )
171
+ else:
172
+ filepath, text = filepath_and_text[0], filepath_and_text[1]
173
+ spk = None
174
+
175
+ text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
176
+ mel = self.get_mel(filepath)
177
+
178
+ durations = self.get_durations(filepath, text) if self.load_durations else None
179
+
180
+ return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations}
181
+
182
+ def get_durations(self, filepath, text):
183
+ filepath = Path(filepath)
184
+ data_dir, name = filepath.parent.parent, filepath.stem
185
+
186
+ try:
187
+ dur_loc = data_dir / "durations" / f"{name}.npy"
188
+ durs = torch.from_numpy(np.load(dur_loc).astype(int))
189
+
190
+ except FileNotFoundError as e:
191
+ raise FileNotFoundError(
192
+ f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n"
193
+ ) from e
194
+
195
+ assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match"
196
+
197
+ return durs
198
+
199
+ def get_mel(self, filepath):
200
+ audio, sr = ta.load(filepath)
201
+ assert sr == self.sample_rate
202
+ mel = mel_spectrogram(
203
+ audio,
204
+ self.n_fft,
205
+ self.n_mels,
206
+ self.sample_rate,
207
+ self.hop_length,
208
+ self.win_length,
209
+ self.f_min,
210
+ self.f_max,
211
+ center=False,
212
+ ).squeeze()
213
+ mel = normalize(mel, self.data_parameters["mel_mean"], self.data_parameters["mel_std"])
214
+ return mel
215
+
216
+ def get_text(self, text, add_blank=True):
217
+ text_norm, cleaned_text = text_to_sequence(text, self.cleaners)
218
+ if self.add_blank:
219
+ text_norm = intersperse(text_norm, 0)
220
+ text_norm = torch.IntTensor(text_norm)
221
+ return text_norm, cleaned_text
222
+
223
+ def __getitem__(self, index):
224
+ datapoint = self.get_datapoint(self.filepaths_and_text[index])
225
+ return datapoint
226
+
227
+ def __len__(self):
228
+ return len(self.filepaths_and_text)
229
+
230
+
231
+ class TextMelBatchCollate:
232
+ def __init__(self, n_spks):
233
+ self.n_spks = n_spks
234
+
235
+ def __call__(self, batch):
236
+ B = len(batch)
237
+ y_max_length = max([item["y"].shape[-1] for item in batch])
238
+ y_max_length = fix_len_compatibility(y_max_length)
239
+ x_max_length = max([item["x"].shape[-1] for item in batch])
240
+ n_feats = batch[0]["y"].shape[-2]
241
+
242
+ y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
243
+ x = torch.zeros((B, x_max_length), dtype=torch.long)
244
+ durations = torch.zeros((B, x_max_length), dtype=torch.long)
245
+
246
+ y_lengths, x_lengths = [], []
247
+ spks = []
248
+ filepaths, x_texts = [], []
249
+ for i, item in enumerate(batch):
250
+ y_, x_ = item["y"], item["x"]
251
+ y_lengths.append(y_.shape[-1])
252
+ x_lengths.append(x_.shape[-1])
253
+ y[i, :, : y_.shape[-1]] = y_
254
+ x[i, : x_.shape[-1]] = x_
255
+ spks.append(item["spk"])
256
+ filepaths.append(item["filepath"])
257
+ x_texts.append(item["x_text"])
258
+ if item["durations"] is not None:
259
+ durations[i, : item["durations"].shape[-1]] = item["durations"]
260
+
261
+ y_lengths = torch.tensor(y_lengths, dtype=torch.long)
262
+ x_lengths = torch.tensor(x_lengths, dtype=torch.long)
263
+ spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
264
+
265
+ return {
266
+ "x": x,
267
+ "x_lengths": x_lengths,
268
+ "y": y,
269
+ "y_lengths": y_lengths,
270
+ "spks": spks,
271
+ "filepaths": filepaths,
272
+ "x_texts": x_texts,
273
+ "durations": durations if not torch.eq(durations, 0).all() else None,
274
+ }
File without changes