xinference 0.16.3__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (69) hide show
  1. xinference/_compat.py +22 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +148 -12
  4. xinference/client/restful/restful_client.py +47 -2
  5. xinference/constants.py +1 -0
  6. xinference/core/model.py +45 -15
  7. xinference/core/supervisor.py +8 -2
  8. xinference/core/utils.py +67 -2
  9. xinference/model/audio/__init__.py +12 -0
  10. xinference/model/audio/core.py +21 -4
  11. xinference/model/audio/fish_speech.py +70 -35
  12. xinference/model/audio/model_spec.json +81 -1
  13. xinference/model/audio/whisper_mlx.py +208 -0
  14. xinference/model/embedding/core.py +259 -4
  15. xinference/model/embedding/model_spec.json +1 -1
  16. xinference/model/embedding/model_spec_modelscope.json +1 -1
  17. xinference/model/image/stable_diffusion/core.py +5 -2
  18. xinference/model/llm/__init__.py +2 -0
  19. xinference/model/llm/llm_family.json +485 -6
  20. xinference/model/llm/llm_family_modelscope.json +519 -0
  21. xinference/model/llm/mlx/core.py +45 -3
  22. xinference/model/llm/sglang/core.py +1 -0
  23. xinference/model/llm/transformers/core.py +1 -0
  24. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  25. xinference/model/llm/utils.py +19 -0
  26. xinference/model/llm/vllm/core.py +84 -2
  27. xinference/model/rerank/core.py +11 -4
  28. xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
  29. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  30. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  31. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  32. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  33. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  34. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
  35. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  36. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  37. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
  38. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  39. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  40. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  41. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  42. xinference/thirdparty/fish_speech/tools/api.py +578 -75
  43. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  44. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  45. xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
  46. xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
  47. xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
  48. xinference/thirdparty/fish_speech/tools/schema.py +187 -0
  49. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  50. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  51. xinference/thirdparty/fish_speech/tools/webui.py +138 -75
  52. xinference/types.py +2 -1
  53. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/METADATA +30 -6
  54. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/RECORD +58 -63
  55. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/WHEEL +1 -1
  56. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  57. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  58. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  61. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  62. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  63. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  64. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  65. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  67. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/LICENSE +0 -0
  68. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/entry_points.txt +0 -0
  69. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/top_level.txt +0 -0
xinference/core/utils.py CHANGED
@@ -11,11 +11,13 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import asyncio
14
15
  import logging
15
16
  import os
16
17
  import random
17
18
  import string
18
19
  import uuid
20
+ import weakref
19
21
  from enum import Enum
20
22
  from typing import Dict, Generator, List, Optional, Tuple, Union
21
23
 
@@ -23,7 +25,10 @@ import orjson
23
25
  from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown
24
26
 
25
27
  from .._compat import BaseModel
26
- from ..constants import XINFERENCE_LOG_ARG_MAX_LENGTH
28
+ from ..constants import (
29
+ XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
30
+ XINFERENCE_LOG_ARG_MAX_LENGTH,
31
+ )
27
32
 
28
33
  logger = logging.getLogger(__name__)
29
34
 
@@ -49,13 +54,20 @@ def log_async(
49
54
  ):
50
55
  import time
51
56
  from functools import wraps
57
+ from inspect import signature
52
58
 
53
59
  def decorator(func):
54
60
  func_name = func.__name__
61
+ sig = signature(func)
55
62
 
56
63
  @wraps(func)
57
64
  async def wrapped(*args, **kwargs):
58
- request_id_str = kwargs.get("request_id", "")
65
+ try:
66
+ bound_args = sig.bind_partial(*args, **kwargs)
67
+ arguments = bound_args.arguments
68
+ except TypeError:
69
+ arguments = {}
70
+ request_id_str = arguments.get("request_id", "")
59
71
  if not request_id_str:
60
72
  request_id_str = uuid.uuid1()
