xinference 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (94) hide show
  1. xinference/_compat.py +22 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +91 -6
  4. xinference/client/restful/restful_client.py +39 -0
  5. xinference/core/model.py +41 -13
  6. xinference/deploy/cmdline.py +3 -1
  7. xinference/deploy/test/test_cmdline.py +56 -0
  8. xinference/isolation.py +24 -0
  9. xinference/model/audio/__init__.py +12 -0
  10. xinference/model/audio/core.py +26 -4
  11. xinference/model/audio/f5tts.py +195 -0
  12. xinference/model/audio/fish_speech.py +71 -35
  13. xinference/model/audio/model_spec.json +88 -0
  14. xinference/model/audio/model_spec_modelscope.json +9 -0
  15. xinference/model/audio/whisper_mlx.py +208 -0
  16. xinference/model/embedding/core.py +322 -6
  17. xinference/model/embedding/model_spec.json +8 -1
  18. xinference/model/embedding/model_spec_modelscope.json +9 -1
  19. xinference/model/llm/__init__.py +4 -2
  20. xinference/model/llm/llm_family.json +479 -53
  21. xinference/model/llm/llm_family_modelscope.json +423 -17
  22. xinference/model/llm/mlx/core.py +230 -50
  23. xinference/model/llm/sglang/core.py +2 -0
  24. xinference/model/llm/transformers/chatglm.py +9 -5
  25. xinference/model/llm/transformers/core.py +1 -0
  26. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  27. xinference/model/llm/transformers/utils.py +16 -8
  28. xinference/model/llm/utils.py +23 -1
  29. xinference/model/llm/vllm/core.py +89 -2
  30. xinference/thirdparty/f5_tts/__init__.py +0 -0
  31. xinference/thirdparty/f5_tts/api.py +166 -0
  32. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  33. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  34. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  35. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  36. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  37. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  38. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  39. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  40. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  41. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  42. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  43. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  44. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  45. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  46. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  47. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  48. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  49. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  50. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  51. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  52. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  53. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  54. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  55. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  56. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  57. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  58. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  59. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  60. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  61. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  62. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  63. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  64. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  65. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  66. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  67. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  68. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  69. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  70. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  71. xinference/thirdparty/f5_tts/train/README.md +77 -0
  72. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  73. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  74. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  75. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  76. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  77. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  78. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  79. xinference/thirdparty/f5_tts/train/train.py +75 -0
  80. xinference/types.py +2 -1
  81. xinference/web/ui/build/asset-manifest.json +3 -3
  82. xinference/web/ui/build/index.html +1 -1
  83. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  84. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  86. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/METADATA +39 -18
  87. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/RECORD +92 -39
  88. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/WHEEL +1 -1
  89. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  91. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  92. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/LICENSE +0 -0
  93. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/entry_points.txt +0 -0
  94. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,195 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+ import re
