xinference 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (94) hide show
  1. xinference/_compat.py +22 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +91 -6
  4. xinference/client/restful/restful_client.py +39 -0
  5. xinference/core/model.py +41 -13
  6. xinference/deploy/cmdline.py +3 -1
  7. xinference/deploy/test/test_cmdline.py +56 -0
  8. xinference/isolation.py +24 -0
  9. xinference/model/audio/__init__.py +12 -0
  10. xinference/model/audio/core.py +26 -4
  11. xinference/model/audio/f5tts.py +195 -0
  12. xinference/model/audio/fish_speech.py +71 -35
  13. xinference/model/audio/model_spec.json +88 -0
  14. xinference/model/audio/model_spec_modelscope.json +9 -0
  15. xinference/model/audio/whisper_mlx.py +208 -0
  16. xinference/model/embedding/core.py +322 -6
  17. xinference/model/embedding/model_spec.json +8 -1
  18. xinference/model/embedding/model_spec_modelscope.json +9 -1
  19. xinference/model/llm/__init__.py +4 -2
  20. xinference/model/llm/llm_family.json +479 -53
  21. xinference/model/llm/llm_family_modelscope.json +423 -17
  22. xinference/model/llm/mlx/core.py +230 -50
  23. xinference/model/llm/sglang/core.py +2 -0
  24. xinference/model/llm/transformers/chatglm.py +9 -5
  25. xinference/model/llm/transformers/core.py +1 -0
  26. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  27. xinference/model/llm/transformers/utils.py +16 -8
  28. xinference/model/llm/utils.py +23 -1
  29. xinference/model/llm/vllm/core.py +89 -2
  30. xinference/thirdparty/f5_tts/__init__.py +0 -0
  31. xinference/thirdparty/f5_tts/api.py +166 -0
  32. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  33. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  34. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  35. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  36. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  37. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  38. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  39. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  40. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  41. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  42. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  43. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  44. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  45. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  46. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  47. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  48. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  49. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  50. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  51. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  52. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  53. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  54. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  55. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  56. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  57. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  58. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  59. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  60. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  61. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  62. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  63. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  64. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  65. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  66. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  67. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  68. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  69. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  70. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  71. xinference/thirdparty/f5_tts/train/README.md +77 -0
  72. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  73. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  74. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  75. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  76. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  77. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  78. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  79. xinference/thirdparty/f5_tts/train/train.py +75 -0
  80. xinference/types.py +2 -1
  81. xinference/web/ui/build/asset-manifest.json +3 -3
  82. xinference/web/ui/build/index.html +1 -1
  83. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  84. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  86. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/METADATA +39 -18
  87. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/RECORD +92 -39
  88. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/WHEEL +1 -1
  89. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  91. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  92. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/LICENSE +0 -0
  93. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/entry_points.txt +0 -0
  94. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/top_level.txt +0 -0
xinference/_compat.py CHANGED
@@ -60,6 +60,10 @@ from openai.types.chat.chat_completion_stream_options_param import (
60
60
  ChatCompletionStreamOptionsParam,
61
61
  )
62
62
  from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
63
+ from openai.types.shared_params.response_format_json_object import (
64
+ ResponseFormatJSONObject,
65
+ )
66
+ from openai.types.shared_params.response_format_text import ResponseFormatText
63
67
 