61
73
  if func_name == "text_to_image":
@@ -269,3 +281,56 @@ def assign_replica_gpu(
269
281
  if isinstance(gpu_idx, list) and gpu_idx:
270
282
  return gpu_idx[rep_id::replica]
271
283
  return gpu_idx
284
+
285
+
286
+ class CancelMixin:
287
+ _CANCEL_TASK_NAME = "abort_block"
288
+
289
+ def __init__(self):
290
+ self._running_tasks: weakref.WeakValueDictionary[
291
+ str, asyncio.Task
292
+ ] = weakref.WeakValueDictionary()
293
+
294
+ def _add_running_task(self, request_id: Optional[str]):
295
+ """Add current asyncio task to the running task.
296
+ :param request_id: The corresponding request id.
297
+ """
298
+ if request_id is None:
299
+ return
300
+ running_task = self._running_tasks.get(request_id)
301
+ if running_task is not None:
302
+ if running_task.get_name() == self._CANCEL_TASK_NAME:
303
+ raise Exception(f"The request has been aborted: {request_id}")
304
+ raise Exception(f"Duplicate request id: {request_id}")
305
+ current_task = asyncio.current_task()
306
+ assert current_task is not None
307
+ self._running_tasks[request_id] = current_task
308
+
309
+ def _cancel_running_task(
310
+ self,
311
+ request_id: Optional[str],
312
+ block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
313
+ ):
314
+ """Cancel the running asyncio task.
315
+ :param request_id: The request id to cancel.
316
+ :param block_duration: The duration seconds to ensure the request can't be executed.
317
+ """
318
+ if request_id is None:
319
+ return
320
+ running_task = self._running_tasks.pop(request_id, None)
321
+ if running_task is not None:
322
+ running_task.cancel()
323
+
324
+ async def block_task():
325
+ """This task is for blocking the request for a duration."""
326
+ try:
327
+ await asyncio.sleep(block_duration)
328
+ logger.info("Abort block end for request: %s", request_id)
329
+ except asyncio.CancelledError:
330
+ logger.info("Abort block is cancelled for request: %s", request_id)
331
+
332
+ if block_duration > 0:
333
+ logger.info("Abort block start for request: %s", request_id)
334
+ self._running_tasks[request_id] = asyncio.create_task(
335
+ block_task(), name=self._CANCEL_TASK_NAME
336
+ )
@@ -15,6 +15,8 @@
15
15
  import codecs
16
16
  import json
17
17
  import os
18
+ import platform
19
+ import sys
18
20
  import warnings
19
21
  from typing import Any, Dict
20
22
 
@@ -55,6 +57,14 @@ def register_custom_model():
55
57
  warnings.warn(f"{user_defined_audio_dir}/{f} has error, {e}")
56
58
 
57
59
 
60
+ def _need_filter(spec: dict):
61
+ if (sys.platform != "darwin" or platform.processor() != "arm") and spec.get(
62
+ "engine", ""
63
+ ).upper() == "MLX":
64
+ return True
65
+ return False
66
+
67
+
58
68
  def _install():
59
69
  _model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
60
70
  _model_spec_modelscope_json = os.path.join(
@@ -64,6 +74,7 @@ def _install():
64
74
  dict(
65
75
  (spec["model_name"], AudioModelFamilyV1(**spec))
66
76
  for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
77
+ if not _need_filter(spec)
67
78
  )
68
79
  )
69
80
  for model_name, model_spec in BUILTIN_AUDIO_MODELS.items():
@@ -75,6 +86,7 @@ def _install():
75
86
  for spec in json.load(
76
87
  codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
77
88
  )
89
+ if not _need_filter(spec)
78
90
  )
79
91
  )
80
92
  for model_name, model_spec in MODELSCOPE_AUDIO_MODELS.items():
@@ -24,6 +24,7 @@ from .cosyvoice import CosyVoiceModel
24
24
  from .fish_speech import FishSpeechModel
