xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (170) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +28 -6
  4. xinference/core/utils.py +10 -6
  5. xinference/deploy/cmdline.py +3 -1
  6. xinference/deploy/test/test_cmdline.py +56 -0
  7. xinference/isolation.py +24 -0
  8. xinference/model/audio/core.py +10 -0
  9. xinference/model/audio/cosyvoice.py +25 -3
  10. xinference/model/audio/f5tts.py +200 -0
  11. xinference/model/audio/f5tts_mlx.py +260 -0
  12. xinference/model/audio/fish_speech.py +36 -111
  13. xinference/model/audio/model_spec.json +27 -3
  14. xinference/model/audio/model_spec_modelscope.json +18 -0
  15. xinference/model/audio/utils.py +32 -0
  16. xinference/model/embedding/core.py +203 -142
  17. xinference/model/embedding/model_spec.json +7 -0
  18. xinference/model/embedding/model_spec_modelscope.json +8 -0
  19. xinference/model/image/core.py +69 -1
  20. xinference/model/image/model_spec.json +127 -4
  21. xinference/model/image/model_spec_modelscope.json +130 -4
  22. xinference/model/image/stable_diffusion/core.py +45 -13
  23. xinference/model/llm/__init__.py +2 -2
  24. xinference/model/llm/llm_family.json +219 -53
  25. xinference/model/llm/llm_family.py +15 -36
  26. xinference/model/llm/llm_family_modelscope.json +167 -20
  27. xinference/model/llm/mlx/core.py +287 -51
  28. xinference/model/llm/sglang/core.py +1 -0
  29. xinference/model/llm/transformers/chatglm.py +9 -5
  30. xinference/model/llm/transformers/core.py +1 -0
  31. xinference/model/llm/transformers/qwen2_vl.py +2 -0
  32. xinference/model/llm/transformers/utils.py +16 -8
  33. xinference/model/llm/utils.py +5 -1
  34. xinference/model/llm/vllm/core.py +16 -2
  35. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  36. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  37. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  38. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  39. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  40. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  41. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  42. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  43. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  44. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  45. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  46. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  47. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  48. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  49. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  50. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  51. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  52. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  53. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  54. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  55. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  56. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  57. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  58. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  59. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  60. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  61. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  62. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  63. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  64. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  65. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  66. xinference/thirdparty/f5_tts/api.py +166 -0
  67. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  68. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  69. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  70. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  71. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  72. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  73. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  74. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  75. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  76. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  77. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  78. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  79. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  80. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  81. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  82. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  83. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  84. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  85. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  86. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  87. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  88. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  89. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  90. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  91. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  92. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  93. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  94. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  95. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  96. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  97. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  98. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  99. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  100. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  101. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  102. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  103. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  104. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  105. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  106. xinference/thirdparty/f5_tts/train/README.md +77 -0
  107. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  108. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  109. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  110. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  111. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  112. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  113. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  114. xinference/thirdparty/f5_tts/train/train.py +75 -0
  115. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  116. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  117. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  118. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  119. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  120. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  121. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  122. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  123. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  124. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  125. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  126. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  127. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  128. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  129. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  130. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  131. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  132. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  133. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  134. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  135. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  136. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  137. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  138. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  139. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  140. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  141. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  142. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  143. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  144. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  145. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  146. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  147. xinference/thirdparty/matcha/utils/utils.py +2 -2
  148. xinference/web/ui/build/asset-manifest.json +3 -3
  149. xinference/web/ui/build/index.html +1 -1
  150. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  151. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  153. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
  154. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
  155. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  156. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  157. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  158. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  159. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  160. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  161. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  162. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  163. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  164. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  165. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  166. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  167. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
  168. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
  169. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
  170. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
@@ -168,9 +168,14 @@ class MLXModel(LLM):
168
168
  return False
169
169
  if "generate" not in llm_family.model_ability:
170
170
  return False
171
+ if "chat" in llm_family.model_ability or "vision" in llm_family.model_ability:
172
+ # do not process chat or vision
173
+ return False
171
174
  return True