64
68
  OpenAIChatCompletionStreamOptionsParam = create_model_from_typeddict(
65
69
  ChatCompletionStreamOptionsParam
@@ -70,6 +74,23 @@ OpenAIChatCompletionNamedToolChoiceParam = create_model_from_typeddict(
70
74
  )
71
75
 
72
76
 
77
+ class JSONSchema(BaseModel):
78
+ name: str
79
+ description: Optional[str] = None
80
+ schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
81
+ strict: Optional[bool] = None
82
+
83
+
84
+ class ResponseFormatJSONSchema(BaseModel):
85
+ json_schema: JSONSchema
86
+ type: Literal["json_schema"]
87
+
88
+
89
+ ResponseFormat = Union[
90
+ ResponseFormatText, ResponseFormatJSONObject, ResponseFormatJSONSchema
91
+ ]
92
+
93
+
73
94
  class CreateChatCompletionOpenAI(BaseModel):
74
95
  """
75
96
  Comes from source code: https://github.com/openai/openai-python/blob/main/src/openai/types/chat/completion_create_params.py
@@ -84,8 +105,7 @@ class CreateChatCompletionOpenAI(BaseModel):
84
105
  n: Optional[int]
85
106
  parallel_tool_calls: Optional[bool]
86
107
  presence_penalty: Optional[float]
87
- # we do not support this
88
- # response_format: ResponseFormat
108
+ response_format: Optional[ResponseFormat]
89
109
  seed: Optional[int]
90
110
  service_tier: Optional[Literal["auto", "default"]]
91
111
  stop: Union[Optional[str], List[str]]
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-11-15T17:33:11+0800",
11
+ "date": "2024-12-13T18:21:03+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "4c96475b8f90e354aa1b47856fda4db098b62b65",
15
- "version": "1.0.0"
14
+ "full-revisionid": "b132fca91f3e1b11b111f9b89f68a55e4b7605c6",
15
+ "version": "1.1.0"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -94,9 +94,9 @@ class CreateCompletionRequest(CreateCompletion):
94
94
 
95
95
  class CreateEmbeddingRequest(BaseModel):
96
96
  model: str
97
- input: Union[str, List[str], List[int], List[List[int]]] = Field(
98
- description="The input to embed."
99
- )
97
+ input: Union[
98
+ str, List[str], List[int], List[List[int]], Dict[str, str], List[Dict[str, str]]
99
+ ] = Field(description="The input to embed.")
100
100
  user: Optional[str] = None
101
101
 
102
102
  class Config:
@@ -489,6 +489,16 @@ class RESTfulAPI(CancelMixin):
489
489
  else None
490
490
  ),
491
491
  )
492
+ self._router.add_api_route(
493
+ "/v1/convert_ids_to_tokens",
494
+ self.convert_ids_to_tokens,
495
+ methods=["POST"],
496
+ dependencies=(
497
+ [Security(self._auth_service, scopes=["models:read"])]
498
+ if self.is_authenticated()
499
+ else None
500
+ ),
501
+ )
492
502
  self._router.add_api_route(
493
503
  "/v1/rerank",
494
504
  self.rerank,
@@ -1219,6 +1229,9 @@ class RESTfulAPI(CancelMixin):
1219
1229
  raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
1220
1230
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
1221
1231
 
1232
+ # guided_decoding params
1233
+ kwargs.update(self.extract_guided_params(raw_body=raw_body))
1234
+
1222
1235
  # TODO: Decide if this default value override is necessary #1061
1223
1236
  if body.max_tokens is None:
1224
1237
  kwargs["max_tokens"] = max_tokens_field.default
@@ -1264,6 +1277,8 @@ class RESTfulAPI(CancelMixin):
1264
1277
  # https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
1265
1278
  yield dict(data=json.dumps({"error": str(ex)}))
1266
1279
  return
1280
+ finally:
1281
+ await model.decrease_serve_count()
1267
1282
 
1268
1283
  return EventSourceResponse(stream_results())
1269
1284
  else:
@@ -1312,6 +1327,41 @@ class RESTfulAPI(CancelMixin):
1312
1327
  await self._report_error_event(model_uid, str(e))
1313
1328
  raise HTTPException(status_code=500, detail=str(e))
1314
1329
 
1330
+ async def convert_ids_to_tokens(self, request: Request) -> Response:
1331
+ payload = await request.json()
1332
+ body = CreateEmbeddingRequest.parse_obj(payload)
1333
+ model_uid = body.model
1334
+ exclude = {
1335
+ "model",
1336
+ "input",
1337
+ "user",
1338
+ }
1339
+ kwargs = {key: value for key, value in payload.items() if key not in exclude}
1340
+
1341
+ try:
1342
+ model = await (await self._get_supervisor_ref()).get_model(model_uid)
1343
+ except ValueError as ve:
1344
+ logger.error(str(ve), exc_info=True)
1345
+ await self._report_error_event(model_uid, str(ve))
1346
+ raise HTTPException(status_code=400, detail=str(ve))
1347
+ except Exception as e:
1348
+ logger.error(e, exc_info=True)
1349
+ await self._report_error_event(model_uid, str(e))
1350
+ raise HTTPException(status_code=500, detail=str(e))
1351
+
1352
+ try:
1353
+ decoded_texts = await model.convert_ids_to_tokens(body.input, **kwargs)
1354
+ return Response(decoded_texts, media_type="application/json")
1355
+ except RuntimeError as re:
1356
+ logger.error(re, exc_info=True)
1357
+ await self._report_error_event(model_uid, str(re))
1358
+ self.handle_request_limit_error(re)
1359
+ raise HTTPException(status_code=400, detail=str(re))
1360
+ except Exception as e:
1361
+ logger.error(e, exc_info=True)
1362
+ await self._report_error_event(model_uid, str(e))
1363
+ raise HTTPException(status_code=500, detail=str(e))
1364
+
1315
1365
  async def rerank(self, request: Request) -> Response:
1316
1366
  payload = await request.json()
1317
1367
  body = RerankRequest.parse_obj(payload)
@@ -1495,8 +1545,16 @@ class RESTfulAPI(CancelMixin):
1495
1545
  **parsed_kwargs,
1496
1546
  )
1497
1547
  if body.stream:
1548
+
1549
+ async def stream_results():
1550
+ try:
1551
+ async for item in out:
1552
+ yield item
1553
+ finally:
1554
+ await model.decrease_serve_count()
1555
+
1498
1556
  return EventSourceResponse(
1499
- media_type="application/octet-stream", content=out
1557
+ media_type="application/octet-stream", content=stream_results()
1500
1558
  )
1501
1559
  else:
1502
1560
  return Response(media_type="application/octet-stream", content=out)
@@ -1916,9 +1974,13 @@ class RESTfulAPI(CancelMixin):
1916
1974
  "logit_bias_type",
1917
1975
  "user",
1918
1976
  }
1977
+
1919
1978
  raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
1920
1979
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
1921
1980
 
1981
+ # guided_decoding params
1982
+ kwargs.update(self.extract_guided_params(raw_body=raw_body))
1983
+
1922
1984
  # TODO: Decide if this default value override is necessary #1061
1923
1985
  if body.max_tokens is None:
1924
1986
  kwargs["max_tokens"] = max_tokens_field.default
@@ -1982,7 +2044,6 @@ class RESTfulAPI(CancelMixin):
1982
2044
  )
1983
2045
  if body.tools and body.stream:
1984
2046
  is_vllm = await model.is_vllm_backend()
1985
-
1986
2047
  if not (
1987
2048
  (is_vllm and model_family in QWEN_TOOL_CALL_FAMILY)
1988
2049
  or (not is_vllm and model_family in GLM4_TOOL_CALL_FAMILY)
@@ -1992,7 +2053,8 @@ class RESTfulAPI(CancelMixin):
1992
2053
  detail="Streaming support for tool calls is available only when using "
1993
2054
  "Qwen models with vLLM backend or GLM4-chat models without vLLM backend.",
1994
2055
  )
1995
-
2056
+ if "skip_special_tokens" in raw_kwargs and await model.is_vllm_backend():
2057
+ kwargs["skip_special_tokens"] = raw_kwargs["skip_special_tokens"]
1996
2058
  if body.stream:
1997
2059
 
1998
2060
  async def stream_results():
@@ -2027,6 +2089,8 @@ class RESTfulAPI(CancelMixin):
2027
2089
  # https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
2028
2090
  yield dict(data=json.dumps({"error": str(ex)}))
2029
2091
  return
2092
+ finally:
2093
+ await model.decrease_serve_count()
2030
2094
 
2031
2095
  return EventSourceResponse(stream_results())
2032
2096
  else:
@@ -2279,6 +2343,27 @@ class RESTfulAPI(CancelMixin):
2279
2343
  logger.error(e, exc_info=True)
2280
2344
  raise HTTPException(status_code=500, detail=str(e))
2281
2345
 
2346
+ @staticmethod
2347
+ def extract_guided_params(raw_body: dict) -> dict:
2348
+ kwargs = {}
2349
+ if raw_body.get("guided_json") is not None:
2350
+ kwargs["guided_json"] = raw_body.get("guided_json")
2351
+ if raw_body.get("guided_regex") is not None:
2352
+ kwargs["guided_regex"] = raw_body.get("guided_regex")
2353
+ if raw_body.get("guided_choice") is not None:
2354
+ kwargs["guided_choice"] = raw_body.get("guided_choice")
2355
+ if raw_body.get("guided_grammar") is not None:
2356
+ kwargs["guided_grammar"] = raw_body.get("guided_grammar")
2357
+ if raw_body.get("guided_json_object") is not None:
2358
+ kwargs["guided_json_object"] = raw_body.get("guided_json_object")
2359
+ if raw_body.get("guided_decoding_backend") is not None:
2360
+ kwargs["guided_decoding_backend"] = raw_body.get("guided_decoding_backend")
2361
+ if raw_body.get("guided_whitespace_pattern") is not None:
2362
+ kwargs["guided_whitespace_pattern"] = raw_body.get(
2363
+ "guided_whitespace_pattern"
2364
+ )
2365
+ return kwargs
2366
+
2282
2367
 
2283
2368
  def run(
2284
2369
  supervisor_address: str,
@@ -126,6 +126,43 @@ class RESTfulEmbeddingModelHandle(RESTfulModelHandle):
126
126
  response_data = response.json()
127
127
  return response_data
128
128
 
129
+ def convert_ids_to_tokens(
130
+ self, input: Union[List, List[List]], **kwargs
131
+ ) -> List[str]:
132
+ """
133
+ Convert token IDs to human readable tokens via RESTful APIs.
134
+
135
+ Parameters
136
+ ----------
137
+ input: Union[List, List[List]]
138
+ Input token IDs to convert, can be a single list of token IDs or a list of token ID lists.
139
+ To convert multiple sequences in a single request, pass a list of token ID lists.
140
+
141
+ Returns
142
+ -------
143
+ list
144
+ A list of decoded tokens in human readable format.
145
+
146
+ Raises
147
+ ------
148
+ RuntimeError
149
+ Report the failure of token conversion and provide the error message.
150
+
151
+ """
152
+ url = f"{self._base_url}/v1/convert_ids_to_tokens"
153
+ request_body = {
154
+ "model": self._model_uid,
155
+ "input": input,
156
+ }
157
+ request_body.update(kwargs)
158
+ response = requests.post(url, json=request_body, headers=self.auth_headers)
159
+ if response.status_code != 200:
160
+ raise RuntimeError(
161
+ f"Failed to decode token ids, detail: {_get_error_string(response)}"
162
+ )
163
+ response_data = response.json()
164
+ return response_data
165
+
129
166
 
130
167
  class RESTfulRerankModelHandle(RESTfulModelHandle):
131
168
  def rerank(
@@ -704,6 +741,8 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
704
741
  The speed of the generated audio.
705
742
  stream: bool
706
743
  Use stream or not.
744
+ prompt_speech: bytes
745
+ The audio bytes to be provided to the model.
707
746
 
708
747
  Returns
709
748
  -------
xinference/core/model.py CHANGED
@@ -78,6 +78,7 @@ XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = [
78
78
  ]
79
79
 
80
80
  XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS = ["FLUX.1-dev", "FLUX.1-schnell"]
81
+ XINFERENCE_BATCHING_BLACK_LIST = ["glm4-chat"]
81
82
 
82
83
 
83
84
  def request_limit(fn):
@@ -91,21 +92,26 @@ def request_limit(fn):
91
92
  logger.debug(
92
93
  f"Request {fn.__name__}, current serve request count: {self._serve_count}, request limit: {self._request_limits} for the model {self.model_uid()}"
93
94
  )
94
- if self._request_limits is not None:
95
- if 1 + self._serve_count <= self._request_limits:
96
- self._serve_count += 1
97
- else:
98
- raise RuntimeError(
99
- f"Rate limit reached for the model. Request limit {self._request_limits} for the model: {self.model_uid()}"
100
- )
95
+ if 1 + self._serve_count <= self._request_limits:
96
+ self._serve_count += 1
97
+ else:
98
+ raise RuntimeError(
99
+ f"Rate limit reached for the model. Request limit {self._request_limits} for the model: {self.model_uid()}"
100
+ )
101
+ ret = None
101
102
  try:
102
103
  ret = await fn(self, *args, **kwargs)
103
104
  finally:
104
- if self._request_limits is not None:
105
+ if ret is not None and (
106
+ inspect.isasyncgen(ret) or inspect.isgenerator(ret)
107
+ ):
108
+ # stream case, let client call model_ref to decrease self._serve_count
109
+ pass
110
+ else:
105
111
  self._serve_count -= 1
106
- logger.debug(
107
- f"After request {fn.__name__}, current serve request count: {self._serve_count} for the model {self.model_uid()}"
108
- )
112
+ logger.debug(
113
+ f"After request {fn.__name__}, current serve request count: {self._serve_count} for the model {self.model_uid()}"
114
+ )
109
115
  return ret
110
116
 
111
117
  return wrapped_func
@@ -215,7 +221,9 @@ class ModelActor(xo.StatelessActor, CancelMixin):
215
221
  self._model_description = (
216
222
  model_description.to_dict() if model_description else {}
217
223
  )
218
- self._request_limits = request_limits
224
+ self._request_limits = (
225
+ float("inf") if request_limits is None else request_limits
226
+ )
219
227
  self._pending_requests: asyncio.Queue = asyncio.Queue()
220
228
  self._handle_pending_requests_task = None
221
229
  self._lock = (
@@ -268,6 +276,9 @@ class ModelActor(xo.StatelessActor, CancelMixin):
268
276
  def __repr__(self) -> str:
269
277
  return f"ModelActor({self._replica_model_uid})"
270
278
 
279
+ def decrease_serve_count(self):
280
+ self._serve_count -= 1
281
+
271
282
  async def _record_completion_metrics(
272
283
  self, duration, completion_tokens, prompt_tokens
273
284
  ):
@@ -362,7 +373,11 @@ class ModelActor(xo.StatelessActor, CancelMixin):
362
373
  f"Your model {self._model.model_family.model_name} with model family {self._model.model_family.model_family} is disqualified."
363
374
  )
364
375
  return False
365
- return condition
376
+ return (
377
+ condition
378
+ and self._model.model_family.model_name
379
+ not in XINFERENCE_BATCHING_BLACK_LIST
380
+ )
366
381
 
367
382
  def allow_batching_for_text_to_image(self) -> bool:
368
383
  from ..model.image.stable_diffusion.core import DiffusionModel
@@ -794,6 +809,19 @@ class ModelActor(xo.StatelessActor, CancelMixin):
794
809
  f"Model {self._model.model_spec} is not for creating embedding."
795
810
  )
796
811
 
812
+ @request_limit
813
+ @log_async(logger=logger)
814
+ async def convert_ids_to_tokens(
815
+ self, input: Union[List, List[List]], *args, **kwargs
816
+ ):
817
+ kwargs.pop("request_id", None)
818
+ if hasattr(self._model, "convert_ids_to_tokens"):
819
+ return await self._call_wrapper_json(
820
+ self._model.convert_ids_to_tokens, input, *args, **kwargs
821
+ )
822
+
823
+ raise AttributeError(f"Model {self._model.model_spec} can convert token id.")
824
+
797
825
  @request_limit
798
826
  @log_async(logger=logger)
799
827
  async def rerank(
@@ -846,7 +846,9 @@ def model_launch(
846
846
  kwargs = {}
847
847
  for i in range(0, len(ctx.args), 2):
848
848
  if not ctx.args[i].startswith("--"):
849
- raise ValueError("You must specify extra kwargs with `--` prefix.")
849
+ raise ValueError(
850
+ f"You must specify extra kwargs with `--` prefix. There is an error in parameter passing that is {ctx.args[i]}."
851
+ )
850
852
  kwargs[ctx.args[i][2:]] = handle_click_args_type(ctx.args[i + 1])
851
853
  print(f"Launch model name: {model_name} with kwargs: {kwargs}", file=sys.stderr)
852
854
 
@@ -23,6 +23,7 @@ from ..cmdline import (
23
23
  list_model_registrations,
24
24
  model_chat,
25
25
  model_generate,
26
+ model_launch,
26
27
  model_list,
27
28
  model_terminate,
28
29
  register_model,
@@ -311,3 +312,58 @@ def test_remove_cache(setup):
311
312
 
312
313
  assert result.exit_code == 0
313
314
  assert "Cache directory qwen1.5-chat has been deleted."
315
+
316
+
317
+ def test_launch_error_in_passing_parameters():
318
+ runner = CliRunner()
319
+
320
+ # Known parameter but not provided with value.
321
+ result = runner.invoke(
322
+ model_launch,
323
+ [
324
+ "--model-engine",
325
+ "transformers",
326
+ "--model-name",
327
+ "qwen2.5-instruct",
328
+ "--model-uid",
329
+ "-s",
330
+ "0.5",
331
+ "-f",
332
+ "gptq",
333
+ "-q",
334
+ "INT4",
335
+ "111",
336
+ "-l",
337
+ ],
338
+ )
339
+ assert result.exit_code == 1
340
+ assert (
341
+ "You must specify extra kwargs with `--` prefix. There is an error in parameter passing that is 0.5."
342
+ in str(result)
343
+ )
344
+
345
+ # Unknown parameter
346
+ result = runner.invoke(
347
+ model_launch,
348
+ [
349
+ "--model-engine",
350
+ "transformers",
351
+ "--model-name",
352
+ "qwen2.5-instruct",
353
+ "--model-uid",
354
+ "123",
355
+ "-s",
356
+ "0.5",
357
+ "-f",
358
+ "gptq",
359
+ "-q",
360
+ "INT4",
361
+ "-l",
362
+ "111",
363
+ ],
364
+ )
365
+ assert result.exit_code == 1
366
+ assert (
367
+ "You must specify extra kwargs with `--` prefix. There is an error in parameter passing that is -l."
368
+ in str(result)
369
+ )
xinference/isolation.py CHANGED
@@ -37,6 +37,30 @@ class Isolation:
37
37
  asyncio.set_event_loop(self._loop)
38
38
  self._stopped = asyncio.Event()
39
39
  self._loop.run_until_complete(self._stopped.wait())
40
+ self._cancel_all_tasks(self._loop)
41
+
42
+ @staticmethod
43
+ def _cancel_all_tasks(loop):
44
+ to_cancel = asyncio.all_tasks(loop)
45
+ if not to_cancel:
46
+ return
47
+
48
+ for task in to_cancel:
49
+ task.cancel()
50
+
51
+ loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
52
+
53
+ for task in to_cancel:
54
+ if task.cancelled():
55
+ continue
56
+ if task.exception() is not None:
57
+ loop.call_exception_handler(
58
+ {
59
+ "message": "unhandled exception during asyncio.run() shutdown",
60
+ "exception": task.exception(),
61
+ "task": task,
62
+ }
63
+ )
40
64
 
41
65
  def start(self):
42
66
  if self._threaded:
@@ -15,6 +15,8 @@
15
15
  import codecs
16
16
  import json
17
17
  import os
18
+ import platform
19
+ import sys
18
20
  import warnings
19
21
  from typing import Any, Dict
20
22
 
@@ -55,6 +57,14 @@ def register_custom_model():
55
57
  warnings.warn(f"{user_defined_audio_dir}/{f} has error, {e}")
56
58
 
57
59
 
60
+ def _need_filter(spec: dict):
61
+ if (sys.platform != "darwin" or platform.processor() != "arm") and spec.get(
62
+ "engine", ""
63
+ ).upper() == "MLX":
64
+ return True
65
+ return False
66
+
67
+
58
68
  def _install():
59
69
  _model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
60
70
  _model_spec_modelscope_json = os.path.join(
@@ -64,6 +74,7 @@ def _install():
64
74
  dict(
65
75
  (spec["model_name"], AudioModelFamilyV1(**spec))
66
76
  for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
77
+ if not _need_filter(spec)
67
78
  )
68
79
  )
69
80
  for model_name, model_spec in BUILTIN_AUDIO_MODELS.items():
@@ -75,6 +86,7 @@ def _install():
75
86
  for spec in json.load(
76
87
  codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
77
88
  )
89
+ if not _need_filter(spec)
78
90
  )
79
91
  )
80
92
  for model_name, model_spec in MODELSCOPE_AUDIO_MODELS.items():
@@ -21,9 +21,11 @@ from ..core import CacheableModelSpec, ModelDescription
21
21
  from ..utils import valid_model_revision
22
22
  from .chattts import ChatTTSModel
23
23
  from .cosyvoice import CosyVoiceModel
24
+ from .f5tts import F5TTSModel
24
25
  from .fish_speech import FishSpeechModel
25
26
  from .funasr import FunASRModel
26
27
  from .whisper import WhisperModel
28
+ from .whisper_mlx import WhisperMLXModel
27
29
 
28
30
  logger = logging.getLogger(__name__)
29
31
 
@@ -43,11 +45,12 @@ class AudioModelFamilyV1(CacheableModelSpec):
43
45
  model_family: str
44
46
  model_name: str
45
47
  model_id: str
46
- model_revision: str
48
+ model_revision: Optional[str]
47
49
  multilingual: bool
48
50
  model_ability: Optional[str]
49
51
  default_model_config: Optional[Dict[str, Any]]
50
52
  default_transcription_config: Optional[Dict[str, Any]]
53
+ engine: Optional[str]
51
54
 
52
55
 
53
56
  class AudioModelDescription(ModelDescription):
@@ -160,17 +163,34 @@ def create_audio_model_instance(
160
163
  model_path: Optional[str] = None,
161
164
  **kwargs,
162
165
  ) -> Tuple[
163
- Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel],
166
+ Union[
167
+ WhisperModel,
168
+ WhisperMLXModel,
169
+ FunASRModel,
170
+ ChatTTSModel,
171
+ CosyVoiceModel,
172
+ FishSpeechModel,
173
+ F5TTSModel,
174
+ ],
164
175
  AudioModelDescription,
165
176
  ]:
166
177
  model_spec = match_audio(model_name, download_hub)
167
178
  if model_path is None:
168
179
  model_path = cache(model_spec)
169
180
  model: Union[
170
- WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel
181
+ WhisperModel,
182
+ WhisperMLXModel,
183
+ FunASRModel,
184
+ ChatTTSModel,
185
+ CosyVoiceModel,
186
+ FishSpeechModel,
187
+ F5TTSModel,
171
188
  ]
172
189
  if model_spec.model_family == "whisper":
173
- model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
190
+ if not model_spec.engine:
191
+ model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
192
+ else:
193
+ model = WhisperMLXModel(model_uid, model_path, model_spec, **kwargs)
174
194
  elif model_spec.model_family == "funasr":
175
195
  model = FunASRModel(model_uid, model_path, model_spec, **kwargs)
176
196
  elif model_spec.model_family == "ChatTTS":
@@ -179,6 +199,8 @@ def create_audio_model_instance(
179
199
  model = CosyVoiceModel(model_uid, model_path, model_spec, **kwargs)
180
200
  elif model_spec.model_family == "FishAudio":
181
201
  model = FishSpeechModel(model_uid, model_path, model_spec, **kwargs)
202
+ elif model_spec.model_family == "F5-TTS":
203
+ model = F5TTSModel(model_uid, model_path, model_spec, **kwargs)
182
204
  else:
183
205
  raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
184
206
  model_description = AudioModelDescription(