25
25
  from .funasr import FunASRModel
26
26
  from .whisper import WhisperModel
27
+ from .whisper_mlx import WhisperMLXModel
27
28
 
28
29
  logger = logging.getLogger(__name__)
29
30
 
@@ -43,11 +44,12 @@ class AudioModelFamilyV1(CacheableModelSpec):
43
44
  model_family: str
44
45
  model_name: str
45
46
  model_id: str
46
- model_revision: str
47
+ model_revision: Optional[str]
47
48
  multilingual: bool
48
49
  model_ability: Optional[str]
49
50
  default_model_config: Optional[Dict[str, Any]]
50
51
  default_transcription_config: Optional[Dict[str, Any]]
52
+ engine: Optional[str]
51
53
 
52
54
 
53
55
  class AudioModelDescription(ModelDescription):
@@ -160,17 +162,32 @@ def create_audio_model_instance(
160
162
  model_path: Optional[str] = None,
161
163
  **kwargs,
162
164
  ) -> Tuple[
163
- Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel],
165
+ Union[
166
+ WhisperModel,
167
+ WhisperMLXModel,
168
+ FunASRModel,
169
+ ChatTTSModel,
170
+ CosyVoiceModel,
171
+ FishSpeechModel,
172
+ ],
164
173
  AudioModelDescription,
165
174
  ]:
166
175
  model_spec = match_audio(model_name, download_hub)
167
176
  if model_path is None:
168
177
  model_path = cache(model_spec)
169
178
  model: Union[
170
- WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel
179
+ WhisperModel,
180
+ WhisperMLXModel,
181
+ FunASRModel,
182
+ ChatTTSModel,
183
+ CosyVoiceModel,
184
+ FishSpeechModel,
171
185
  ]
172
186
  if model_spec.model_family == "whisper":
173
- model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
187
+ if not model_spec.engine:
188
+ model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
189
+ else:
190
+ model = WhisperMLXModel(model_uid, model_path, model_spec, **kwargs)
174
191
  elif model_spec.model_family == "funasr":
175
192
  model = FunASRModel(model_uid, model_path, model_spec, **kwargs)
176
193
  elif model_spec.model_family == "ChatTTS":
@@ -81,12 +81,14 @@ class FishSpeechModel:
81
81
  if not is_device_available(self._device):
82
82
  raise ValueError(f"Device {self._device} is not available!")
83
83
 
84
- logger.info("Loading Llama model...")
84
+ enable_compile = self._kwargs.get("compile", False)
85
+ precision = self._kwargs.get("precision", torch.bfloat16)
86
+ logger.info("Loading Llama model, compile=%s...", enable_compile)
85
87
  self._llama_queue = launch_thread_safe_queue(
86
88
  checkpoint_path=self._model_path,
87
89
  device=self._device,
88
- precision=torch.bfloat16,
89
- compile=False,
90
+ precision=precision,
91
+ compile=enable_compile,
90
92
  )
91
93
  logger.info("Llama model loaded, loading VQ-GAN model...")
92
94
 
@@ -112,9 +114,10 @@ class FishSpeechModel:
112
114
  top_p,
113
115
  repetition_penalty,
114
116
  temperature,
117
+ seed="0",
115
118
  streaming=False,
116
119
  ):
117
- from fish_speech.utils import autocast_exclude_mps
120
+ from fish_speech.utils import autocast_exclude_mps, set_seed
118
121
  from tools.api import decode_vq_tokens, encode_reference
119
122
  from tools.llama.generate import (
120
123
  GenerateRequest,
@@ -122,6 +125,11 @@ class FishSpeechModel:
122
125
  WrappedGenerateResponse,
123
126
  )
124
127
 
128
+ seed = int(seed)
129
+ if seed != 0:
130
+ set_seed(seed)
131
+ logger.warning(f"set seed: {seed}")
132
+
125
133
  # Parse reference audio aka prompt