172
175
 
173
- def _get_prompt_cache(self, prompt, lora_name: Optional[str] = None):
176
+ def _get_prompt_cache(
177
+ self, prompt, lora_name: Optional[str] = None, model: Any = None
178
+ ):
174
179
  from mlx_lm.models.cache import make_prompt_cache
175
180
 
176
181
  assert self._prompt_cache is not None
@@ -182,7 +187,9 @@ class MLXModel(LLM):
182
187
  or self._prompt_cache.tokens != prompt[:cache_len]
183
188
  ):
184
189
  self._prompt_cache.model_key = model_key
185
- self._prompt_cache.cache = make_prompt_cache(self._model, self._max_kv_size)
190
+ self._prompt_cache.cache = make_prompt_cache(
191
+ model or self._model, self._max_kv_size
192
+ )
186
193
  self._prompt_cache.tokens = []
187
194
  logger.debug("Making new prompt cache for %s", self.model_uid)
188
195
  else:
@@ -191,18 +198,35 @@ class MLXModel(LLM):
191
198
  self._prompt_cache.tokens.extend(prompt)
192
199
  return prompt
193
200
 
194
- def _generate_stream(self, prompt: str, kwargs: MLXGenerateConfig):
195
- import mlx.core as mx
196
- from mlx_lm.utils import generate_step
201
+ def _generate_stream_inner(self, **kwargs):
202
+ from mlx_lm.utils import make_sampler, stream_generate
197
203
 
198
- model = self._model
204
+ sampler = make_sampler(
205
+ temp=kwargs.pop("temperature"), top_p=kwargs.pop("top_p")
206
+ )
207
+ prompt_token_ids = kwargs.pop("prompt_token_ids")
208
+ yield from stream_generate(
209
+ self._model, self._tokenizer, prompt_token_ids, sampler=sampler, **kwargs
210
+ )
211
+
212
+ def _prepare_inputs(
213
+ self, prompt: Union[str, Dict[str, Any]], kwargs
214
+ ) -> Tuple[Any, int]:
215
+ prompt_token_ids = self._tokenizer.encode(prompt)
216
+ prompt_token_ids = self._get_prompt_cache(
217
+ prompt_token_ids, kwargs.get("lora_name")
218
+ )
219
+ return prompt_token_ids, len(prompt_token_ids)
220
+
221
+ def _generate_stream(
222
+ self, prompt: Union[str, Dict[str, Any]], kwargs: MLXGenerateConfig
223
+ ):
199
224
  model_uid = self.model_uid
200
225
  tokenizer = self._tokenizer
201
226
  max_tokens = kwargs["max_tokens"]
202
227
  chunk_id = str(uuid.uuid4())
203
228
  stop_token_ids = kwargs.get("stop_token_ids", [])
204
229
  stream = kwargs.get("stream", False)
205
- lora_name = kwargs.get("lora_name")
206
230
  stream_options = kwargs.pop("stream_options", None)
207
231
  include_usage = (
208
232
  stream_options["include_usage"]
@@ -210,39 +234,28 @@ class MLXModel(LLM):
210
234
  else False
211
235
  )
212
236
 
213
- prompt_token_ids = tokenizer.encode(prompt)
214
- prompt_token_ids = self._get_prompt_cache(prompt_token_ids, lora_name)
215
- prompt_tokens = mx.array(prompt_token_ids)
216
- input_echo_len = len(prompt_tokens)
237
+ prompt_token_ids, input_echo_len = self._prepare_inputs(prompt, kwargs)
217
238
 
218
239
  i = 0
219
240
  start = time.time()
220
241
  output = ""
221
242
  tokens = []