18
+ from io import BytesIO
19
+ from typing import TYPE_CHECKING, Optional
20
+
21
+ if TYPE_CHECKING:
22
+ from .core import AudioModelFamilyV1
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class F5TTSModel:
28
+ def __init__(
29
+ self,
30
+ model_uid: str,
31
+ model_path: str,
32
+ model_spec: "AudioModelFamilyV1",
33
+ device: Optional[str] = None,
34
+ **kwargs,
35
+ ):
36
+ self._model_uid = model_uid
37
+ self._model_path = model_path
38
+ self._model_spec = model_spec
39
+ self._device = device
40
+ self._model = None
41
+ self._vocoder = None
42
+ self._kwargs = kwargs
43
+
44
+ @property
45
+ def model_ability(self):
46
+ return self._model_spec.model_ability
47
+
48
+ def load(self):
49
+ import os
50
+ import sys
51
+
52
+ # The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
53
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))
54
+
55
+ from f5_tts.infer.utils_infer import load_model, load_vocoder
56
+ from f5_tts.model import DiT
57
+
58
+ vocoder_name = self._kwargs.get("vocoder_name", "vocos")
59
+ vocoder_path = self._kwargs.get("vocoder_path")
60
+
61
+ if vocoder_name not in ["vocos", "bigvgan"]:
62
+ raise Exception(f"Unsupported vocoder name: {vocoder_name}")
63
+
64
+ if vocoder_path is not None:
65
+ self._vocoder = load_vocoder(
66
+ vocoder_name=vocoder_name, is_local=True, local_path=vocoder_path
67
+ )
68
+ else:
69
+ self._vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=False)
70
+
71
+ model_cls = DiT
72
+ model_cfg = dict(
73
+ dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
74
+ )
75
+ if vocoder_name == "vocos":
76
+ exp_name = "F5TTS_Base"
77
+ ckpt_step = 1200000
78
+ elif vocoder_name == "bigvgan":
79
+ exp_name = "F5TTS_Base_bigvgan"
80
+ ckpt_step = 1250000
81
+ else:
82
+ assert False
83
+ ckpt_file = os.path.join(
84
+ self._model_path, exp_name, f"model_{ckpt_step}.safetensors"
85
+ )
86
+ logger.info(f"Loading %s...", ckpt_file)
87
+ self._model = load_model(
88
+ model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name
89
+ )
90
+
91
+ def _infer(self, ref_audio, ref_text, text_gen, model_obj, mel_spec_type, speed):
92
+ import numpy as np
93
+ from f5_tts.infer.utils_infer import infer_process, preprocess_ref_audio_text
94
+
95
+ config = {}
96
+ main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
97
+ if "voices" not in config:
98
+ voices = {"main": main_voice}
99
+ else:
100
+ voices = config["voices"]
101
+ voices["main"] = main_voice
102
+ for voice in voices:
103
+ (
104
+ voices[voice]["ref_audio"],
105
+ voices[voice]["ref_text"],
106
+ ) = preprocess_ref_audio_text(
107
+ voices[voice]["ref_audio"], voices[voice]["ref_text"]
108
+ )
109
+ print("Voice:", voice)
110
+ print("Ref_audio:", voices[voice]["ref_audio"])
111
+ print("Ref_text:", voices[voice]["ref_text"])
112
+
113
+ final_sample_rate = None
114
+ generated_audio_segments = []
115
+ reg1 = r"(?=\[\w+\])"
116
+ chunks = re.split(reg1, text_gen)
117
+ reg2 = r"\[(\w+)\]"
118
+ for text in chunks:
119
+ if not text.strip():
120
+ continue
121
+ match = re.match(reg2, text)
122
+ if match:
123
+ voice = match[1]
124
+ else:
125
+ print("No voice tag found, using main.")
126
+ voice = "main"
127
+ if voice not in voices:
128
+ print(f"Voice {voice} not found, using main.")
129
+ voice = "main"
130
+ text = re.sub(reg2, "", text)
131
+ gen_text = text.strip()
132
+ ref_audio = voices[voice]["ref_audio"]
133
+ ref_text = voices[voice]["ref_text"]
134
+ print(f"Voice: {voice}")
135
+ audio, final_sample_rate, spectragram = infer_process(
136
+ ref_audio,
137
+ ref_text,
138
+ gen_text,
139
+ model_obj,
140
+ self._vocoder,
141
+ mel_spec_type=mel_spec_type,
142
+ speed=speed,
143
+ )
144
+ generated_audio_segments.append(audio)
145
+
146
+ if generated_audio_segments:
147
+ final_wave = np.concatenate(generated_audio_segments)
148
+ return final_sample_rate, final_wave
149
+ return None, None
150
+
151
+ def speech(
152
+ self,
153
+ input: str,
154
+ voice: str,
155
+ response_format: str = "mp3",
156
+ speed: float = 1.0,
157
+ stream: bool = False,
158
+ **kwargs,
159
+ ):
160
+ import f5_tts
161
+ import soundfile
162
+ import tomli
163
+
164
+ if stream:
165
+ raise Exception("F5-TTS does not support stream generation.")
166
+
167
+ prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
168
+ prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
169
+
170
+ if prompt_speech is None:
171
+ base = os.path.dirname(f5_tts.__file__)
172
+ config = os.path.join(base, "infer/examples/basic/basic.toml")
173
+ with open(config, "rb") as f:
174
+ config_dict = tomli.load(f)
175
+ prompt_speech = os.path.join(base, config_dict["ref_audio"])
176
+ prompt_text = config_dict["ref_text"]
177
+
178
+ assert self._model is not None
179
+ vocoder_name = self._kwargs.get("vocoder_name", "vocos")
180
+ sample_rate, wav = self._infer(
181
+ ref_audio=prompt_speech,
182
+ ref_text=prompt_text,
183
+ text_gen=input,
184
+ model_obj=self._model,
185
+ mel_spec_type=vocoder_name,
186
+ speed=speed,
187
+ )
188
+
189
+ # Save the generated audio
190
+ with BytesIO() as out:
191
+ with soundfile.SoundFile(
192
+ out, "w", sample_rate, 1, format=response_format.upper()
193
+ ) as f:
194
+ f.write(wav)
195
+ return out.getvalue()
@@ -81,12 +81,14 @@ class FishSpeechModel:
81
81
  if not is_device_available(self._device):