126
134
  prompt_tokens = encode_reference(
127
135
  decoder_model=self._model,
@@ -137,7 +145,7 @@ class FishSpeechModel:
137
145
  top_p=top_p,
138
146
  repetition_penalty=repetition_penalty,
139
147
  temperature=temperature,
140
- compile=False,
148
+ compile=self._kwargs.get("compile", False),
141
149
  iterative_prompt=chunk_length > 0,
142
150
  chunk_length=chunk_length,
143
151
  max_length=2048,
@@ -153,22 +161,20 @@ class FishSpeechModel:
153
161
  )
154
162
  )
155
163
 
156
- if streaming:
157
- yield wav_chunk_header(), None, None
158
-
159
164
  segments = []
160
165
 
161
166
  while True:
162
- result: WrappedGenerateResponse = response_queue.get() # type: ignore
167
+ result: WrappedGenerateResponse = response_queue.get()
163
168
  if result.status == "error":
164
- raise Exception(str(result.response))
169
+ raise result.response
165
170
 
166
- result: GenerateResponse = result.response # type: ignore
171
+ result: GenerateResponse = result.response
167
172
  if result.action == "next":
168
173
  break
169
174
 
170
175
  with autocast_exclude_mps(
171
- device_type=self._model.device.type, dtype=torch.bfloat16
176
+ device_type=self._model.device.type,
177
+ dtype=self._kwargs.get("precision", torch.bfloat16),
172
178
  ):
