xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (191) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +60 -44
  6. xinference/model/audio/chattts.py +25 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/cosyvoice.py +4 -3
  9. xinference/model/audio/custom.py +4 -5
  10. xinference/model/audio/fish_speech.py +228 -0
  11. xinference/model/audio/model_spec.json +8 -0
  12. xinference/model/embedding/core.py +25 -1
  13. xinference/model/embedding/custom.py +4 -5
  14. xinference/model/flexible/core.py +5 -1
  15. xinference/model/image/custom.py +4 -5
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +66 -3
  19. xinference/model/llm/__init__.py +6 -0
  20. xinference/model/llm/llm_family.json +54 -9
  21. xinference/model/llm/llm_family.py +7 -6
  22. xinference/model/llm/llm_family_modelscope.json +56 -10
  23. xinference/model/llm/lmdeploy/__init__.py +0 -0
  24. xinference/model/llm/lmdeploy/core.py +557 -0
  25. xinference/model/llm/sglang/core.py +7 -1
  26. xinference/model/llm/transformers/cogvlm2.py +4 -45
  27. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  28. xinference/model/llm/transformers/core.py +3 -0
  29. xinference/model/llm/transformers/glm4v.py +2 -23
  30. xinference/model/llm/transformers/intern_vl.py +94 -11
  31. xinference/model/llm/transformers/minicpmv25.py +2 -23
  32. xinference/model/llm/transformers/minicpmv26.py +2 -22
  33. xinference/model/llm/transformers/yi_vl.py +2 -24
  34. xinference/model/llm/utils.py +13 -1
  35. xinference/model/llm/vllm/core.py +1 -34
  36. xinference/model/rerank/custom.py +4 -5
  37. xinference/model/utils.py +41 -1
  38. xinference/model/video/core.py +3 -1
  39. xinference/model/video/diffusers.py +41 -38
  40. xinference/model/video/model_spec.json +24 -1
  41. xinference/model/video/model_spec_modelscope.json +25 -1
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/matcha/__init__.py +0 -0
  137. xinference/thirdparty/matcha/app.py +357 -0
  138. xinference/thirdparty/matcha/cli.py +419 -0
  139. xinference/thirdparty/matcha/data/__init__.py +0 -0
  140. xinference/thirdparty/matcha/data/components/__init__.py +0 -0
  141. xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
  142. xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
  143. xinference/thirdparty/matcha/hifigan/config.py +28 -0
  144. xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
  145. xinference/thirdparty/matcha/hifigan/env.py +17 -0
  146. xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
  147. xinference/thirdparty/matcha/hifigan/models.py +368 -0
  148. xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
  149. xinference/thirdparty/matcha/models/__init__.py +0 -0
  150. xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
  151. xinference/thirdparty/matcha/models/components/__init__.py +0 -0
  152. xinference/thirdparty/matcha/models/components/decoder.py +443 -0
  153. xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
  154. xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
  155. xinference/thirdparty/matcha/models/components/transformer.py +316 -0
  156. xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
  157. xinference/thirdparty/matcha/onnx/__init__.py +0 -0
  158. xinference/thirdparty/matcha/onnx/export.py +181 -0
  159. xinference/thirdparty/matcha/onnx/infer.py +168 -0
  160. xinference/thirdparty/matcha/text/__init__.py +53 -0
  161. xinference/thirdparty/matcha/text/cleaners.py +121 -0
  162. xinference/thirdparty/matcha/text/numbers.py +71 -0
  163. xinference/thirdparty/matcha/text/symbols.py +17 -0
  164. xinference/thirdparty/matcha/train.py +122 -0
  165. xinference/thirdparty/matcha/utils/__init__.py +5 -0
  166. xinference/thirdparty/matcha/utils/audio.py +82 -0
  167. xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
  168. xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
  169. xinference/thirdparty/matcha/utils/instantiators.py +56 -0
  170. xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
  171. xinference/thirdparty/matcha/utils/model.py +90 -0
  172. xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
  173. xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
  174. xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
  175. xinference/thirdparty/matcha/utils/pylogger.py +21 -0
  176. xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
  177. xinference/thirdparty/matcha/utils/utils.py +259 -0
  178. xinference/web/ui/build/asset-manifest.json +3 -3
  179. xinference/web/ui/build/index.html +1 -1
  180. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  181. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  182. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  183. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
  184. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
  185. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  186. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  187. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  188. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
  189. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
  190. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
  191. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,181 @@