222
- for (token, _), i in zip(
223
- generate_step(
224
- prompt_tokens,
225
- model,
226
- temp=kwargs["temperature"],
243
+ for chunk_resp, i in zip(
244
+ self._generate_stream_inner(
245
+ prompt_token_ids=prompt_token_ids,
246
+ max_tokens=max_tokens,
247
+ temperature=kwargs["temperature"],
248
+ top_p=kwargs["top_p"],
227
249
  repetition_penalty=kwargs["repetition_penalty"],
228
250
  repetition_context_size=kwargs["repetition_context_size"],
229
- top_p=kwargs["top_p"],
230
- prompt_cache=self._prompt_cache.cache, # type: ignore
251
+ prompt_cache=self._prompt_cache.cache if self._prompt_cache else None, # type: ignore
231
252
  ),
232
253
  range(max_tokens),
233
254
  ):
255
+ token = chunk_resp.token
234
256
  tokens.append(token)
235
- if token == tokenizer.eos_token_id or token in stop_token_ids: # type: ignore
236
- break
237
-
238
- # Yield the last segment if streaming
239
- out = tokenizer.decode(
240
- token,
241
- skip_special_tokens=True,
242
- spaces_between_special_tokens=False,
243
- clean_up_tokenization_spaces=True,
244
- )
245
257
 
258
+ out = chunk_resp.text
246
259
  if stream:
247
260
  # this special character is mainly for qwen
248
261
  out = out.strip("�")
@@ -266,11 +279,15 @@ class MLXModel(LLM):
266
279
  total_tokens=(input_echo_len + i),
267
280
  ), completion_usage
268
281
 
282
+ if token == tokenizer.eos_token_id or token in stop_token_ids: # type: ignore
283
+ break
284
+
269
285
  logger.info(
270
286
  f"Average generation speed: {i / (time.time() - start):.2f} tokens/s."
271
287
  )
272
288
 
273
- self._prompt_cache.tokens.extend(tokens) # type: ignore
289
+ if self._prompt_cache:
290
+ self._prompt_cache.tokens.extend(tokens) # type: ignore
274
291
 
275
292
  if i == max_tokens - 1:
276
293
  finish_reason = "length"
@@ -314,10 +331,12 @@ class MLXModel(LLM):
314
331
  yield completion_chunk, completion_usage
315
332
 
316
333
  def generate(
317
- self, prompt: str, generate_config: Optional[MLXGenerateConfig] = None
334
+ self,
335
+ prompt: Union[str, Dict[str, Any]],
336
+ generate_config: Optional[MLXGenerateConfig] = None,
318
337
  ) -> Union[Completion, Iterator[CompletionChunk]]:
319
338
  def generator_wrapper(
320
- prompt: str, generate_config: MLXGenerateConfig
339
+ prompt: Union[str, Dict[str, Any]], generate_config: MLXGenerateConfig
321
340
  ) -> Iterator[CompletionChunk]:
322
341
  for completion_chunk, completion_usage in self._generate_stream(
323
342
  prompt,
@@ -356,26 +375,6 @@ class MLXModel(LLM):
356
375
 
357
376
 
358
377
  class MLXChatModel(MLXModel, ChatModelMixin):
359
- def __init__(
360
- self,
361
- model_uid: str,
362
- model_family: "LLMFamilyV1",
363
- model_spec: "LLMSpecV1",
364
- quantization: str,
365
- model_path: str,
366
- model_config: Optional[MLXModelConfig] = None,
367
- peft_model: Optional[List[LoRA]] = None,
368
- ):
369
- super().__init__(
370
- model_uid,
371
- model_family,
372
- model_spec,
373
- quantization,
374
- model_path,
375
- model_config,
376
- peft_model,
377
- )
378
-
379
378
  def _sanitize_generate_config(
380
379
  self,
381
380
  generate_config: Optional[MLXGenerateConfig],
@@ -402,6 +401,9 @@ class MLXChatModel(MLXModel, ChatModelMixin):
402
401
  return False
403
402
  if "chat" not in llm_family.model_ability:
404
403
  return False
404
+ if "vision" in llm_family.model_ability:
405
+ # do not process vision
406
+ return False
405
407
  return True
406
408
 
407
409
  def chat(
@@ -432,3 +434,237 @@ class MLXChatModel(MLXModel, ChatModelMixin):
432
434
  if tools:
433
435
  return self._tool_calls_completion(self.model_family, self.model_uid, c)
434
436
  return self._to_chat_completion(c)
437
+
438
+
439
+ class MLXVisionModel(MLXModel, ChatModelMixin):
440
+ @classmethod
441
+ def match(
442
+ cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
443
+ ) -> bool:
444
+ if llm_spec.model_format not in ["mlx"]:
445
+ return False
446
+ if sys.platform != "darwin" or platform.processor() != "arm":
447
+ # only work for Mac M chips
448
+ return False
449
+ if "vision" not in llm_family.model_ability:
450
+ return False
451
+ return True
452
+
453
+ def _load_model(self, **kwargs):
454
+ try:
455
+ from mlx_vlm import load
456
+ except ImportError:
457
+ error_message = "Failed to import module 'mlx_vlm'"
458
+ installation_guide = [
459
+ "Please make sure 'mlx_vlm' is installed. ",
460
+ "You can install it by `pip install mlx_vlm`\n",
461
+ ]
462
+
463
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
464
+
465
+ self._prompt_cache = PromptCache()
466
+
467
+ return load(self.model_path)
468
+
469
+ def load(self):
470
+ kwargs = {}
471
+ kwargs["revision"] = self._model_config.get(
472
+ "revision", self.model_spec.model_revision
473
+ )
474
+ kwargs["trust_remote_code"] = self._model_config.get("trust_remote_code")
475
+ kwargs["cache_limit_gb"] = self._model_config.pop("cache_limit_gb", None)
476
+
477
+ self._model, self._processor = self._load_model(**kwargs)
478
+ self._tokenizer = self._processor.tokenizer
479
+
480
+ def _generate_stream_inner_no_image(self, **kwargs):
481
+ import mlx.nn as nn
482
+ from mlx_lm.utils import make_sampler, stream_generate
483
+
484
+ # For mlx-lm, the model(inputs) will return logits,
485
+ # but the language model in mlx-vlm will return an object
486
+ # https://github.com/Blaizzy/mlx-vlm/blob/3f5e1620072440afb7496940f67ac1c7fc64056f/mlx_vlm/models/base.py#L260
487
+ # so we cannot pass the language model to stream_generate directly
488
+ # we wrap here to just let model(inputs) return logits to pass stream_generate
489
+ class ModelWrapper(nn.Module):
490
+ def __init__(self, model):
491
+ super().__init__()
492
+ self._model = model.language_model
493
+
494
+ @property
495
+ def layers(self):
496
+ return self._model.layers
497
+
498
+ def __call__(self, *args, **kwargs):
499
+ return self._model(*args, **kwargs).logits
500
+
501
+ sampler = make_sampler(
502
+ temp=kwargs.pop("temperature"), top_p=kwargs.pop("top_p")
503
+ )
504
+ prompt_token_ids = kwargs.pop("prompt_token_ids")
505
+ yield from stream_generate(
506
+ ModelWrapper(self._model),
507
+ self._tokenizer,
508
+ prompt_token_ids,
509
+ sampler=sampler,
510
+ **kwargs,
511
+ )
512
+
513
+ def _generate_stream_inner(self, **kwargs):
514
+ import mlx.core as mx
515
+ from mlx_lm.utils import GenerationResponse
516
+ from mlx_vlm.utils import generate_step
517
+
518
+ inputs = kwargs["prompt_token_ids"]
519
+
520
+ if not isinstance(inputs, tuple):
521
+ # no images
522
+ yield from self._generate_stream_inner_no_image(**kwargs)
523
+ return
524
+
525
+ max_tokens = kwargs.pop("max_tokens")
526
+ input_ids, pixel_values, mask = inputs[:3]
527
+
528
+ kwargs = {
529
+ k: v
530
+ for k, v in zip(
531
+ [
532
+ "image_grid_thw",
533
+ "image_sizes",
534
+ "aspect_ratio_ids",
535
+ "aspect_ratio_mask",
536
+ "cross_attention_mask",
537
+ ],
538
+ inputs[3:],
539
+ )
540
+ }
541
+
542
+ tokenizer = self._processor.tokenizer
543
+ detokenizer = self._processor.detokenizer
544
+
545
+ detokenizer.reset()
546
+ tic = time.perf_counter()
547
+ for (token, logprobs), n in zip(
548
+ generate_step(input_ids, self._model, pixel_values, mask, **kwargs),
549
+ range(max_tokens),
550
+ ):
551
+ if n == 0:
552
+ prompt_time = time.perf_counter() - tic
553
+ prompt_tps = len(input_ids) / prompt_time
554
+ tic = time.perf_counter()
555
+ if token == tokenizer.eos_token_id:
556
+ break
557
+ detokenizer.add_token(token)
558
+
559
+ # Yield the last segment if streaming
560
+ yield GenerationResponse(
561
+ text=detokenizer.last_segment,
562
+ token=token,
563
+ logprobs=logprobs,
564
+ prompt_tokens=len(input_ids),
565
+ prompt_tps=prompt_tps,
566
+ generation_tokens=n + 1,
567
+ generation_tps=(n + 1) / (time.perf_counter() - tic),
568
+ peak_memory=mx.metal.get_peak_memory() / 1e9,
569
+ )
570
+
571
+ detokenizer.finalize()
572
+ yield GenerationResponse(
573
+ text=detokenizer.last_segment,
574
+ token=token,
575
+ logprobs=logprobs,
576
+ prompt_tokens=len(input_ids),
577
+ prompt_tps=prompt_tps,
578
+ generation_tokens=n + 1,
579
+ generation_tps=(n + 1) / (time.perf_counter() - tic),
580
+ peak_memory=mx.metal.get_peak_memory() / 1e9,
581
+ )
582
+
583
+ def _prepare_inputs(
584
+ self, prompt: Union[str, Dict[str, Any]], kwargs
585
+ ) -> Tuple[Any, int]:
586
+ from mlx_vlm import prepare_inputs
587
+
588
+ prompt_str = prompt.get("prompt") # type: ignore
589
+ images = prompt.get("multi_modal_data", {}).get("image") # type: ignore
590
+ if images and not isinstance(images, list):
591
+ images = [images]
592
+ if hasattr(self._model.config, "image_token_index"):
593
+ image_token_index = self._model.config.image_token_index
594
+ else:
595
+ image_token_index = None
596
+
597
+ if not images:
598
+ prompt = prompt["prompt"] # type: ignore
599
+ prompt_token_ids = self._tokenizer.encode(prompt)
600
+ prompt_token_ids = self._get_prompt_cache(
601
+ prompt_token_ids,
602
+ kwargs.get("lora_name"),
603
+ model=self._model.language_model,
604
+ )
605
+ return prompt_token_ids, len(prompt_token_ids)
606
+ else:
607
+ inputs = prepare_inputs(
608
+ None,
609
+ self._processor,
610
+ images,
611
+ prompt_str,
612
+ image_token_index,
613
+ kwargs.get("resize_shape"),
614
+ )
615
+ input_ids = inputs[0]
616
+ return inputs, len(input_ids)
617
+
618
+ def chat(
619
+ self,
620
+ messages: List[Dict],
621
+ generate_config: Optional[MLXGenerateConfig] = None,
622
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
623
+ messages = self._transform_messages(messages) # type: ignore
624
+ tools = generate_config.pop("tools", []) if generate_config else None
625
+
626
+ model_family = self.model_family.model_family or self.model_family.model_name
627
+
628
+ if "internvl2" not in model_family.lower():
629
+ from qwen_vl_utils import process_vision_info
630
+
631
+ full_context_kwargs = {}
632
+ if tools and model_family in QWEN_TOOL_CALL_FAMILY:
633
+ full_context_kwargs["tools"] = tools
634
+ assert self.model_family.chat_template is not None
635
+ prompt = self.get_full_context(
636
+ messages, self.model_family.chat_template, **full_context_kwargs
637
+ )
638
+ images, video_inputs = process_vision_info(messages)
639
+ if video_inputs:
640
+ raise ValueError("Not support video input now.")
641
+ else:
642
+ prompt, images = self.get_specific_prompt(model_family, messages) # type: ignore
643
+
644
+ if not images:
645
+ inputs = {
646
+ "prompt": prompt,
647
+ }
648
+ elif len(images) == 1:
649
+ inputs = {
650
+ "prompt": prompt,
651
+ "multi_modal_data": {"image": images[-1]}, # type: ignore
652
+ }
653
+ else:
654
+ inputs = {
655
+ "prompt": prompt,
656
+ "multi_modal_data": {"image": images}, # type: ignore
657
+ }
658
+ generate_config = self._sanitize_generate_config(generate_config)
659
+
660
+ stream = generate_config.get("stream", False)
661
+ if stream:
662
+ it = self.generate(inputs, generate_config)
663
+ assert isinstance(it, Iterator)
664
+ return self._to_chat_completion_chunks(it)
665
+ else:
666
+ c = self.generate(inputs, generate_config)
667
+ assert not isinstance(c, Iterator)
668
+ if tools:
669
+ return self._tool_calls_completion(self.model_family, self.model_uid, c)
670
+ return self._to_chat_completion(c)
@@ -75,6 +75,7 @@ SGLANG_SUPPORTED_CHAT_MODELS = [
75
75
  "llama-2-chat",
76
76
  "llama-3-instruct",
77
77
  "llama-3.1-instruct",
78
+ "llama-3.3-instruct",
78
79
  "qwen-chat",
79
80
  "qwen1.5-chat",
80
81
  "qwen2-instruct",
@@ -61,7 +61,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
61
61
 
62
62
  def _load_model(self, **kwargs):
63
63
  try:
64
- from transformers import AutoModel, AutoTokenizer
64
+ from transformers import AutoModelForCausalLM, AutoTokenizer
65
65
  except ImportError:
66
66
  error_message = "Failed to import module 'transformers'"
67
67
  installation_guide = [
@@ -77,7 +77,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
77
77
  encode_special_tokens=True,
78
78
  revision=kwargs["revision"],
79
79
  )
80
- model = AutoModel.from_pretrained(
80
+ model = AutoModelForCausalLM.from_pretrained(
81
81
  self.model_path,
82
82
  **kwargs,
83
83
  )
@@ -232,9 +232,11 @@ class ChatglmPytorchChatModel(PytorchChatModel):
232
232
  content = {
233
233
  "name": function_name,
234
234
  "arguments": json.dumps(
235
- arguments_json
236
- if isinstance(arguments_json, dict)
237
- else arguments,
235
+ (
236
+ arguments_json
237
+ if isinstance(arguments_json, dict)
238
+ else arguments
239
+ ),
238
240
  ensure_ascii=False,
239
241
  ),
240
242
  }
@@ -331,6 +333,8 @@ class ChatglmPytorchChatModel(PytorchChatModel):
331
333
  max_new_tokens = generate_config.get("max_tokens")
332
334
  if max_new_tokens is not None:
333
335
  kwargs["max_new_tokens"] = int(max_new_tokens)
336
+ else:
337
+ kwargs["max_new_tokens"] = 1024
334
338
  do_sample = generate_config.get("do_sample")
335
339
  if do_sample is not None:
336
340
  kwargs["do_sample"] = bool(do_sample)
@@ -69,6 +69,7 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
69
69
  "deepseek-v2.5",
70
70
  "deepseek-v2-chat-0628",
71
71
  "glm-edge-v",
72
+ "QvQ-72B-Preview",
72
73
  ]
73
74
 
74
75
 
@@ -47,6 +47,8 @@ class Qwen2VLChatModel(PytorchChatModel):
47
47
  llm_family = model_family.model_family or model_family.model_name
48
48
  if "qwen2-vl-instruct".lower() in llm_family.lower():
49
49
  return True
50
+ if "qvq-72b-preview".lower() in llm_family.lower():
51
+ return True
50
52
  return False
51
53
 
52
54
  def load(self):
@@ -156,6 +156,7 @@ def _get_completion(
156
156
  finish_reason: Optional[str],
157
157
  model_uid: str,
158
158
  r: InferenceRequest,
159
+ completion_tokens: int,
159
160
  ):
160
161
  completion_choice = CompletionChoice(
161
162
  text=output, index=0, logprobs=None, finish_reason=finish_reason
@@ -170,8 +171,8 @@ def _get_completion(
170
171
  )
171
172
  completion_usage = CompletionUsage(
172
173
  prompt_tokens=len(r.prompt_tokens),
173
- completion_tokens=len(r.new_tokens),
174
- total_tokens=len(r.prompt_tokens) + len(r.new_tokens),
174
+ completion_tokens=completion_tokens,
175
+ total_tokens=len(r.prompt_tokens) + completion_tokens,
175
176
  )
176
177
  completion = Completion(
177
178
  id=completion_chunk["id"],
@@ -371,7 +372,7 @@ def _batch_inference_one_step_internal(
371
372
  r.stopped = stopped
372
373
  r.finish_reason = finish_reason
373
374
 
374
- if r.stopped and r not in stop_token_mapping and r not in output_mapping:
375
+ if r.stopped and r not in stop_token_mapping:
375
376
  stop_token_mapping[r] = _i + 1
376
377
 
377
378
  if r.stream:
@@ -446,12 +447,14 @@ def _batch_inference_one_step_internal(
446
447
  else:
447
448
  # last round, handle non-stream result
448
449
  if r.stopped and _i == decode_round - 1:
449
- invalid_token_num = decode_round - stop_token_mapping[r]
450
+ invalid_token_num = (
451
+ (decode_round - stop_token_mapping[r] + 1)
452
+ if r.finish_reason == "stop"
453
+ else (decode_round - stop_token_mapping[r])
454
+ )
450
455
  outputs = (
451
456
  tokenizer.decode(
452
- r.new_tokens[: -(invalid_token_num + 1)]
453
- if r.finish_reason == "stop"
454
- else r.new_tokens[:-invalid_token_num],
457
+ r.new_tokens[:-invalid_token_num],
455
458
  skip_special_tokens=True,
456
459
  spaces_between_special_tokens=False,
457
460
  clean_up_tokenization_spaces=True,
@@ -460,7 +463,12 @@ def _batch_inference_one_step_internal(
460
463
  else output_mapping[r]
461
464
  )
462
465
  completion = _get_completion(
463
- outputs, r.chunk_id, r.finish_reason, model_uid, r
466
+ outputs,
467
+ r.chunk_id,
468
+ r.finish_reason,
469
+ model_uid,
470
+ r,
471
+ len(r.new_tokens) - invalid_token_num,
464
472
  )
465
473
  r.completion = [completion]
466
474
 
@@ -52,6 +52,7 @@ QWEN_TOOL_CALL_FAMILY = [
52
52
  "qwen2-instruct",
53
53
  "qwen2-moe-instruct",
54
54
  "qwen2.5-instruct",
55
+ "qwen2.5-coder-instruct",
55
56
  ]
56
57
 
57
58
  GLM4_TOOL_CALL_FAMILY = [
@@ -324,7 +325,10 @@ class ChatModelMixin:
324
325
  """
325
326
  try:
326
327
  if isinstance(c, dict):
327
- return [(None, c["name"], c["arguments"])]
328
+ try:
329
+ return [(None, c["name"], json.loads(c["arguments"]))]
330
+ except Exception:
331
+ return [(None, c["name"], c["arguments"])]
328
332
  except KeyError:
329
333
  logger.error("Can't parse glm output: %s", c)
330
334
  return [(str(c), None, None)]
@@ -70,6 +70,7 @@ class VLLMModelConfig(TypedDict, total=False):
70
70
  max_model_len: Optional[int]
71
71
  limit_mm_per_prompt: Optional[Dict[str, int]]
72
72
  guided_decoding_backend: Optional[str]
73
+ scheduling_policy: Optional[str]
73
74
 
74
75
 
75
76
  class VLLMGenerateConfig(TypedDict, total=False):
@@ -86,6 +87,7 @@ class VLLMGenerateConfig(TypedDict, total=False):
86
87
  stop: Optional[Union[str, List[str]]]
87
88
  stream: bool # non-sampling param, should not be passed to the engine.
88
89
  stream_options: Optional[Union[dict, None]]
90
+ skip_special_tokens: Optional[bool]
89
91
  response_format: Optional[dict]
90
92
  guided_json: Optional[Union[str, dict]]
91
93
  guided_regex: Optional[str]
@@ -181,14 +183,19 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.5.3":
181
183
  if VLLM_INSTALLED and vllm.__version__ > "0.5.3":
182
184
  VLLM_SUPPORTED_MODELS.append("llama-3.1")
183
185
  VLLM_SUPPORTED_CHAT_MODELS.append("llama-3.1-instruct")
186
+ VLLM_SUPPORTED_CHAT_MODELS.append("llama-3.3-instruct")
184
187
 
185
188
  if VLLM_INSTALLED and vllm.__version__ >= "0.6.1":
186
189
  VLLM_SUPPORTED_VISION_MODEL_LIST.append("internvl2")
187
190
 
191
+ if VLLM_INSTALLED and vllm.__version__ >= "0.6.2":
192
+ VLLM_SUPPORTED_CHAT_MODELS.append("minicpm3-4b")
193
+
188
194
  if VLLM_INSTALLED and vllm.__version__ >= "0.6.3":
189
195
  VLLM_SUPPORTED_MODELS.append("llama-3.2-vision")
190
196
  VLLM_SUPPORTED_VISION_MODEL_LIST.append("llama-3.2-vision-instruct")
191
197
  VLLM_SUPPORTED_VISION_MODEL_LIST.append("qwen2-vl-instruct")
198
+ VLLM_SUPPORTED_VISION_MODEL_LIST.append("QvQ-72B-Preview")
192
199
 
193
200
 
194
201
  class VLLMModel(LLM):
@@ -242,7 +249,6 @@ class VLLMModel(LLM):
242
249
  multiprocessing.set_start_method("fork", force=True)
243
250
 
244
251
  self._model_config = self._sanitize_model_config(self._model_config)
245
-
246
252
  if self.lora_modules is None:
247
253
  self.lora_requests = []
248
254
  else:
@@ -325,7 +331,9 @@ class VLLMModel(LLM):
325
331
  model_config.setdefault("quantization", None)
326
332
  model_config.setdefault("max_model_len", None)
327
333
  model_config.setdefault("guided_decoding_backend", "outlines")
328
-
334
+ # Add scheduling policy if vLLM version is 0.6.3 or higher
335
+ if vllm.__version__ >= "0.6.3":
336
+ model_config.setdefault("scheduling_policy", "fcfs")
329
337
  return model_config
330
338
 
331
339
  @staticmethod
@@ -373,6 +381,9 @@ class VLLMModel(LLM):
373
381
  sanitized.setdefault(
374
382
  "stream_options", generate_config.get("stream_options", None)
375
383
  )
384
+ sanitized.setdefault(
385
+ "skip_special_tokens", generate_config.get("skip_special_tokens", True)
386
+ )
376
387
  sanitized.setdefault(
377
388
  "guided_json", generate_config.get("guided_json", guided_json)
378
389
  )
@@ -854,6 +865,9 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
854
865
  "image": 2, # default 2 images all chat
855
866
  }
856
867
  )
868
+ # Add scheduling policy if vLLM version is 0.6.3 or higher
869
+ if vllm.__version__ >= "0.6.3":
870
+ model_config.setdefault("scheduling_policy", "fcfs")
857
871
 
858
872
  return model_config
859
873