173
179
  fake_audios = decode_vq_tokens(
174
180
  decoder_model=self._model,
@@ -179,7 +185,7 @@ class FishSpeechModel:
179
185
  segments.append(fake_audios)
180
186
 
181
187
  if streaming:
182
- yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
188
+ yield fake_audios, None, None
183
189
 
184
190
  if len(segments) == 0:
185
191
  raise Exception("No audio generated, please check the input text.")
@@ -204,29 +210,58 @@ class FishSpeechModel:
204
210
  logger.warning("Fish speech does not support setting voice: %s.", voice)
205
211
  if speed != 1.0:
206
212
  logger.warning("Fish speech does not support setting speed: %s.", speed)
207
- if stream is True:
208
- logger.warning("stream mode is not implemented.")
209
213
  import torchaudio
210
214
 
211
- result = list(
212
- self._inference(
213
- text=input,
214
- enable_reference_audio=False,
215
- reference_audio=None,
216
- reference_text=kwargs.get("reference_text", ""),
217
- max_new_tokens=kwargs.get("max_new_tokens", 1024),
218
- chunk_length=kwargs.get("chunk_length", 200),
219
- top_p=kwargs.get("top_p", 0.7),
220
- repetition_penalty=kwargs.get("repetition_penalty", 1.2),
221
- temperature=kwargs.get("temperature", 0.7),
222
- )
215
+ prompt_speech = kwargs.get("prompt_speech")
216
+ result = self._inference(
217
+ text=input,
218
+ enable_reference_audio=kwargs.get(
219
+ "enable_reference_audio", prompt_speech is not None
220
+ ),
221
+ reference_audio=prompt_speech,
222
+ reference_text=kwargs.get("reference_text", ""),
223
+ max_new_tokens=kwargs.get("max_new_tokens", 1024),
224
+ chunk_length=kwargs.get("chunk_length", 200),
225
+ top_p=kwargs.get("top_p", 0.7),
226
+ repetition_penalty=kwargs.get("repetition_penalty", 1.2),
227
+ temperature=kwargs.get("temperature", 0.7),
228
+ streaming=stream,
223
229
  )
224
- sample_rate, audio = result[0][1]
225
- audio = np.array([audio])
226
230
 
227
- # Save the generated audio
228
- with BytesIO() as out:
229
- torchaudio.save(
230
- out, torch.from_numpy(audio), sample_rate, format=response_format
231
- )
232
- return out.getvalue()
231
+ if stream:
232
+
233
+ def _stream_generator():
234
+ with BytesIO() as out:
235
+ writer = torchaudio.io.StreamWriter(out, format=response_format)
236
+ writer.add_audio_stream(
237
+ sample_rate=self._model.spec_transform.sample_rate,
238
+ num_channels=1,
239
+ )
240
+ i = 0
241
+ last_pos = 0
242
+ with writer.open():
243
+ for chunk in result:
244
+ chunk = chunk[0]
245
+ if chunk is not None:
246
+ chunk = chunk.reshape((chunk.shape[0], 1))
247
+ trans_chunk = torch.from_numpy(chunk)
248
+ writer.write_audio_chunk(i, trans_chunk)
249
+ new_last_pos = out.tell()
250
+ if new_last_pos != last_pos:
251
+ out.seek(last_pos)
252
+ encoded_bytes = out.read()
253
+ yield encoded_bytes
254
+ last_pos = new_last_pos
255
+
256
+ return _stream_generator()
257
+ else:
258
+ result = list(result)
259
+ sample_rate, audio = result[0][1]
260
+ audio = np.array([audio])
261
+
262
+ # Save the generated audio
263
+ with BytesIO() as out:
264
+ torchaudio.save(
265
+ out, torch.from_numpy(audio), sample_rate, format=response_format
266
+ )
267
+ return out.getvalue()
@@ -103,6 +103,86 @@
103
103
  "model_ability": "audio-to-text",
104
104
  "multilingual": false
105
105
  },
106
+ {
107
+ "model_name": "whisper-tiny-mlx",
108
+ "model_family": "whisper",
109
+ "model_id": "mlx-community/whisper-tiny",
110
+ "model_ability": "audio-to-text",
111
+ "multilingual": true,
112
+ "engine": "mlx"
113
+ },
114
+ {
115
+ "model_name": "whisper-tiny.en-mlx",
116
+ "model_family": "whisper",
117
+ "model_id": "mlx-community/whisper-tiny.en-mlx",
118
+ "model_ability": "audio-to-text",
119
+ "multilingual": false,
120
+ "engine": "mlx"
121
+ },
122
+ {
123
+ "model_name": "whisper-base-mlx",
124
+ "model_family": "whisper",
125
+ "model_id": "mlx-community/whisper-base-mlx",
126
+ "model_ability": "audio-to-text",
127
+ "multilingual": true,
128
+ "engine": "mlx"
129
+ },
130
+ {
131
+ "model_name": "whisper-base.en-mlx",
132
+ "model_family": "whisper",
133
+ "model_id": "mlx-community/whisper-base.en-mlx",
134
+ "model_ability": "audio-to-text",
135
+ "multilingual": false,
136
+ "engine": "mlx"
137
+ },
138
+ {
139
+ "model_name": "whisper-small-mlx",
140
+ "model_family": "whisper",
141
+ "model_id": "mlx-community/whisper-small-mlx",
142
+ "model_ability": "audio-to-text",
143
+ "multilingual": true,
144
+ "engine": "mlx"
145
+ },
146
+ {
147
+ "model_name": "whisper-small.en-mlx",
148
+ "model_family": "whisper",
149
+ "model_id": "mlx-community/whisper-small.en-mlx",
150
+ "model_ability": "audio-to-text",
151
+ "multilingual": false,
152
+ "engine": "mlx"
153
+ },
154
+ {
155
+ "model_name": "whisper-medium-mlx",
156
+ "model_family": "whisper",
157
+ "model_id": "mlx-community/whisper-medium-mlx",
158
+ "model_ability": "audio-to-text",
159
+ "multilingual": true,
160
+ "engine": "mlx"
161
+ },
162
+ {
163
+ "model_name": "whisper-medium.en-mlx",
164
+ "model_family": "whisper",
165
+ "model_id": "mlx-community/whisper-medium.en-mlx",
166
+ "model_ability": "audio-to-text",
167
+ "multilingual": false,
168
+ "engine": "mlx"
169
+ },
170
+ {
171
+ "model_name": "whisper-large-v3-mlx",
172
+ "model_family": "whisper",
173
+ "model_id": "mlx-community/whisper-large-v3-mlx",
174
+ "model_ability": "audio-to-text",
175
+ "multilingual": true,
176
+ "engine": "mlx"
177
+ },
178
+ {
179
+ "model_name": "whisper-large-v3-turbo-mlx",
180
+ "model_family": "whisper",
181
+ "model_id": "mlx-community/whisper-large-v3-turbo",
182
+ "model_ability": "audio-to-text",
183
+ "multilingual": true,
184
+ "engine": "mlx"
185
+ },
106
186
  {
107
187
  "model_name": "SenseVoiceSmall",
108
188
  "model_family": "funasr",
@@ -159,7 +239,7 @@
159
239
  "model_name": "FishSpeech-1.4",
160
240
  "model_family": "FishAudio",
161
241
  "model_id": "fishaudio/fish-speech-1.4",
162
- "model_revision": "3c49651b8e583b6b13f55e375432e0d57e1aa84d",
242
+ "model_revision": "069c573759936b35191d3380deb89183c0656f59",
163
243
  "model_ability": "text-to-audio",
164
244
  "multilingual": true
165
245
  }
@@ -0,0 +1,208 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import functools
15
+ import itertools
16
+ import logging
17
+ import tempfile
18
+ from typing import TYPE_CHECKING, List, Optional
19
+
20
+ if TYPE_CHECKING:
21
+ from .core import AudioModelFamilyV1
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class WhisperMLXModel:
27
+ def __init__(
28
+ self,
29
+ model_uid: str,
30
+ model_path: str,
31
+ model_spec: "AudioModelFamilyV1",
32
+ device: Optional[str] = None,
33
+ **kwargs,
34
+ ):
35
+ self._model_uid = model_uid
36
+ self._model_path = model_path
37
+ self._model_spec = model_spec
38
+ self._device = device
39
+ self._model = None
40
+ self._kwargs = kwargs
41
+ self._use_lighting = False
42
+
43
+ @property
44
+ def model_ability(self):
45
+ return self._model_spec.model_ability
46
+
47
+ def load(self):
48
+ use_lightning = self._kwargs.get("use_lightning", "auto")
49
+ if use_lightning not in ("auto", True, False, None):
50
+ raise ValueError("use_lightning can only be True, False, None or auto")
51
+
52
+ if use_lightning == "auto" or use_lightning is True:
53
+ try:
54
+ import mlx.core as mx
55
+ from lightning_whisper_mlx.transcribe import ModelHolder
56
+ except ImportError:
57
+ if use_lightning == "auto":
58
+ use_lightning = False
59
+ else:
60
+ error_message = "Failed to import module 'lightning_whisper_mlx'"
61
+ installation_guide = [
62
+ "Please make sure 'lightning_whisper_mlx' is installed.\n",
63
+ ]
64
+
65
+ raise ImportError(
66
+ f"{error_message}\n\n{''.join(installation_guide)}"
67
+ )
68
+ else:
69
+ use_lightning = True
70
+ if not use_lightning:
71
+ try:
72
+ import mlx.core as mx # noqa: F811
73
+ from mlx_whisper.transcribe import ModelHolder # noqa: F811
74
+ except ImportError:
75
+ error_message = "Failed to import module 'mlx_whisper'"
76
+ installation_guide = [
77
+ "Please make sure 'mlx_whisper' is installed.\n",
78
+ ]
79
+
80
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
81
+ else:
82
+ use_lightning = False
83
+
84
+ logger.info(
85
+ "Loading MLX whisper from %s, use lightning: %s",
86
+ self._model_path,
87
+ use_lightning,
88
+ )
89
+ self._use_lighting = use_lightning
90
+ self._model = ModelHolder.get_model(self._model_path, mx.float16)
91
+
92
+ def transcriptions(
93
+ self,
94
+ audio: bytes,
95
+ language: Optional[str] = None,
96
+ prompt: Optional[str] = None,
97
+ response_format: str = "json",
98
+ temperature: float = 0,
99
+ timestamp_granularities: Optional[List[str]] = None,
100
+ ):
101
+ return self._call(
102
+ audio,
103
+ language=language,
104
+ prompt=prompt,
105
+ response_format=response_format,
106
+ temperature=temperature,
107
+ timestamp_granularities=timestamp_granularities,
108
+ task="transcribe",
109
+ )
110
+
111
+ def translations(
112
+ self,
113
+ audio: bytes,
114
+ language: Optional[str] = None,
115
+ prompt: Optional[str] = None,
116
+ response_format: str = "json",
117
+ temperature: float = 0,
118
+ timestamp_granularities: Optional[List[str]] = None,
119
+ ):
120
+ if not self._model_spec.multilingual:
121
+ raise RuntimeError(
122
+ f"Model {self._model_spec.model_name} is not suitable for translations."
123
+ )
124
+ return self._call(
125
+ audio,
126
+ language=language,
127
+ prompt=prompt,
128
+ response_format=response_format,
129
+ temperature=temperature,
130
+ timestamp_granularities=timestamp_granularities,
131
+ task="translate",
132
+ )
133
+
134
+ def _call(
135
+ self,
136
+ audio: bytes,
137
+ language: Optional[str] = None,
138
+ prompt: Optional[str] = None,
139
+ response_format: str = "json",
140
+ temperature: float = 0,
141
+ timestamp_granularities: Optional[List[str]] = None,
142
+ task: str = "transcribe",
143
+ ):
144
+ if self._use_lighting:
145
+ from lightning_whisper_mlx.transcribe import transcribe_audio
146
+
147
+ transcribe = functools.partial(
148
+ transcribe_audio, batch_size=self._kwargs.get("batch_size", 12)
149
+ )
150
+ else:
151
+ from mlx_whisper import transcribe # type: ignore
152
+
153
+ with tempfile.NamedTemporaryFile(delete=True) as f:
154
+ f.write(audio)
155
+
156
+ kwargs = {"task": task}
157
+ if response_format == "verbose_json":
158
+ if timestamp_granularities == ["word"]:
159
+ kwargs["word_timestamps"] = True # type: ignore
160
+
161
+ result = transcribe(
162
+ f.name,
163
+ path_or_hf_repo=self._model_path,
164
+ language=language,
165
+ temperature=temperature,
166
+ initial_prompt=prompt,
167
+ **kwargs,
168
+ )
169
+ text = result["text"]
170
+ segments = result["segments"]
171
+ language = result["language"]
172
+
173
+ if response_format == "json":
174
+ return {"text": text}
175
+ elif response_format == "verbose_json":
176
+ if not timestamp_granularities or timestamp_granularities == [
177
+ "segment"
178
+ ]:
179
+ return {
180
+ "task": task,
181
+ "language": language,
182
+ "duration": segments[-1]["end"] if segments else 0,
183
+ "text": text,
184
+ "segments": segments,
185
+ }
186
+ else:
187
+ assert timestamp_granularities == ["word"]
188
+
189
+ def _extract_word(word: dict) -> dict:
190
+ return {
191
+ "start": word["start"].item(),
192
+ "end": word["end"].item(),
193
+ "word": word["word"],
194
+ }
195
+
196
+ words = [
197
+ _extract_word(w)
198
+ for w in itertools.chain(*[s["words"] for s in segments])
199
+ ]
200
+ return {
201
+ "task": task,
202
+ "language": language,
203
+ "duration": words[-1]["end"] if words else 0,
204
+ "text": text,
205
+ "words": words,
206
+ }
207
+ else:
208
+ raise ValueError(f"Unsupported response format: {response_format}")