1
+ import argparse
2
+ import random
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ from lightning import LightningModule
8
+
9
+ from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder
10
+
11
+ DEFAULT_OPSET = 15
12
+
13
+ SEED = 1234
14
+ random.seed(SEED)
15
+ np.random.seed(SEED)
16
+ torch.manual_seed(SEED)
17
+ torch.cuda.manual_seed(SEED)
18
+ torch.backends.cudnn.deterministic = True
19
+ torch.backends.cudnn.benchmark = False
20
+
21
+
22
+ class MatchaWithVocoder(LightningModule):
23
+ def __init__(self, matcha, vocoder):
24
+ super().__init__()
25
+ self.matcha = matcha
26
+ self.vocoder = vocoder
27
+
28
+ def forward(self, x, x_lengths, scales, spks=None):
29
+ mel, mel_lengths = self.matcha(x, x_lengths, scales, spks)
30
+ wavs = self.vocoder(mel).clamp(-1, 1)
31
+ lengths = mel_lengths * 256
32
+ return wavs.squeeze(1), lengths
33
+
34
+
35
+ def get_exportable_module(matcha, vocoder, n_timesteps):
36
+ """
37
+ Return an appropriate `LighteningModule` and output-node names
38
+ based on whether the vocoder is embedded in the final graph
39
+ """
40
+
41
+ def onnx_forward_func(x, x_lengths, scales, spks=None):
42
+ """
43
+ Custom forward function for accepting
44
+ scaler parameters as tensors
45
+ """
46
+ # Extract scaler parameters from tensors
47
+ temperature = scales[0]
48
+ length_scale = scales[1]
49
+ output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale)
50
+ return output["mel"], output["mel_lengths"]
51
+
52
+ # Monkey-patch Matcha's forward function
53
+ matcha.forward = onnx_forward_func
54
+
55
+ if vocoder is None:
56
+ model, output_names = matcha, ["mel", "mel_lengths"]
57
+ else:
58
+ model = MatchaWithVocoder(matcha, vocoder)
59
+ output_names = ["wav", "wav_lengths"]
60
+ return model, output_names
61
+
62
+
63
+ def get_inputs(is_multi_speaker):
64
+ """
65
+ Create dummy inputs for tracing
66
+ """
67
+ dummy_input_length = 50
68
+ x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long)
69
+ x_lengths = torch.LongTensor([dummy_input_length])
70
+
71
+ # Scales
72
+ temperature = 0.667
73
+ length_scale = 1.0
74
+ scales = torch.Tensor([temperature, length_scale])
75
+
76
+ model_inputs = [x, x_lengths, scales]
77
+ input_names = [
78
+ "x",
79
+ "x_lengths",
80
+ "scales",
81
+ ]
82
+
83
+ if is_multi_speaker:
84
+ spks = torch.LongTensor([1])
85
+ model_inputs.append(spks)
86
+ input_names.append("spks")
87
+
88
+ return tuple(model_inputs), input_names
89
+
90
+
91
+ def main():
92
+ parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX")
93
+
94
+ parser.add_argument(
95
+ "checkpoint_path",
96
+ type=str,
97
+ help="Path to the model checkpoint",
98
+ )
99
+ parser.add_argument("output", type=str, help="Path to output `.onnx` file")
100
+ parser.add_argument(
101
+ "--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)"
102
+ )
103
+ parser.add_argument(
104
+ "--vocoder-name",
105
+ type=str,
106
+ choices=list(VOCODER_URLS.keys()),
107
+ default=None,
108
+ help="Name of the vocoder to embed in the ONNX graph",
109
+ )
110
+ parser.add_argument(
111
+ "--vocoder-checkpoint-path",
112
+ type=str,
113
+ default=None,
114
+ help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience",
115
+ )
116
+ parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15")
117
+
118
+ args = parser.parse_args()
119
+
120
+ print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}")
121
+ print(f"Setting n_timesteps to {args.n_timesteps}")
122
+
123
+ checkpoint_path = Path(args.checkpoint_path)
124
+ matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu")
125
+
126
+ if args.vocoder_name or args.vocoder_checkpoint_path:
127
+ assert (
128
+ args.vocoder_name and args.vocoder_checkpoint_path
129
+ ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph."
130
+ vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu")
131
+ else:
132
+ vocoder = None
133
+
134
+ is_multi_speaker = matcha.n_spks > 1
135
+
136
+ dummy_input, input_names = get_inputs(is_multi_speaker)
137
+ model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps)
138
+
139
+ # Set dynamic shape for inputs/outputs
140
+ dynamic_axes = {
141
+ "x": {0: "batch_size", 1: "time"},
142
+ "x_lengths": {0: "batch_size"},
143
+ }
144
+
145
+ if vocoder is None:
146
+ dynamic_axes.update(
147
+ {
148
+ "mel": {0: "batch_size", 2: "time"},
149
+ "mel_lengths": {0: "batch_size"},
150
+ }
151
+ )
152
+ else:
153
+ print("Embedding the vocoder in the ONNX graph")
154
+ dynamic_axes.update(
155
+ {
156
+ "wav": {0: "batch_size", 1: "time"},
157
+ "wav_lengths": {0: "batch_size"},
158
+ }
159
+ )
160
+
161
+ if is_multi_speaker:
162
+ dynamic_axes["spks"] = {0: "batch_size"}
163
+
164
+ # Create the output directory (if not exists)
165
+ Path(args.output).parent.mkdir(parents=True, exist_ok=True)
166
+
167
+ model.to_onnx(
168
+ args.output,
169
+ dummy_input,
170
+ input_names=input_names,
171
+ output_names=output_names,
172
+ dynamic_axes=dynamic_axes,
173
+ opset_version=args.opset,
174
+ export_params=True,
175
+ do_constant_folding=True,
176
+ )
177
+ print(f"[🍵] ONNX model exported to {args.output}")
178
+
179
+
180
+ if __name__ == "__main__":
181
+ main()
@@ -0,0 +1,168 @@
1
+ import argparse
2
+ import os
3
+ import warnings
4
+ from pathlib import Path
5
+ from time import perf_counter
6
+
7
+ import numpy as np
8
+ import onnxruntime as ort
9
+ import soundfile as sf
10
+ import torch
11
+
12
+ from matcha.cli import plot_spectrogram_to_numpy, process_text
13
+
14
+
15
+ def validate_args(args):
16
+ assert (
17
+ args.text or args.file
18
+ ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
19
+ assert args.temperature >= 0, "Sampling temperature cannot be negative"
20
+ assert args.speaking_rate >= 0, "Speaking rate must be greater than 0"
21
+ return args
22
+
23
+
24
+ def write_wavs(model, inputs, output_dir, external_vocoder=None):
25
+ if external_vocoder is None:
26
+ print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly")
27
+ t0 = perf_counter()
28
+ wavs, wav_lengths = model.run(None, inputs)
29
+ infer_secs = perf_counter() - t0
30
+ mel_infer_secs = vocoder_infer_secs = None
31
+ else:
32
+ print("[🍵] Generating mel using Matcha")
33
+ mel_t0 = perf_counter()
34
+ mels, mel_lengths = model.run(None, inputs)
35
+ mel_infer_secs = perf_counter() - mel_t0
36
+ print("Generating waveform from mel using external vocoder")
37
+ vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels}
38
+ vocoder_t0 = perf_counter()
39
+ wavs = external_vocoder.run(None, vocoder_inputs)[0]
40
+ vocoder_infer_secs = perf_counter() - vocoder_t0
41
+ wavs = wavs.squeeze(1)
42
+ wav_lengths = mel_lengths * 256
43
+ infer_secs = mel_infer_secs + vocoder_infer_secs
44
+
45
+ output_dir = Path(output_dir)
46
+ output_dir.mkdir(parents=True, exist_ok=True)
47
+ for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)):
48
+ output_filename = output_dir.joinpath(f"output_{i + 1}.wav")
49
+ audio = wav[:wav_length]
50
+ print(f"Writing audio to {output_filename}")
51
+ sf.write(output_filename, audio, 22050, "PCM_24")
52
+
53
+ wav_secs = wav_lengths.sum() / 22050
54
+ print(f"Inference seconds: {infer_secs}")
55
+ print(f"Generated wav seconds: {wav_secs}")
56
+ rtf = infer_secs / wav_secs
57
+ if mel_infer_secs is not None:
58
+ mel_rtf = mel_infer_secs / wav_secs
59
+ print(f"Matcha RTF: {mel_rtf}")
60
+ if vocoder_infer_secs is not None:
61
+ vocoder_rtf = vocoder_infer_secs / wav_secs
62
+ print(f"Vocoder RTF: {vocoder_rtf}")
63
+ print(f"Overall RTF: {rtf}")
64
+
65
+
66
+ def write_mels(model, inputs, output_dir):
67
+ t0 = perf_counter()
68
+ mels, mel_lengths = model.run(None, inputs)
69
+ infer_secs = perf_counter() - t0
70
+
71
+ output_dir = Path(output_dir)
72
+ output_dir.mkdir(parents=True, exist_ok=True)
73
+ for i, mel in enumerate(mels):
74
+ output_stem = output_dir.joinpath(f"output_{i + 1}")
75
+ plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png"))
76
+ np.save(output_stem.with_suffix(".numpy"), mel)
77
+
78
+ wav_secs = (mel_lengths * 256).sum() / 22050
79
+ print(f"Inference seconds: {infer_secs}")
80
+ print(f"Generated wav seconds: {wav_secs}")
81
+ rtf = infer_secs / wav_secs
82
+ print(f"RTF: {rtf}")
83
+
84
+
85
+ def main():
86
+ parser = argparse.ArgumentParser(
87
+ description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching"
88
+ )
89
+ parser.add_argument(
90
+ "model",
91
+ type=str,
92
+ help="ONNX model to use",
93
+ )
94
+ parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)")
95
+ parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
96
+ parser.add_argument("--file", type=str, default=None, help="Text file to synthesize")
97
+ parser.add_argument("--spk", type=int, default=None, help="Speaker ID")
98
+ parser.add_argument(
99
+ "--temperature",
100
+ type=float,
101
+ default=0.667,
102
+ help="Variance of the x0 noise (default: 0.667)",
103
+ )
104
+ parser.add_argument(
105
+ "--speaking-rate",
106
+ type=float,
107
+ default=1.0,
108
+ help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)",
109
+ )
110
+ parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)")
111
+ parser.add_argument(
112
+ "--output-dir",
113
+ type=str,
114
+ default=os.getcwd(),
115
+ help="Output folder to save results (default: current dir)",
116
+ )
117
+
118
+ args = parser.parse_args()
119
+ args = validate_args(args)
120
+
121
+ if args.gpu:
122
+ providers = ["GPUExecutionProvider"]
123
+ else:
124
+ providers = ["CPUExecutionProvider"]
125
+ model = ort.InferenceSession(args.model, providers=providers)
126
+
127
+ model_inputs = model.get_inputs()
128
+ model_outputs = list(model.get_outputs())
129
+
130
+ if args.text:
131
+ text_lines = args.text.splitlines()
132
+ else:
133
+ with open(args.file, encoding="utf-8") as file:
134
+ text_lines = file.read().splitlines()
135
+
136
+ processed_lines = [process_text(0, line, "cpu") for line in text_lines]
137
+ x = [line["x"].squeeze() for line in processed_lines]
138
+ # Pad
139
+ x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True)
140
+ x = x.detach().cpu().numpy()
141
+ x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64)
142
+ inputs = {
143
+ "x": x,
144
+ "x_lengths": x_lengths,
145
+ "scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32),
146
+ }
147
+ is_multi_speaker = len(model_inputs) == 4
148
+ if is_multi_speaker:
149
+ if args.spk is None:
150
+ args.spk = 0
151
+ warn = "[!] Speaker ID not provided! Using speaker ID 0"
152
+ warnings.warn(warn, UserWarning)
153
+ inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64)
154
+
155
+ has_vocoder_embedded = model_outputs[0].name == "wav"
156
+ if has_vocoder_embedded:
157
+ write_wavs(model, inputs, args.output_dir)
158
+ elif args.vocoder:
159
+ external_vocoder = ort.InferenceSession(args.vocoder, providers=providers)
160
+ write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder)
161
+ else:
162
+ warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory"
163
+ warnings.warn(warn, UserWarning)
164
+ write_mels(model, inputs, args.output_dir)
165
+
166
+
167
+ if __name__ == "__main__":
168
+ main()
@@ -0,0 +1,53 @@
1
+ """ from https://github.com/keithito/tacotron """
2
+ from matcha.text import cleaners
3
+ from matcha.text.symbols import symbols
4
+
5
+ # Mappings from symbol to numeric ID and vice versa:
6
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
7
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension
8
+
9
+
10
+ def text_to_sequence(text, cleaner_names):
11
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
12
+ Args:
13
+ text: string to convert to a sequence
14
+ cleaner_names: names of the cleaner functions to run the text through
15
+ Returns:
16
+ List of integers corresponding to the symbols in the text
17
+ """
18
+ sequence = []
19
+
20
+ clean_text = _clean_text(text, cleaner_names)
21
+ for symbol in clean_text:
22
+ symbol_id = _symbol_to_id[symbol]
23
+ sequence += [symbol_id]
24
+ return sequence, clean_text
25
+
26
+
27
+ def cleaned_text_to_sequence(cleaned_text):
28
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
29
+ Args:
30
+ text: string to convert to a sequence
31
+ Returns:
32
+ List of integers corresponding to the symbols in the text
33
+ """
34
+ sequence = [_symbol_to_id[symbol] for symbol in cleaned_text]
35
+ return sequence
36
+
37
+
38
+ def sequence_to_text(sequence):
39
+ """Converts a sequence of IDs back to a string"""
40
+ result = ""
41
+ for symbol_id in sequence:
42
+ s = _id_to_symbol[symbol_id]
43
+ result += s
44
+ return result
45
+
46
+
47
+ def _clean_text(text, cleaner_names):
48
+ for name in cleaner_names:
49
+ cleaner = getattr(cleaners, name)
50
+ if not cleaner:
51
+ raise Exception("Unknown cleaner: %s" % name)
52
+ text = cleaner(text)
53
+ return text
@@ -0,0 +1,121 @@
1
+ """ from https://github.com/keithito/tacotron
2
+
3
+ Cleaners are transformations that run over the input text at both training and eval time.
4
+
5
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
6
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
7
+ 1. "english_cleaners" for English text
8
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
9
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
10
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
11
+ the symbols in symbols.py to match your data).
12
+ """
13
+
14
+ import logging
15
+ import re
16
+
17
+ import phonemizer
18
+ from unidecode import unidecode
19
+
20
+ # To avoid excessive logging we set the log level of the phonemizer package to Critical
21
+ critical_logger = logging.getLogger("phonemizer")
22
+ critical_logger.setLevel(logging.CRITICAL)
23
+
24
+ # Intializing the phonemizer globally significantly reduces the speed
25
+ # now the phonemizer is not initialising at every call
26
+ # Might be less flexible, but it is much-much faster
27
+ global_phonemizer = phonemizer.backend.EspeakBackend(
28
+ language="en-us",
29
+ preserve_punctuation=True,
30
+ with_stress=True,
31
+ language_switch="remove-flags",
32
+ logger=critical_logger,
33
+ )
34
+
35
+
36
+ # Regular expression matching whitespace:
37
+ _whitespace_re = re.compile(r"\s+")
38
+
39
+ # List of (regular expression, replacement) pairs for abbreviations:
40
+ _abbreviations = [
41
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
42
+ for x in [
43
+ ("mrs", "misess"),
44
+ ("mr", "mister"),
45
+ ("dr", "doctor"),
46
+ ("st", "saint"),
47
+ ("co", "company"),
48
+ ("jr", "junior"),
49
+ ("maj", "major"),
50
+ ("gen", "general"),
51
+ ("drs", "doctors"),
52
+ ("rev", "reverend"),
53
+ ("lt", "lieutenant"),
54
+ ("hon", "honorable"),
55
+ ("sgt", "sergeant"),
56
+ ("capt", "captain"),
57
+ ("esq", "esquire"),
58
+ ("ltd", "limited"),
59
+ ("col", "colonel"),
60
+ ("ft", "fort"),
61
+ ]
62
+ ]
63
+
64
+
65
+ def expand_abbreviations(text):
66
+ for regex, replacement in _abbreviations:
67
+ text = re.sub(regex, replacement, text)
68
+ return text
69
+
70
+
71
+ def lowercase(text):
72
+ return text.lower()
73
+
74
+
75
+ def collapse_whitespace(text):
76
+ return re.sub(_whitespace_re, " ", text)
77
+
78
+
79
+ def convert_to_ascii(text):
80
+ return unidecode(text)
81
+
82
+
83
+ def basic_cleaners(text):
84
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
85
+ text = lowercase(text)
86
+ text = collapse_whitespace(text)
87
+ return text
88
+
89
+
90
+ def transliteration_cleaners(text):
91
+ """Pipeline for non-English text that transliterates to ASCII."""
92
+ text = convert_to_ascii(text)
93
+ text = lowercase(text)
94
+ text = collapse_whitespace(text)
95
+ return text
96
+
97
+
98
+ def english_cleaners2(text):
99
+ """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
100
+ text = convert_to_ascii(text)
101
+ text = lowercase(text)
102
+ text = expand_abbreviations(text)
103
+ phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0]
104
+ phonemes = collapse_whitespace(phonemes)
105
+ return phonemes
106
+
107
+
108
+ # I am removing this due to incompatibility with several version of python
109
+ # However, if you want to use it, you can uncomment it
110
+ # and install piper-phonemize with the following command:
111
+ # pip install piper-phonemize
112
+
113
+ # import piper_phonemize
114
+ # def english_cleaners_piper(text):
115
+ # """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
116
+ # text = convert_to_ascii(text)
117
+ # text = lowercase(text)
118
+ # text = expand_abbreviations(text)
119
+ # phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
120
+ # phonemes = collapse_whitespace(phonemes)
121
+ # return phonemes
@@ -0,0 +1,71 @@
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ import re
4
+
5
+ import inflect
6
+
7
+ _inflect = inflect.engine()
8
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
9
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
10
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
11
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
12
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
13
+ _number_re = re.compile(r"[0-9]+")
14
+
15
+
16
+ def _remove_commas(m):
17
+ return m.group(1).replace(",", "")
18
+
19
+
20
+ def _expand_decimal_point(m):
21
+ return m.group(1).replace(".", " point ")
22
+
23
+
24
+ def _expand_dollars(m):
25
+ match = m.group(1)
26
+ parts = match.split(".")
27
+ if len(parts) > 2:
28
+ return match + " dollars"
29
+ dollars = int(parts[0]) if parts[0] else 0
30
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31
+ if dollars and cents:
32
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
33
+ cent_unit = "cent" if cents == 1 else "cents"
34
+ return f"{dollars} {dollar_unit}, {cents} {cent_unit}"
35
+ elif dollars:
36
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
37
+ return f"{dollars} {dollar_unit}"
38
+ elif cents:
39
+ cent_unit = "cent" if cents == 1 else "cents"
40
+ return f"{cents} {cent_unit}"
41
+ else:
42
+ return "zero dollars"
43
+
44
+
45
+ def _expand_ordinal(m):
46
+ return _inflect.number_to_words(m.group(0))
47
+
48
+
49
+ def _expand_number(m):
50
+ num = int(m.group(0))
51
+ if num > 1000 and num < 3000:
52
+ if num == 2000:
53
+ return "two thousand"
54
+ elif num > 2000 and num < 2010:
55
+ return "two thousand " + _inflect.number_to_words(num % 100)
56
+ elif num % 100 == 0:
57
+ return _inflect.number_to_words(num // 100) + " hundred"
58
+ else:
59
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
60
+ else:
61
+ return _inflect.number_to_words(num, andword="")
62
+
63
+
64
+ def normalize_numbers(text):
65
+ text = re.sub(_comma_number_re, _remove_commas, text)
66
+ text = re.sub(_pounds_re, r"\1 pounds", text)
67
+ text = re.sub(_dollars_re, _expand_dollars, text)
68
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
69
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
70
+ text = re.sub(_number_re, _expand_number, text)
71
+ return text
@@ -0,0 +1,17 @@
1
+ """ from https://github.com/keithito/tacotron
2
+
3
+ Defines the set of symbols used in text input to the model.
4
+ """
5
+ _pad = "_"
6
+ _punctuation = ';:,.!?¡¿—…"«»“” '
7
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
8
+ _letters_ipa = (
9
+ "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
10
+ )
11
+
12
+
13
+ # Export all symbols:
14
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
15
+
16
+ # Special symbol ids
17
+ SPACE_ID = symbols.index(" ")