82
82
  raise ValueError(f"Device {self._device} is not available!")
83
83
 
84
- logger.info("Loading Llama model...")
84
+ enable_compile = self._kwargs.get("compile", False)
85
+ precision = self._kwargs.get("precision", torch.bfloat16)
86
+ logger.info("Loading Llama model, compile=%s...", enable_compile)
85
87
  self._llama_queue = launch_thread_safe_queue(
86
88
  checkpoint_path=self._model_path,
87
89
  device=self._device,
88
- precision=torch.bfloat16,
89
- compile=False,
90
+ precision=precision,
91
+ compile=enable_compile,
90
92
  )
91
93
  logger.info("Llama model loaded, loading VQ-GAN model...")
92
94
 
@@ -112,9 +114,10 @@ class FishSpeechModel:
112
114
  top_p,
113
115
  repetition_penalty,
114
116
  temperature,
117
+ seed="0",
115
118
  streaming=False,
116
119
  ):
117
- from fish_speech.utils import autocast_exclude_mps
120
+ from fish_speech.utils import autocast_exclude_mps, set_seed
118
121
  from tools.api import decode_vq_tokens, encode_reference
119
122
  from tools.llama.generate import (
120
123
  GenerateRequest,
@@ -122,6 +125,11 @@ class FishSpeechModel:
122
125
  WrappedGenerateResponse,
123
126
  )
124
127
 
128
+ seed = int(seed)
129
+ if seed != 0:
130
+ set_seed(seed)
131
+ logger.warning(f"set seed: {seed}")
132
+
125
133
  # Parse reference audio aka prompt
126
134
  prompt_tokens = encode_reference(
127
135
  decoder_model=self._model,
@@ -137,7 +145,7 @@ class FishSpeechModel:
137
145
  top_p=top_p,
138
146
  repetition_penalty=repetition_penalty,
139
147
  temperature=temperature,
140
- compile=False,
148
+ compile=self._kwargs.get("compile", False),
141
149
  iterative_prompt=chunk_length > 0,
142
150
  chunk_length=chunk_length,
143
151
  max_length=2048,
@@ -153,22 +161,20 @@ class FishSpeechModel:
153
161
  )
154
162
  )
155
163
 
156
- if streaming:
157
- yield wav_chunk_header(), None, None
158
-
159
164
  segments = []
160
165
 
161
166
  while True:
162
- result: WrappedGenerateResponse = response_queue.get() # type: ignore
167
+ result: WrappedGenerateResponse = response_queue.get()
163
168
  if result.status == "error":
164
- raise Exception(str(result.response))
169
+ raise result.response
165
170
 
166
- result: GenerateResponse = result.response # type: ignore
171
+ result: GenerateResponse = result.response
167
172
  if result.action == "next":
168
173
  break
169
174
 
170
175
  with autocast_exclude_mps(
171
- device_type=self._model.device.type, dtype=torch.bfloat16
176
+ device_type=self._model.device.type,
177
+ dtype=self._kwargs.get("precision", torch.bfloat16),
172
178
  ):
173
179
  fake_audios = decode_vq_tokens(
174
180
  decoder_model=self._model,
@@ -179,7 +185,7 @@ class FishSpeechModel:
179
185
  segments.append(fake_audios)
180
186
 
181
187
  if streaming:
182
- yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
188
+ yield fake_audios, None, None
183
189
 
184
190
  if len(segments) == 0:
185
191
  raise Exception("No audio generated, please check the input text.")
@@ -204,29 +210,59 @@ class FishSpeechModel:
204
210
  logger.warning("Fish speech does not support setting voice: %s.", voice)
205
211
  if speed != 1.0:
206
212
  logger.warning("Fish speech does not support setting speed: %s.", speed)
207
- if stream is True:
208
- logger.warning("stream mode is not implemented.")
209
213
  import torchaudio
210
214
 
211
- result = list(
212
- self._inference(
213
- text=input,
214
- enable_reference_audio=False,
215
- reference_audio=None,
216
- reference_text=kwargs.get("reference_text", ""),
217
- max_new_tokens=kwargs.get("max_new_tokens", 1024),
218
- chunk_length=kwargs.get("chunk_length", 200),
219
- top_p=kwargs.get("top_p", 0.7),
220
- repetition_penalty=kwargs.get("repetition_penalty", 1.2),
221
- temperature=kwargs.get("temperature", 0.7),
222
- )
215
+ prompt_speech = kwargs.get("prompt_speech")
216
+ prompt_text = kwargs.get("prompt_text", kwargs.get("reference_text", ""))
217
+ result = self._inference(
218
+ text=input,
219
+ enable_reference_audio=kwargs.get(
220
+ "enable_reference_audio", prompt_speech is not None
221
+ ),
222
+ reference_audio=prompt_speech,
223
+ reference_text=prompt_text,
224
+ max_new_tokens=kwargs.get("max_new_tokens", 1024),
225
+ chunk_length=kwargs.get("chunk_length", 200),
226
+ top_p=kwargs.get("top_p", 0.7),
227
+ repetition_penalty=kwargs.get("repetition_penalty", 1.2),
228
+ temperature=kwargs.get("temperature", 0.7),
229
+ streaming=stream,
223
230
  )
224
- sample_rate, audio = result[0][1]
225
- audio = np.array([audio])
226
231
 
227
- # Save the generated audio
228
- with BytesIO() as out:
229
- torchaudio.save(
230
- out, torch.from_numpy(audio), sample_rate, format=response_format
231
- )
232
- return out.getvalue()
232
+ if stream:
233
+
234
+ def _stream_generator():
235
+ with BytesIO() as out:
236
+ writer = torchaudio.io.StreamWriter(out, format=response_format)
237
+ writer.add_audio_stream(
238
+ sample_rate=self._model.spec_transform.sample_rate,
239
+ num_channels=1,
240
+ )
241
+ i = 0
242
+ last_pos = 0
243
+ with writer.open():
244
+ for chunk in result:
245
+ chunk = chunk[0]
246
+ if chunk is not None:
247
+ chunk = chunk.reshape((chunk.shape[0], 1))
248
+ trans_chunk = torch.from_numpy(chunk)
249
+ writer.write_audio_chunk(i, trans_chunk)
250
+ new_last_pos = out.tell()
251
+ if new_last_pos != last_pos:
252
+ out.seek(last_pos)
253
+ encoded_bytes = out.read()
254
+ yield encoded_bytes
255
+ last_pos = new_last_pos
256
+
257
+ return _stream_generator()
258
+ else:
259
+ result = list(result)
260
+ sample_rate, audio = result[0][1]
261
+ audio = np.array([audio])
262
+
263
+ # Save the generated audio
264
+ with BytesIO() as out:
265
+ torchaudio.save(
266
+ out, torch.from_numpy(audio), sample_rate, format=response_format
267
+ )
268
+ return out.getvalue()
@@ -103,6 +103,86 @@
103
103
  "model_ability": "audio-to-text",
104
104
  "multilingual": false
105
105
  },
106
+ {
107
+ "model_name": "whisper-tiny-mlx",
108
+ "model_family": "whisper",
109
+ "model_id": "mlx-community/whisper-tiny",
110
+ "model_ability": "audio-to-text",
111
+ "multilingual": true,
112
+ "engine": "mlx"
113
+ },
114
+ {
115
+ "model_name": "whisper-tiny.en-mlx",
116
+ "model_family": "whisper",
117
+ "model_id": "mlx-community/whisper-tiny.en-mlx",
118
+ "model_ability": "audio-to-text",
119
+ "multilingual": false,
120
+ "engine": "mlx"
121
+ },
122
+ {
123
+ "model_name": "whisper-base-mlx",
124
+ "model_family": "whisper",
125
+ "model_id": "mlx-community/whisper-base-mlx",
126
+ "model_ability": "audio-to-text",
127
+ "multilingual": true,
128
+ "engine": "mlx"
129
+ },
130
+ {
131
+ "model_name": "whisper-base.en-mlx",
132
+ "model_family": "whisper",
133
+ "model_id": "mlx-community/whisper-base.en-mlx",
134
+ "model_ability": "audio-to-text",
135
+ "multilingual": false,
136
+ "engine": "mlx"
137
+ },
138
+ {
139
+ "model_name": "whisper-small-mlx",
140
+ "model_family": "whisper",
141
+ "model_id": "mlx-community/whisper-small-mlx",
142
+ "model_ability": "audio-to-text",
143
+ "multilingual": true,
144
+ "engine": "mlx"
145
+ },
146
+ {
147
+ "model_name": "whisper-small.en-mlx",
148
+ "model_family": "whisper",
149
+ "model_id": "mlx-community/whisper-small.en-mlx",
150
+ "model_ability": "audio-to-text",
151
+ "multilingual": false,
152
+ "engine": "mlx"
153
+ },
154
+ {
155
+ "model_name": "whisper-medium-mlx",
156
+ "model_family": "whisper",
157
+ "model_id": "mlx-community/whisper-medium-mlx",
158
+ "model_ability": "audio-to-text",
159
+ "multilingual": true,
160
+ "engine": "mlx"
161
+ },
162
+ {
163
+ "model_name": "whisper-medium.en-mlx",
164
+ "model_family": "whisper",
165
+ "model_id": "mlx-community/whisper-medium.en-mlx",
166
+ "model_ability": "audio-to-text",
167
+ "multilingual": false,
168
+ "engine": "mlx"
169
+ },
170
+ {
171
+ "model_name": "whisper-large-v3-mlx",
172
+ "model_family": "whisper",
173
+ "model_id": "mlx-community/whisper-large-v3-mlx",
174
+ "model_ability": "audio-to-text",
175
+ "multilingual": true,
176
+ "engine": "mlx"
177
+ },
178
+ {
179
+ "model_name": "whisper-large-v3-turbo-mlx",
180
+ "model_family": "whisper",
181
+ "model_id": "mlx-community/whisper-large-v3-turbo",
182
+ "model_ability": "audio-to-text",
183
+ "multilingual": true,
184
+ "engine": "mlx"
185
+ },
106
186
  {
107
187
  "model_name": "SenseVoiceSmall",
108
188
  "model_family": "funasr",
@@ -162,5 +242,13 @@
162
242
  "model_revision": "069c573759936b35191d3380deb89183c0656f59",
163
243
  "model_ability": "text-to-audio",
164
244
  "multilingual": true
245
+ },
246
+ {
247
+ "model_name": "F5-TTS",
248
+ "model_family": "F5-TTS",
249
+ "model_id": "SWivid/F5-TTS",
250
+ "model_revision": "4dcc16f297f2ff98a17b3726b16f5de5a5e45672",
251
+ "model_ability": "text-to-audio",
252
+ "multilingual": true
165
253
  }
166
254
  ]
