xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +28 -6
- xinference/core/utils.py +10 -6
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/core.py +10 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +200 -0
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +36 -111
- xinference/model/audio/model_spec.json +27 -3
- xinference/model/audio/model_spec_modelscope.json +18 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/embedding/core.py +203 -142
- xinference/model/embedding/model_spec.json +7 -0
- xinference/model/embedding/model_spec_modelscope.json +8 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +127 -4
- xinference/model/image/model_spec_modelscope.json +130 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/__init__.py +2 -2
- xinference/model/llm/llm_family.json +219 -53
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +167 -20
- xinference/model/llm/mlx/core.py +287 -51
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/qwen2_vl.py +2 -0
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +5 -1
- xinference/model/llm/vllm/core.py +16 -2
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
xinference/model/llm/mlx/core.py
CHANGED
|
@@ -168,9 +168,14 @@ class MLXModel(LLM):
|
|
|
168
168
|
return False
|
|
169
169
|
if "generate" not in llm_family.model_ability:
|
|
170
170
|
return False
|
|
171
|
+
if "chat" in llm_family.model_ability or "vision" in llm_family.model_ability:
|
|
172
|
+
# do not process chat or vision
|
|
173
|
+
return False
|
|
171
174
|
return True
|
|
172
175
|
|
|
173
|
-
def _get_prompt_cache(
|
|
176
|
+
def _get_prompt_cache(
|
|
177
|
+
self, prompt, lora_name: Optional[str] = None, model: Any = None
|
|
178
|
+
):
|
|
174
179
|
from mlx_lm.models.cache import make_prompt_cache
|
|
175
180
|
|
|
176
181
|
assert self._prompt_cache is not None
|
|
@@ -182,7 +187,9 @@ class MLXModel(LLM):
|
|
|
182
187
|
or self._prompt_cache.tokens != prompt[:cache_len]
|
|
183
188
|
):
|
|
184
189
|
self._prompt_cache.model_key = model_key
|
|
185
|
-
self._prompt_cache.cache = make_prompt_cache(
|
|
190
|
+
self._prompt_cache.cache = make_prompt_cache(
|
|
191
|
+
model or self._model, self._max_kv_size
|
|
192
|
+
)
|
|
186
193
|
self._prompt_cache.tokens = []
|
|
187
194
|
logger.debug("Making new prompt cache for %s", self.model_uid)
|
|
188
195
|
else:
|
|
@@ -191,18 +198,35 @@ class MLXModel(LLM):
|
|
|
191
198
|
self._prompt_cache.tokens.extend(prompt)
|
|
192
199
|
return prompt
|
|
193
200
|
|
|
194
|
-
def
|
|
195
|
-
|
|
196
|
-
from mlx_lm.utils import generate_step
|
|
201
|
+
def _generate_stream_inner(self, **kwargs):
|
|
202
|
+
from mlx_lm.utils import make_sampler, stream_generate
|
|
197
203
|
|
|
198
|
-
|
|
204
|
+
sampler = make_sampler(
|
|
205
|
+
temp=kwargs.pop("temperature"), top_p=kwargs.pop("top_p")
|
|
206
|
+
)
|
|
207
|
+
prompt_token_ids = kwargs.pop("prompt_token_ids")
|
|
208
|
+
yield from stream_generate(
|
|
209
|
+
self._model, self._tokenizer, prompt_token_ids, sampler=sampler, **kwargs
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def _prepare_inputs(
|
|
213
|
+
self, prompt: Union[str, Dict[str, Any]], kwargs
|
|
214
|
+
) -> Tuple[Any, int]:
|
|
215
|
+
prompt_token_ids = self._tokenizer.encode(prompt)
|
|
216
|
+
prompt_token_ids = self._get_prompt_cache(
|
|
217
|
+
prompt_token_ids, kwargs.get("lora_name")
|
|
218
|
+
)
|
|
219
|
+
return prompt_token_ids, len(prompt_token_ids)
|
|
220
|
+
|
|
221
|
+
def _generate_stream(
|
|
222
|
+
self, prompt: Union[str, Dict[str, Any]], kwargs: MLXGenerateConfig
|
|
223
|
+
):
|
|
199
224
|
model_uid = self.model_uid
|
|
200
225
|
tokenizer = self._tokenizer
|
|
201
226
|
max_tokens = kwargs["max_tokens"]
|
|
202
227
|
chunk_id = str(uuid.uuid4())
|
|
203
228
|
stop_token_ids = kwargs.get("stop_token_ids", [])
|
|
204
229
|
stream = kwargs.get("stream", False)
|
|
205
|
-
lora_name = kwargs.get("lora_name")
|
|
206
230
|
stream_options = kwargs.pop("stream_options", None)
|
|
207
231
|
include_usage = (
|
|
208
232
|
stream_options["include_usage"]
|
|
@@ -210,39 +234,28 @@ class MLXModel(LLM):
|
|
|
210
234
|
else False
|
|
211
235
|
)
|
|
212
236
|
|
|
213
|
-
prompt_token_ids =
|
|
214
|
-
prompt_token_ids = self._get_prompt_cache(prompt_token_ids, lora_name)
|
|
215
|
-
prompt_tokens = mx.array(prompt_token_ids)
|
|
216
|
-
input_echo_len = len(prompt_tokens)
|
|
237
|
+
prompt_token_ids, input_echo_len = self._prepare_inputs(prompt, kwargs)
|
|
217
238
|
|
|
218
239
|
i = 0
|
|
219
240
|
start = time.time()
|
|
220
241
|
output = ""
|
|
221
242
|
tokens = []
|
|
222
|
-
for
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
243
|
+
for chunk_resp, i in zip(
|
|
244
|
+
self._generate_stream_inner(
|
|
245
|
+
prompt_token_ids=prompt_token_ids,
|
|
246
|
+
max_tokens=max_tokens,
|
|
247
|
+
temperature=kwargs["temperature"],
|
|
248
|
+
top_p=kwargs["top_p"],
|
|
227
249
|
repetition_penalty=kwargs["repetition_penalty"],
|
|
228
250
|
repetition_context_size=kwargs["repetition_context_size"],
|
|
229
|
-
|
|
230
|
-
prompt_cache=self._prompt_cache.cache, # type: ignore
|
|
251
|
+
prompt_cache=self._prompt_cache.cache if self._prompt_cache else None, # type: ignore
|
|
231
252
|
),
|
|
232
253
|
range(max_tokens),
|
|
233
254
|
):
|
|
255
|
+
token = chunk_resp.token
|
|
234
256
|
tokens.append(token)
|
|
235
|
-
if token == tokenizer.eos_token_id or token in stop_token_ids: # type: ignore
|
|
236
|
-
break
|
|
237
|
-
|
|
238
|
-
# Yield the last segment if streaming
|
|
239
|
-
out = tokenizer.decode(
|
|
240
|
-
token,
|
|
241
|
-
skip_special_tokens=True,
|
|
242
|
-
spaces_between_special_tokens=False,
|
|
243
|
-
clean_up_tokenization_spaces=True,
|
|
244
|
-
)
|
|
245
257
|
|
|
258
|
+
out = chunk_resp.text
|
|
246
259
|
if stream:
|
|
247
260
|
# this special character is mainly for qwen
|
|
248
261
|
out = out.strip("�")
|
|
@@ -266,11 +279,15 @@ class MLXModel(LLM):
|
|
|
266
279
|
total_tokens=(input_echo_len + i),
|
|
267
280
|
), completion_usage
|
|
268
281
|
|
|
282
|
+
if token == tokenizer.eos_token_id or token in stop_token_ids: # type: ignore
|
|
283
|
+
break
|
|
284
|
+
|
|
269
285
|
logger.info(
|
|
270
286
|
f"Average generation speed: {i / (time.time() - start):.2f} tokens/s."
|
|
271
287
|
)
|
|
272
288
|
|
|
273
|
-
self._prompt_cache
|
|
289
|
+
if self._prompt_cache:
|
|
290
|
+
self._prompt_cache.tokens.extend(tokens) # type: ignore
|
|
274
291
|
|
|
275
292
|
if i == max_tokens - 1:
|
|
276
293
|
finish_reason = "length"
|
|
@@ -314,10 +331,12 @@ class MLXModel(LLM):
|
|
|
314
331
|
yield completion_chunk, completion_usage
|
|
315
332
|
|
|
316
333
|
def generate(
|
|
317
|
-
self,
|
|
334
|
+
self,
|
|
335
|
+
prompt: Union[str, Dict[str, Any]],
|
|
336
|
+
generate_config: Optional[MLXGenerateConfig] = None,
|
|
318
337
|
) -> Union[Completion, Iterator[CompletionChunk]]:
|
|
319
338
|
def generator_wrapper(
|
|
320
|
-
prompt: str, generate_config: MLXGenerateConfig
|
|
339
|
+
prompt: Union[str, Dict[str, Any]], generate_config: MLXGenerateConfig
|
|
321
340
|
) -> Iterator[CompletionChunk]:
|
|
322
341
|
for completion_chunk, completion_usage in self._generate_stream(
|
|
323
342
|
prompt,
|
|
@@ -356,26 +375,6 @@ class MLXModel(LLM):
|
|
|
356
375
|
|
|
357
376
|
|
|
358
377
|
class MLXChatModel(MLXModel, ChatModelMixin):
|
|
359
|
-
def __init__(
|
|
360
|
-
self,
|
|
361
|
-
model_uid: str,
|
|
362
|
-
model_family: "LLMFamilyV1",
|
|
363
|
-
model_spec: "LLMSpecV1",
|
|
364
|
-
quantization: str,
|
|
365
|
-
model_path: str,
|
|
366
|
-
model_config: Optional[MLXModelConfig] = None,
|
|
367
|
-
peft_model: Optional[List[LoRA]] = None,
|
|
368
|
-
):
|
|
369
|
-
super().__init__(
|
|
370
|
-
model_uid,
|
|
371
|
-
model_family,
|
|
372
|
-
model_spec,
|
|
373
|
-
quantization,
|
|
374
|
-
model_path,
|
|
375
|
-
model_config,
|
|
376
|
-
peft_model,
|
|
377
|
-
)
|
|
378
|
-
|
|
379
378
|
def _sanitize_generate_config(
|
|
380
379
|
self,
|
|
381
380
|
generate_config: Optional[MLXGenerateConfig],
|
|
@@ -402,6 +401,9 @@ class MLXChatModel(MLXModel, ChatModelMixin):
|
|
|
402
401
|
return False
|
|
403
402
|
if "chat" not in llm_family.model_ability:
|
|
404
403
|
return False
|
|
404
|
+
if "vision" in llm_family.model_ability:
|
|
405
|
+
# do not process vision
|
|
406
|
+
return False
|
|
405
407
|
return True
|
|
406
408
|
|
|
407
409
|
def chat(
|
|
@@ -432,3 +434,237 @@ class MLXChatModel(MLXModel, ChatModelMixin):
|
|
|
432
434
|
if tools:
|
|
433
435
|
return self._tool_calls_completion(self.model_family, self.model_uid, c)
|
|
434
436
|
return self._to_chat_completion(c)
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
class MLXVisionModel(MLXModel, ChatModelMixin):
|
|
440
|
+
@classmethod
|
|
441
|
+
def match(
|
|
442
|
+
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
443
|
+
) -> bool:
|
|
444
|
+
if llm_spec.model_format not in ["mlx"]:
|
|
445
|
+
return False
|
|
446
|
+
if sys.platform != "darwin" or platform.processor() != "arm":
|
|
447
|
+
# only work for Mac M chips
|
|
448
|
+
return False
|
|
449
|
+
if "vision" not in llm_family.model_ability:
|
|
450
|
+
return False
|
|
451
|
+
return True
|
|
452
|
+
|
|
453
|
+
def _load_model(self, **kwargs):
|
|
454
|
+
try:
|
|
455
|
+
from mlx_vlm import load
|
|
456
|
+
except ImportError:
|
|
457
|
+
error_message = "Failed to import module 'mlx_vlm'"
|
|
458
|
+
installation_guide = [
|
|
459
|
+
"Please make sure 'mlx_vlm' is installed. ",
|
|
460
|
+
"You can install it by `pip install mlx_vlm`\n",
|
|
461
|
+
]
|
|
462
|
+
|
|
463
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
464
|
+
|
|
465
|
+
self._prompt_cache = PromptCache()
|
|
466
|
+
|
|
467
|
+
return load(self.model_path)
|
|
468
|
+
|
|
469
|
+
def load(self):
|
|
470
|
+
kwargs = {}
|
|
471
|
+
kwargs["revision"] = self._model_config.get(
|
|
472
|
+
"revision", self.model_spec.model_revision
|
|
473
|
+
)
|
|
474
|
+
kwargs["trust_remote_code"] = self._model_config.get("trust_remote_code")
|
|
475
|
+
kwargs["cache_limit_gb"] = self._model_config.pop("cache_limit_gb", None)
|
|
476
|
+
|
|
477
|
+
self._model, self._processor = self._load_model(**kwargs)
|
|
478
|
+
self._tokenizer = self._processor.tokenizer
|
|
479
|
+
|
|
480
|
+
def _generate_stream_inner_no_image(self, **kwargs):
|
|
481
|
+
import mlx.nn as nn
|
|
482
|
+
from mlx_lm.utils import make_sampler, stream_generate
|
|
483
|
+
|
|
484
|
+
# For mlx-lm, the model(inputs) will return logits,
|
|
485
|
+
# but the language model in mlx-vlm will return an object
|
|
486
|
+
# https://github.com/Blaizzy/mlx-vlm/blob/3f5e1620072440afb7496940f67ac1c7fc64056f/mlx_vlm/models/base.py#L260
|
|
487
|
+
# so we cannot pass the language model to stream_generate directly
|
|
488
|
+
# we wrap here to just let model(inputs) return logits to pass stream_generate
|
|
489
|
+
class ModelWrapper(nn.Module):
|
|
490
|
+
def __init__(self, model):
|
|
491
|
+
super().__init__()
|
|
492
|
+
self._model = model.language_model
|
|
493
|
+
|
|
494
|
+
@property
|
|
495
|
+
def layers(self):
|
|
496
|
+
return self._model.layers
|
|
497
|
+
|
|
498
|
+
def __call__(self, *args, **kwargs):
|
|
499
|
+
return self._model(*args, **kwargs).logits
|
|
500
|
+
|
|
501
|
+
sampler = make_sampler(
|
|
502
|
+
temp=kwargs.pop("temperature"), top_p=kwargs.pop("top_p")
|
|
503
|
+
)
|
|
504
|
+
prompt_token_ids = kwargs.pop("prompt_token_ids")
|
|
505
|
+
yield from stream_generate(
|
|
506
|
+
ModelWrapper(self._model),
|
|
507
|
+
self._tokenizer,
|
|
508
|
+
prompt_token_ids,
|
|
509
|
+
sampler=sampler,
|
|
510
|
+
**kwargs,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
def _generate_stream_inner(self, **kwargs):
|
|
514
|
+
import mlx.core as mx
|
|
515
|
+
from mlx_lm.utils import GenerationResponse
|
|
516
|
+
from mlx_vlm.utils import generate_step
|
|
517
|
+
|
|
518
|
+
inputs = kwargs["prompt_token_ids"]
|
|
519
|
+
|
|
520
|
+
if not isinstance(inputs, tuple):
|
|
521
|
+
# no images
|
|
522
|
+
yield from self._generate_stream_inner_no_image(**kwargs)
|
|
523
|
+
return
|
|
524
|
+
|
|
525
|
+
max_tokens = kwargs.pop("max_tokens")
|
|
526
|
+
input_ids, pixel_values, mask = inputs[:3]
|
|
527
|
+
|
|
528
|
+
kwargs = {
|
|
529
|
+
k: v
|
|
530
|
+
for k, v in zip(
|
|
531
|
+
[
|
|
532
|
+
"image_grid_thw",
|
|
533
|
+
"image_sizes",
|
|
534
|
+
"aspect_ratio_ids",
|
|
535
|
+
"aspect_ratio_mask",
|
|
536
|
+
"cross_attention_mask",
|
|
537
|
+
],
|
|
538
|
+
inputs[3:],
|
|
539
|
+
)
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
tokenizer = self._processor.tokenizer
|
|
543
|
+
detokenizer = self._processor.detokenizer
|
|
544
|
+
|
|
545
|
+
detokenizer.reset()
|
|
546
|
+
tic = time.perf_counter()
|
|
547
|
+
for (token, logprobs), n in zip(
|
|
548
|
+
generate_step(input_ids, self._model, pixel_values, mask, **kwargs),
|
|
549
|
+
range(max_tokens),
|
|
550
|
+
):
|
|
551
|
+
if n == 0:
|
|
552
|
+
prompt_time = time.perf_counter() - tic
|
|
553
|
+
prompt_tps = len(input_ids) / prompt_time
|
|
554
|
+
tic = time.perf_counter()
|
|
555
|
+
if token == tokenizer.eos_token_id:
|
|
556
|
+
break
|
|
557
|
+
detokenizer.add_token(token)
|
|
558
|
+
|
|
559
|
+
# Yield the last segment if streaming
|
|
560
|
+
yield GenerationResponse(
|
|
561
|
+
text=detokenizer.last_segment,
|
|
562
|
+
token=token,
|
|
563
|
+
logprobs=logprobs,
|
|
564
|
+
prompt_tokens=len(input_ids),
|
|
565
|
+
prompt_tps=prompt_tps,
|
|
566
|
+
generation_tokens=n + 1,
|
|
567
|
+
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
|
568
|
+
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
detokenizer.finalize()
|
|
572
|
+
yield GenerationResponse(
|
|
573
|
+
text=detokenizer.last_segment,
|
|
574
|
+
token=token,
|
|
575
|
+
logprobs=logprobs,
|
|
576
|
+
prompt_tokens=len(input_ids),
|
|
577
|
+
prompt_tps=prompt_tps,
|
|
578
|
+
generation_tokens=n + 1,
|
|
579
|
+
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
|
580
|
+
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
def _prepare_inputs(
|
|
584
|
+
self, prompt: Union[str, Dict[str, Any]], kwargs
|
|
585
|
+
) -> Tuple[Any, int]:
|
|
586
|
+
from mlx_vlm import prepare_inputs
|
|
587
|
+
|
|
588
|
+
prompt_str = prompt.get("prompt") # type: ignore
|
|
589
|
+
images = prompt.get("multi_modal_data", {}).get("image") # type: ignore
|
|
590
|
+
if images and not isinstance(images, list):
|
|
591
|
+
images = [images]
|
|
592
|
+
if hasattr(self._model.config, "image_token_index"):
|
|
593
|
+
image_token_index = self._model.config.image_token_index
|
|
594
|
+
else:
|
|
595
|
+
image_token_index = None
|
|
596
|
+
|
|
597
|
+
if not images:
|
|
598
|
+
prompt = prompt["prompt"] # type: ignore
|
|
599
|
+
prompt_token_ids = self._tokenizer.encode(prompt)
|
|
600
|
+
prompt_token_ids = self._get_prompt_cache(
|
|
601
|
+
prompt_token_ids,
|
|
602
|
+
kwargs.get("lora_name"),
|
|
603
|
+
model=self._model.language_model,
|
|
604
|
+
)
|
|
605
|
+
return prompt_token_ids, len(prompt_token_ids)
|
|
606
|
+
else:
|
|
607
|
+
inputs = prepare_inputs(
|
|
608
|
+
None,
|
|
609
|
+
self._processor,
|
|
610
|
+
images,
|
|
611
|
+
prompt_str,
|
|
612
|
+
image_token_index,
|
|
613
|
+
kwargs.get("resize_shape"),
|
|
614
|
+
)
|
|
615
|
+
input_ids = inputs[0]
|
|
616
|
+
return inputs, len(input_ids)
|
|
617
|
+
|
|
618
|
+
def chat(
|
|
619
|
+
self,
|
|
620
|
+
messages: List[Dict],
|
|
621
|
+
generate_config: Optional[MLXGenerateConfig] = None,
|
|
622
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
623
|
+
messages = self._transform_messages(messages) # type: ignore
|
|
624
|
+
tools = generate_config.pop("tools", []) if generate_config else None
|
|
625
|
+
|
|
626
|
+
model_family = self.model_family.model_family or self.model_family.model_name
|
|
627
|
+
|
|
628
|
+
if "internvl2" not in model_family.lower():
|
|
629
|
+
from qwen_vl_utils import process_vision_info
|
|
630
|
+
|
|
631
|
+
full_context_kwargs = {}
|
|
632
|
+
if tools and model_family in QWEN_TOOL_CALL_FAMILY:
|
|
633
|
+
full_context_kwargs["tools"] = tools
|
|
634
|
+
assert self.model_family.chat_template is not None
|
|
635
|
+
prompt = self.get_full_context(
|
|
636
|
+
messages, self.model_family.chat_template, **full_context_kwargs
|
|
637
|
+
)
|
|
638
|
+
images, video_inputs = process_vision_info(messages)
|
|
639
|
+
if video_inputs:
|
|
640
|
+
raise ValueError("Not support video input now.")
|
|
641
|
+
else:
|
|
642
|
+
prompt, images = self.get_specific_prompt(model_family, messages) # type: ignore
|
|
643
|
+
|
|
644
|
+
if not images:
|
|
645
|
+
inputs = {
|
|
646
|
+
"prompt": prompt,
|
|
647
|
+
}
|
|
648
|
+
elif len(images) == 1:
|
|
649
|
+
inputs = {
|
|
650
|
+
"prompt": prompt,
|
|
651
|
+
"multi_modal_data": {"image": images[-1]}, # type: ignore
|
|
652
|
+
}
|
|
653
|
+
else:
|
|
654
|
+
inputs = {
|
|
655
|
+
"prompt": prompt,
|
|
656
|
+
"multi_modal_data": {"image": images}, # type: ignore
|
|
657
|
+
}
|
|
658
|
+
generate_config = self._sanitize_generate_config(generate_config)
|
|
659
|
+
|
|
660
|
+
stream = generate_config.get("stream", False)
|
|
661
|
+
if stream:
|
|
662
|
+
it = self.generate(inputs, generate_config)
|
|
663
|
+
assert isinstance(it, Iterator)
|
|
664
|
+
return self._to_chat_completion_chunks(it)
|
|
665
|
+
else:
|
|
666
|
+
c = self.generate(inputs, generate_config)
|
|
667
|
+
assert not isinstance(c, Iterator)
|
|
668
|
+
if tools:
|
|
669
|
+
return self._tool_calls_completion(self.model_family, self.model_uid, c)
|
|
670
|
+
return self._to_chat_completion(c)
|
|
@@ -61,7 +61,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
61
61
|
|
|
62
62
|
def _load_model(self, **kwargs):
|
|
63
63
|
try:
|
|
64
|
-
from transformers import
|
|
64
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
65
65
|
except ImportError:
|
|
66
66
|
error_message = "Failed to import module 'transformers'"
|
|
67
67
|
installation_guide = [
|
|
@@ -77,7 +77,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
77
77
|
encode_special_tokens=True,
|
|
78
78
|
revision=kwargs["revision"],
|
|
79
79
|
)
|
|
80
|
-
model =
|
|
80
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
81
81
|
self.model_path,
|
|
82
82
|
**kwargs,
|
|
83
83
|
)
|
|
@@ -232,9 +232,11 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
232
232
|
content = {
|
|
233
233
|
"name": function_name,
|
|
234
234
|
"arguments": json.dumps(
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
235
|
+
(
|
|
236
|
+
arguments_json
|
|
237
|
+
if isinstance(arguments_json, dict)
|
|
238
|
+
else arguments
|
|
239
|
+
),
|
|
238
240
|
ensure_ascii=False,
|
|
239
241
|
),
|
|
240
242
|
}
|
|
@@ -331,6 +333,8 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
331
333
|
max_new_tokens = generate_config.get("max_tokens")
|
|
332
334
|
if max_new_tokens is not None:
|
|
333
335
|
kwargs["max_new_tokens"] = int(max_new_tokens)
|
|
336
|
+
else:
|
|
337
|
+
kwargs["max_new_tokens"] = 1024
|
|
334
338
|
do_sample = generate_config.get("do_sample")
|
|
335
339
|
if do_sample is not None:
|
|
336
340
|
kwargs["do_sample"] = bool(do_sample)
|
|
@@ -47,6 +47,8 @@ class Qwen2VLChatModel(PytorchChatModel):
|
|
|
47
47
|
llm_family = model_family.model_family or model_family.model_name
|
|
48
48
|
if "qwen2-vl-instruct".lower() in llm_family.lower():
|
|
49
49
|
return True
|
|
50
|
+
if "qvq-72b-preview".lower() in llm_family.lower():
|
|
51
|
+
return True
|
|
50
52
|
return False
|
|
51
53
|
|
|
52
54
|
def load(self):
|
|
@@ -156,6 +156,7 @@ def _get_completion(
|
|
|
156
156
|
finish_reason: Optional[str],
|
|
157
157
|
model_uid: str,
|
|
158
158
|
r: InferenceRequest,
|
|
159
|
+
completion_tokens: int,
|
|
159
160
|
):
|
|
160
161
|
completion_choice = CompletionChoice(
|
|
161
162
|
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
@@ -170,8 +171,8 @@ def _get_completion(
|
|
|
170
171
|
)
|
|
171
172
|
completion_usage = CompletionUsage(
|
|
172
173
|
prompt_tokens=len(r.prompt_tokens),
|
|
173
|
-
completion_tokens=
|
|
174
|
-
total_tokens=len(r.prompt_tokens) +
|
|
174
|
+
completion_tokens=completion_tokens,
|
|
175
|
+
total_tokens=len(r.prompt_tokens) + completion_tokens,
|
|
175
176
|
)
|
|
176
177
|
completion = Completion(
|
|
177
178
|
id=completion_chunk["id"],
|
|
@@ -371,7 +372,7 @@ def _batch_inference_one_step_internal(
|
|
|
371
372
|
r.stopped = stopped
|
|
372
373
|
r.finish_reason = finish_reason
|
|
373
374
|
|
|
374
|
-
if r.stopped and r not in stop_token_mapping
|
|
375
|
+
if r.stopped and r not in stop_token_mapping:
|
|
375
376
|
stop_token_mapping[r] = _i + 1
|
|
376
377
|
|
|
377
378
|
if r.stream:
|
|
@@ -446,12 +447,14 @@ def _batch_inference_one_step_internal(
|
|
|
446
447
|
else:
|
|
447
448
|
# last round, handle non-stream result
|
|
448
449
|
if r.stopped and _i == decode_round - 1:
|
|
449
|
-
invalid_token_num =
|
|
450
|
+
invalid_token_num = (
|
|
451
|
+
(decode_round - stop_token_mapping[r] + 1)
|
|
452
|
+
if r.finish_reason == "stop"
|
|
453
|
+
else (decode_round - stop_token_mapping[r])
|
|
454
|
+
)
|
|
450
455
|
outputs = (
|
|
451
456
|
tokenizer.decode(
|
|
452
|
-
r.new_tokens[
|
|
453
|
-
if r.finish_reason == "stop"
|
|
454
|
-
else r.new_tokens[:-invalid_token_num],
|
|
457
|
+
r.new_tokens[:-invalid_token_num],
|
|
455
458
|
skip_special_tokens=True,
|
|
456
459
|
spaces_between_special_tokens=False,
|
|
457
460
|
clean_up_tokenization_spaces=True,
|
|
@@ -460,7 +463,12 @@ def _batch_inference_one_step_internal(
|
|
|
460
463
|
else output_mapping[r]
|
|
461
464
|
)
|
|
462
465
|
completion = _get_completion(
|
|
463
|
-
outputs,
|
|
466
|
+
outputs,
|
|
467
|
+
r.chunk_id,
|
|
468
|
+
r.finish_reason,
|
|
469
|
+
model_uid,
|
|
470
|
+
r,
|
|
471
|
+
len(r.new_tokens) - invalid_token_num,
|
|
464
472
|
)
|
|
465
473
|
r.completion = [completion]
|
|
466
474
|
|
xinference/model/llm/utils.py
CHANGED
|
@@ -52,6 +52,7 @@ QWEN_TOOL_CALL_FAMILY = [
|
|
|
52
52
|
"qwen2-instruct",
|
|
53
53
|
"qwen2-moe-instruct",
|
|
54
54
|
"qwen2.5-instruct",
|
|
55
|
+
"qwen2.5-coder-instruct",
|
|
55
56
|
]
|
|
56
57
|
|
|
57
58
|
GLM4_TOOL_CALL_FAMILY = [
|
|
@@ -324,7 +325,10 @@ class ChatModelMixin:
|
|
|
324
325
|
"""
|
|
325
326
|
try:
|
|
326
327
|
if isinstance(c, dict):
|
|
327
|
-
|
|
328
|
+
try:
|
|
329
|
+
return [(None, c["name"], json.loads(c["arguments"]))]
|
|
330
|
+
except Exception:
|
|
331
|
+
return [(None, c["name"], c["arguments"])]
|
|
328
332
|
except KeyError:
|
|
329
333
|
logger.error("Can't parse glm output: %s", c)
|
|
330
334
|
return [(str(c), None, None)]
|
|
@@ -70,6 +70,7 @@ class VLLMModelConfig(TypedDict, total=False):
|
|
|
70
70
|
max_model_len: Optional[int]
|
|
71
71
|
limit_mm_per_prompt: Optional[Dict[str, int]]
|
|
72
72
|
guided_decoding_backend: Optional[str]
|
|
73
|
+
scheduling_policy: Optional[str]
|
|
73
74
|
|
|
74
75
|
|
|
75
76
|
class VLLMGenerateConfig(TypedDict, total=False):
|
|
@@ -86,6 +87,7 @@ class VLLMGenerateConfig(TypedDict, total=False):
|
|
|
86
87
|
stop: Optional[Union[str, List[str]]]
|
|
87
88
|
stream: bool # non-sampling param, should not be passed to the engine.
|
|
88
89
|
stream_options: Optional[Union[dict, None]]
|
|
90
|
+
skip_special_tokens: Optional[bool]
|
|
89
91
|
response_format: Optional[dict]
|
|
90
92
|
guided_json: Optional[Union[str, dict]]
|
|
91
93
|
guided_regex: Optional[str]
|
|
@@ -181,14 +183,19 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.5.3":
|
|
|
181
183
|
if VLLM_INSTALLED and vllm.__version__ > "0.5.3":
|
|
182
184
|
VLLM_SUPPORTED_MODELS.append("llama-3.1")
|
|
183
185
|
VLLM_SUPPORTED_CHAT_MODELS.append("llama-3.1-instruct")
|
|
186
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("llama-3.3-instruct")
|
|
184
187
|
|
|
185
188
|
if VLLM_INSTALLED and vllm.__version__ >= "0.6.1":
|
|
186
189
|
VLLM_SUPPORTED_VISION_MODEL_LIST.append("internvl2")
|
|
187
190
|
|
|
191
|
+
if VLLM_INSTALLED and vllm.__version__ >= "0.6.2":
|
|
192
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("minicpm3-4b")
|
|
193
|
+
|
|
188
194
|
if VLLM_INSTALLED and vllm.__version__ >= "0.6.3":
|
|
189
195
|
VLLM_SUPPORTED_MODELS.append("llama-3.2-vision")
|
|
190
196
|
VLLM_SUPPORTED_VISION_MODEL_LIST.append("llama-3.2-vision-instruct")
|
|
191
197
|
VLLM_SUPPORTED_VISION_MODEL_LIST.append("qwen2-vl-instruct")
|
|
198
|
+
VLLM_SUPPORTED_VISION_MODEL_LIST.append("QvQ-72B-Preview")
|
|
192
199
|
|
|
193
200
|
|
|
194
201
|
class VLLMModel(LLM):
|
|
@@ -242,7 +249,6 @@ class VLLMModel(LLM):
|
|
|
242
249
|
multiprocessing.set_start_method("fork", force=True)
|
|
243
250
|
|
|
244
251
|
self._model_config = self._sanitize_model_config(self._model_config)
|
|
245
|
-
|
|
246
252
|
if self.lora_modules is None:
|
|
247
253
|
self.lora_requests = []
|
|
248
254
|
else:
|
|
@@ -325,7 +331,9 @@ class VLLMModel(LLM):
|
|
|
325
331
|
model_config.setdefault("quantization", None)
|
|
326
332
|
model_config.setdefault("max_model_len", None)
|
|
327
333
|
model_config.setdefault("guided_decoding_backend", "outlines")
|
|
328
|
-
|
|
334
|
+
# Add scheduling policy if vLLM version is 0.6.3 or higher
|
|
335
|
+
if vllm.__version__ >= "0.6.3":
|
|
336
|
+
model_config.setdefault("scheduling_policy", "fcfs")
|
|
329
337
|
return model_config
|
|
330
338
|
|
|
331
339
|
@staticmethod
|
|
@@ -373,6 +381,9 @@ class VLLMModel(LLM):
|
|
|
373
381
|
sanitized.setdefault(
|
|
374
382
|
"stream_options", generate_config.get("stream_options", None)
|
|
375
383
|
)
|
|
384
|
+
sanitized.setdefault(
|
|
385
|
+
"skip_special_tokens", generate_config.get("skip_special_tokens", True)
|
|
386
|
+
)
|
|
376
387
|
sanitized.setdefault(
|
|
377
388
|
"guided_json", generate_config.get("guided_json", guided_json)
|
|
378
389
|
)
|
|
@@ -854,6 +865,9 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
|
|
|
854
865
|
"image": 2, # default 2 images all chat
|
|
855
866
|
}
|
|
856
867
|
)
|
|
868
|
+
# Add scheduling policy if vLLM version is 0.6.3 or higher
|
|
869
|
+
if vllm.__version__ >= "0.6.3":
|
|
870
|
+
model_config.setdefault("scheduling_policy", "fcfs")
|
|
857
871
|
|
|
858
872
|
return model_config
|
|
859
873
|
|