@@ -73,5 +73,14 @@
73
73
  "model_revision": "master",
74
74
  "model_ability": "text-to-audio",
75
75
  "multilingual": true
76
+ },
77
+ {
78
+ "model_name": "F5-TTS",
79
+ "model_family": "F5-TTS",
80
+ "model_hub": "modelscope",
81
+ "model_id": "SWivid/F5-TTS_Emilia-ZH-EN",
82
+ "model_revision": "master",
83
+ "model_ability": "text-to-audio",
84
+ "multilingual": true
76
85
  }
77
86
  ]
@@ -0,0 +1,208 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import functools
15
+ import itertools
16
+ import logging
17
+ import tempfile
18
+ from typing import TYPE_CHECKING, List, Optional
19
+
20
+ if TYPE_CHECKING:
21
+ from .core import AudioModelFamilyV1
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class WhisperMLXModel:
27
+ def __init__(
28
+ self,
29
+ model_uid: str,
30
+ model_path: str,
31
+ model_spec: "AudioModelFamilyV1",
32
+ device: Optional[str] = None,
33
+ **kwargs,
34
+ ):
35
+ self._model_uid = model_uid
36
+ self._model_path = model_path
37
+ self._model_spec = model_spec
38
+ self._device = device
39
+ self._model = None
40
+ self._kwargs = kwargs
41
+ self._use_lighting = False
42
+
43
+ @property
44
+ def model_ability(self):
45
+ return self._model_spec.model_ability
46
+
47
+ def load(self):
48
+ use_lightning = self._kwargs.get("use_lightning", "auto")
49
+ if use_lightning not in ("auto", True, False, None):
50
+ raise ValueError("use_lightning can only be True, False, None or auto")
51
+
52
+ if use_lightning == "auto" or use_lightning is True:
53
+ try:
54
+ import mlx.core as mx
55
+ from lightning_whisper_mlx.transcribe import ModelHolder
56
+ except ImportError:
57
+ if use_lightning == "auto":
58
+ use_lightning = False
59
+ else:
60
+ error_message = "Failed to import module 'lightning_whisper_mlx'"
61
+ installation_guide = [
62
+ "Please make sure 'lightning_whisper_mlx' is installed.\n",
63
+ ]
64
+
65
+ raise ImportError(
66
+ f"{error_message}\n\n{''.join(installation_guide)}"
67
+ )
68
+ else:
69
+ use_lightning = True
70
+ if not use_lightning:
71
+ try:
72
+ import mlx.core as mx # noqa: F811
73
+ from mlx_whisper.transcribe import ModelHolder # noqa: F811
74
+ except ImportError:
75
+ error_message = "Failed to import module 'mlx_whisper'"
76
+ installation_guide = [
77
+ "Please make sure 'mlx_whisper' is installed.\n",
78
+ ]
79
+
80
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
81
+ else:
82
+ use_lightning = False
83
+
84
+ logger.info(
85
+ "Loading MLX whisper from %s, use lightning: %s",
86
+ self._model_path,
87
+ use_lightning,
88
+ )
89
+ self._use_lighting = use_lightning
90
+ self._model = ModelHolder.get_model(self._model_path, mx.float16)
91
+
92
+ def transcriptions(
93
+ self,
94
+ audio: bytes,
95
+ language: Optional[str] = None,
96
+ prompt: Optional[str] = None,
97
+ response_format: str = "json",
98
+ temperature: float = 0,
99
+ timestamp_granularities: Optional[List[str]] = None,
100
+ ):
101
+ return self._call(
102
+ audio,
103
+ language=language,
104
+ prompt=prompt,
105
+ response_format=response_format,
106
+ temperature=temperature,
107
+ timestamp_granularities=timestamp_granularities,
108
+ task="transcribe",
109
+ )
110
+
111
+ def translations(
112
+ self,
113
+ audio: bytes,
114
+ language: Optional[str] = None,
115
+ prompt: Optional[str] = None,
116
+ response_format: str = "json",
117
+ temperature: float = 0,
118
+ timestamp_granularities: Optional[List[str]] = None,
119
+ ):
120
+ if not self._model_spec.multilingual:
121
+ raise RuntimeError(
122
+ f"Model {self._model_spec.model_name} is not suitable for translations."
123
+ )
124
+ return self._call(
125
+ audio,
126
+ language=language,
127
+ prompt=prompt,
128
+ response_format=response_format,
129
+ temperature=temperature,
130
+ timestamp_granularities=timestamp_granularities,
131
+ task="translate",
132
+ )
133
+
134
+ def _call(
135
+ self,
136
+ audio: bytes,
137
+ language: Optional[str] = None,
138
+ prompt: Optional[str] = None,
139
+ response_format: str = "json",
140
+ temperature: float = 0,
141
+ timestamp_granularities: Optional[List[str]] = None,
142
+ task: str = "transcribe",
143
+ ):
144
+ if self._use_lighting:
145
+ from lightning_whisper_mlx.transcribe import transcribe_audio
146
+
147
+ transcribe = functools.partial(
148
+ transcribe_audio, batch_size=self._kwargs.get("batch_size", 12)
149
+ )
150
+ else:
151
+ from mlx_whisper import transcribe # type: ignore
152
+
153
+ with tempfile.NamedTemporaryFile(delete=True) as f:
154
+ f.write(audio)
155
+
156
+ kwargs = {"task": task}
157
+ if response_format == "verbose_json":
158
+ if timestamp_granularities == ["word"]:
159
+ kwargs["word_timestamps"] = True # type: ignore
160
+
161
+ result = transcribe(
162
+ f.name,
163
+ path_or_hf_repo=self._model_path,
164
+ language=language,
165
+ temperature=temperature,
166
+ initial_prompt=prompt,
167
+ **kwargs,
168
+ )
169
+ text = result["text"]
170
+ segments = result["segments"]
171
+ language = result["language"]
172
+
173
+ if response_format == "json":
174
+ return {"text": text}
175
+ elif response_format == "verbose_json":
176
+ if not timestamp_granularities or timestamp_granularities == [
177
+ "segment"
178
+ ]:
179
+ return {
180
+ "task": task,
181
+ "language": language,
182
+ "duration": segments[-1]["end"] if segments else 0,
183
+ "text": text,
184
+ "segments": segments,
185
+ }
186
+ else:
187
+ assert timestamp_granularities == ["word"]
188
+
189
+ def _extract_word(word: dict) -> dict:
190
+ return {
191
+ "start": word["start"].item(),
192
+ "end": word["end"].item(),
193
+ "word": word["word"],
194
+ }
195
+
196
+ words = [
197
+ _extract_word(w)
198
+ for w in itertools.chain(*[s["words"] for s in segments])
199
+ ]
200
+ return {
201
+ "task": task,
202
+ "language": language,
203
+ "duration": words[-1]["end"] if words else 0,
204
+ "text": text,
205
+ "words": words,
206
+ }
207
+ else:
208
+ raise ValueError(f"Unsupported response format: {response_format}")