xinference 0.13.2__py3-none-any.whl → 0.13.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -1
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +30 -5
- xinference/client/restful/restful_client.py +18 -3
- xinference/constants.py +0 -4
- xinference/core/chat_interface.py +2 -2
- xinference/core/image_interface.py +6 -3
- xinference/core/model.py +9 -4
- xinference/core/scheduler.py +4 -4
- xinference/core/supervisor.py +2 -0
- xinference/core/worker.py +7 -0
- xinference/deploy/utils.py +6 -0
- xinference/model/audio/core.py +9 -4
- xinference/model/audio/cosyvoice.py +136 -0
- xinference/model/audio/model_spec.json +24 -0
- xinference/model/audio/model_spec_modelscope.json +27 -0
- xinference/model/core.py +25 -4
- xinference/model/embedding/core.py +88 -13
- xinference/model/embedding/model_spec.json +8 -0
- xinference/model/embedding/model_spec_modelscope.json +8 -0
- xinference/model/flexible/core.py +8 -2
- xinference/model/flexible/launchers/__init__.py +1 -0
- xinference/model/flexible/launchers/image_process_launcher.py +70 -0
- xinference/model/image/core.py +8 -5
- xinference/model/image/model_spec.json +36 -5
- xinference/model/image/model_spec_modelscope.json +21 -3
- xinference/model/image/stable_diffusion/core.py +36 -28
- xinference/model/llm/core.py +6 -4
- xinference/model/llm/ggml/llamacpp.py +7 -5
- xinference/model/llm/llm_family.json +802 -82
- xinference/model/llm/llm_family.py +6 -6
- xinference/model/llm/llm_family_csghub.json +39 -0
- xinference/model/llm/llm_family_modelscope.json +295 -47
- xinference/model/llm/mlx/core.py +7 -0
- xinference/model/llm/pytorch/chatglm.py +246 -5
- xinference/model/llm/pytorch/cogvlm2.py +1 -1
- xinference/model/llm/pytorch/deepseek_vl.py +2 -1
- xinference/model/llm/pytorch/falcon.py +2 -1
- xinference/model/llm/pytorch/llama_2.py +4 -2
- xinference/model/llm/pytorch/omnilmm.py +2 -1
- xinference/model/llm/pytorch/qwen_vl.py +2 -1
- xinference/model/llm/pytorch/vicuna.py +2 -1
- xinference/model/llm/pytorch/yi_vl.py +2 -1
- xinference/model/llm/sglang/core.py +12 -6
- xinference/model/llm/utils.py +78 -1
- xinference/model/llm/vllm/core.py +9 -5
- xinference/model/rerank/core.py +4 -3
- xinference/thirdparty/cosyvoice/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
- xinference/thirdparty/cosyvoice/bin/train.py +136 -0
- xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
- xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
- xinference/thirdparty/cosyvoice/cli/model.py +60 -0
- xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
- xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
- xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
- xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
- xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
- xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
- xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
- xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
- xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
- xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
- xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
- xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
- xinference/thirdparty/cosyvoice/utils/common.py +103 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.95c1d652.js → main.af906659.js} +3 -3
- xinference/web/ui/build/static/js/main.af906659.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2cd5e4279ad7e13a1f41d486e9fca7756295bfad5bd77d90992f4ac3e10b496d.json +1 -0
- {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/METADATA +39 -11
- {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/RECORD +101 -57
- xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
- /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.af906659.js.LICENSE.txt} +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/LICENSE +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/WHEEL +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/top_level.txt +0 -0
|
@@ -11,10 +11,17 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import copy
|
|
15
|
+
import json
|
|
16
|
+
import threading
|
|
14
17
|
import time
|
|
15
18
|
import uuid
|
|
16
19
|
from typing import Any, Dict, Iterator, List, Optional, Union
|
|
17
20
|
|
|
21
|
+
import torch
|
|
22
|
+
from transformers.generation.logits_process import LogitsProcessor
|
|
23
|
+
from transformers.generation.utils import LogitsProcessorList
|
|
24
|
+
|
|
18
25
|
from ....core.scheduler import InferenceRequest
|
|
19
26
|
from ....types import (
|
|
20
27
|
SPECIAL_TOOL_PROMPT,
|
|
@@ -33,6 +40,16 @@ from ..utils import GLM4_TOOL_CALL_FAMILY
|
|
|
33
40
|
from .core import PytorchChatModel, PytorchModelConfig
|
|
34
41
|
|
|
35
42
|
|
|
43
|
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
|
44
|
+
def __call__(
|
|
45
|
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
|
46
|
+
) -> torch.FloatTensor:
|
|
47
|
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
|
48
|
+
scores.zero_()
|
|
49
|
+
scores[..., 198] = 5e4
|
|
50
|
+
return scores
|
|
51
|
+
|
|
52
|
+
|
|
36
53
|
class ChatglmPytorchChatModel(PytorchChatModel):
|
|
37
54
|
def __init__(
|
|
38
55
|
self,
|
|
@@ -103,9 +120,11 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
103
120
|
tools = generate_config.pop("tools", None)
|
|
104
121
|
if tools is None:
|
|
105
122
|
return False
|
|
123
|
+
# Convert a iterable to a list
|
|
124
|
+
tools = list(tools)
|
|
106
125
|
tool_choice = generate_config.pop("tool_choice", "none")
|
|
107
126
|
if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
|
|
108
|
-
chat_history[:] = self.
|
|
127
|
+
chat_history[:] = self._process_messages(
|
|
109
128
|
chat_history, tools=tools, tool_choice=tool_choice
|
|
110
129
|
)
|
|
111
130
|
return True
|
|
@@ -124,7 +143,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
124
143
|
return True
|
|
125
144
|
|
|
126
145
|
@staticmethod
|
|
127
|
-
def
|
|
146
|
+
def _process_messages(messages, tools=None, tool_choice="none"):
|
|
128
147
|
# This method is adapted from https://github.com/THUDM/GLM-4/blob/main/basic_demo/openai_api_server.py
|
|
129
148
|
_messages = messages
|
|
130
149
|
processed_messages = []
|
|
@@ -210,6 +229,212 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
210
229
|
break
|
|
211
230
|
return processed_messages
|
|
212
231
|
|
|
232
|
+
@staticmethod
|
|
233
|
+
def _process_response(output, history, tools, end=False):
|
|
234
|
+
# Copy from https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/modeling_chatglm.py
|
|
235
|
+
content = ""
|
|
236
|
+
history = copy.deepcopy(history)
|
|
237
|
+
if not tools and end:
|
|
238
|
+
return None, None
|
|
239
|
+
for response in output.split("<|assistant|>"):
|
|
240
|
+
if "\n" in response:
|
|
241
|
+
metadata, content = response.split("\n", maxsplit=1)
|
|
242
|
+
else:
|
|
243
|
+
metadata, content = "", response
|
|
244
|
+
if not metadata.strip():
|
|
245
|
+
if tools and any(t.startswith(response) for t in tools) and not end:
|
|
246
|
+
# Waiting for tool call complete.
|
|
247
|
+
return None, None
|
|
248
|
+
content = content.strip()
|
|
249
|
+
history.append(
|
|
250
|
+
{"role": "assistant", "metadata": metadata, "content": content}
|
|
251
|
+
)
|
|
252
|
+
content = content.replace("[[训练时间]]", "2023年")
|
|
253
|
+
else:
|
|
254
|
+
if tools and metadata in tools and not end:
|
|
255
|
+
return None, None
|
|
256
|
+
history.append(
|
|
257
|
+
{"role": "assistant", "metadata": metadata, "content": content}
|
|
258
|
+
)
|
|
259
|
+
metadata = metadata.strip()
|
|
260
|
+
if tools and metadata in tools and end:
|
|
261
|
+
try:
|
|
262
|
+
parameters = json.loads(content)
|
|
263
|
+
content = {"name": metadata.strip(), "parameters": parameters}
|
|
264
|
+
except json.JSONDecodeError:
|
|
265
|
+
content = {"name": metadata.strip(), "content": content}
|
|
266
|
+
else:
|
|
267
|
+
content = {"name": metadata.strip(), "content": content}
|
|
268
|
+
return content, history
|
|
269
|
+
|
|
270
|
+
def _get_generate_args(
|
|
271
|
+
self,
|
|
272
|
+
tokenizer,
|
|
273
|
+
query: str,
|
|
274
|
+
history: Optional[List[Dict]] = None,
|
|
275
|
+
role: str = "user",
|
|
276
|
+
past_key_values=None,
|
|
277
|
+
max_length: int = 8192,
|
|
278
|
+
do_sample=True,
|
|
279
|
+
top_p=0.8,
|
|
280
|
+
temperature=0.8,
|
|
281
|
+
logits_processor=None,
|
|
282
|
+
**kwargs,
|
|
283
|
+
):
|
|
284
|
+
# Copy from https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/modeling_chatglm.py
|
|
285
|
+
if history is None:
|
|
286
|
+
history = []
|
|
287
|
+
if logits_processor is None:
|
|
288
|
+
logits_processor = LogitsProcessorList()
|
|
289
|
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
|
290
|
+
eos_token_id = [
|
|
291
|
+
tokenizer.eos_token_id,
|
|
292
|
+
tokenizer.convert_tokens_to_ids("<|user|>"),
|
|
293
|
+
tokenizer.convert_tokens_to_ids("<|observation|>"),
|
|
294
|
+
]
|
|
295
|
+
gen_kwargs = {
|
|
296
|
+
"max_length": max_length,
|
|
297
|
+
"do_sample": do_sample,
|
|
298
|
+
"top_p": top_p,
|
|
299
|
+
"temperature": temperature,
|
|
300
|
+
"logits_processor": logits_processor,
|
|
301
|
+
**kwargs,
|
|
302
|
+
}
|
|
303
|
+
if past_key_values is None:
|
|
304
|
+
inputs = tokenizer.apply_chat_template(
|
|
305
|
+
history + [{"role": role, "content": query}],
|
|
306
|
+
add_generation_prompt=True,
|
|
307
|
+
tokenize=True,
|
|
308
|
+
return_tensors="pt",
|
|
309
|
+
return_dict=True,
|
|
310
|
+
)
|
|
311
|
+
else:
|
|
312
|
+
inputs = tokenizer.apply_chat_template(
|
|
313
|
+
[{"role": role, "content": query}],
|
|
314
|
+
add_special_tokens=False,
|
|
315
|
+
add_generation_prompt=True,
|
|
316
|
+
tokenize=True,
|
|
317
|
+
return_tensors="pt",
|
|
318
|
+
return_dict=True,
|
|
319
|
+
)
|
|
320
|
+
inputs = inputs.to(self._model.device)
|
|
321
|
+
if past_key_values is not None:
|
|
322
|
+
past_length = past_key_values[0][0].shape[2]
|
|
323
|
+
inputs.position_ids += past_length
|
|
324
|
+
attention_mask = inputs.attention_mask
|
|
325
|
+
attention_mask = torch.cat(
|
|
326
|
+
(attention_mask.new_ones(1, past_length), attention_mask), dim=1
|
|
327
|
+
)
|
|
328
|
+
inputs["attention_mask"] = attention_mask
|
|
329
|
+
history.append({"role": role, "content": query})
|
|
330
|
+
tools = history[0]["role"] == "system" and history[0].get("tools")
|
|
331
|
+
tools = (
|
|
332
|
+
[
|
|
333
|
+
t.get("function", {}).get("name", "")
|
|
334
|
+
for t in tools
|
|
335
|
+
if isinstance(t, dict)
|
|
336
|
+
]
|
|
337
|
+
if tools
|
|
338
|
+
else []
|
|
339
|
+
)
|
|
340
|
+
kwargs = dict(inputs)
|
|
341
|
+
kwargs["past_key_values"] = past_key_values
|
|
342
|
+
kwargs["eos_token_id"] = eos_token_id
|
|
343
|
+
kwargs.update(gen_kwargs)
|
|
344
|
+
return kwargs, tools
|
|
345
|
+
|
|
346
|
+
@torch.inference_mode()
|
|
347
|
+
def stream_chat(
|
|
348
|
+
self,
|
|
349
|
+
tokenizer,
|
|
350
|
+
query: str,
|
|
351
|
+
history: Optional[List[Dict]] = None,
|
|
352
|
+
role: str = "user",
|
|
353
|
+
past_key_values=None,
|
|
354
|
+
max_length: int = 8192,
|
|
355
|
+
do_sample=True,
|
|
356
|
+
top_p=0.8,
|
|
357
|
+
temperature=0.8,
|
|
358
|
+
logits_processor=None,
|
|
359
|
+
**kwargs,
|
|
360
|
+
):
|
|
361
|
+
from transformers import TextIteratorStreamer
|
|
362
|
+
|
|
363
|
+
kwargs, tools = self._get_generate_args(
|
|
364
|
+
tokenizer=tokenizer,
|
|
365
|
+
query=query,
|
|
366
|
+
history=history,
|
|
367
|
+
role=role,
|
|
368
|
+
past_key_values=past_key_values,
|
|
369
|
+
max_length=max_length,
|
|
370
|
+
do_sample=do_sample,
|
|
371
|
+
top_p=top_p,
|
|
372
|
+
temperature=temperature,
|
|
373
|
+
logits_processor=logits_processor,
|
|
374
|
+
**kwargs,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
streamer = TextIteratorStreamer(
|
|
378
|
+
tokenizer, skip_prompt=True, skip_special_tokens=True
|
|
379
|
+
)
|
|
380
|
+
kwargs["streamer"] = streamer
|
|
381
|
+
thread = threading.Thread(target=self._model.generate, kwargs=kwargs)
|
|
382
|
+
thread.start()
|
|
383
|
+
|
|
384
|
+
response = ""
|
|
385
|
+
for token in streamer:
|
|
386
|
+
response += token
|
|
387
|
+
if response and response[-1] != "�":
|
|
388
|
+
new_response, new_history = self._process_response(
|
|
389
|
+
response, history, tools, end=False
|
|
390
|
+
)
|
|
391
|
+
if new_response is None:
|
|
392
|
+
continue
|
|
393
|
+
yield new_response, new_history
|
|
394
|
+
if tools:
|
|
395
|
+
new_response, new_history = self._process_response(
|
|
396
|
+
response, history, tools, end=True
|
|
397
|
+
)
|
|
398
|
+
if new_response:
|
|
399
|
+
yield new_response, new_history
|
|
400
|
+
|
|
401
|
+
@torch.inference_mode()
|
|
402
|
+
def non_stream_chat(
|
|
403
|
+
self,
|
|
404
|
+
tokenizer,
|
|
405
|
+
query: str,
|
|
406
|
+
history: Optional[List[Dict]] = None,
|
|
407
|
+
role: str = "user",
|
|
408
|
+
past_key_values=None,
|
|
409
|
+
max_length: int = 8192,
|
|
410
|
+
do_sample=True,
|
|
411
|
+
top_p=0.8,
|
|
412
|
+
temperature=0.8,
|
|
413
|
+
logits_processor=None,
|
|
414
|
+
**kwargs,
|
|
415
|
+
):
|
|
416
|
+
kwargs, tools = self._get_generate_args(
|
|
417
|
+
tokenizer=tokenizer,
|
|
418
|
+
query=query,
|
|
419
|
+
history=history,
|
|
420
|
+
role=role,
|
|
421
|
+
past_key_values=past_key_values,
|
|
422
|
+
max_length=max_length,
|
|
423
|
+
do_sample=do_sample,
|
|
424
|
+
top_p=top_p,
|
|
425
|
+
temperature=temperature,
|
|
426
|
+
logits_processor=logits_processor,
|
|
427
|
+
**kwargs,
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
outputs = self._model.generate(**kwargs)
|
|
431
|
+
outputs = outputs[:, kwargs["input_ids"].shape[1] :]
|
|
432
|
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
433
|
+
if tools:
|
|
434
|
+
return self._process_response(response, history, tools, end=True)
|
|
435
|
+
else:
|
|
436
|
+
return self._process_response(response, history, tools)
|
|
437
|
+
|
|
213
438
|
def chat(
|
|
214
439
|
self,
|
|
215
440
|
prompt: str,
|
|
@@ -247,7 +472,13 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
247
472
|
if isinstance(stream_options, dict)
|
|
248
473
|
else False
|
|
249
474
|
)
|
|
250
|
-
if stream and
|
|
475
|
+
if stream and (
|
|
476
|
+
not tools or self.model_family.model_name in GLM4_TOOL_CALL_FAMILY
|
|
477
|
+
):
|
|
478
|
+
if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
|
|
479
|
+
stream_chat = self.stream_chat
|
|
480
|
+
else:
|
|
481
|
+
stream_chat = self._model.stream_chat
|
|
251
482
|
|
|
252
483
|
def _stream_generator():
|
|
253
484
|
last_chunk_text_length = 0
|
|
@@ -256,9 +487,14 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
256
487
|
inputs = self._tokenizer([prompt], return_tensors="pt")
|
|
257
488
|
inputs = inputs.to(self._model.device)
|
|
258
489
|
prompt_tokens = len(inputs["input_ids"][0])
|
|
259
|
-
for chunk_text, _ in
|
|
490
|
+
for chunk_text, _ in stream_chat(
|
|
260
491
|
self._tokenizer, prompt, chat_history, **kwargs
|
|
261
492
|
):
|
|
493
|
+
if tools and isinstance(chunk_text, dict):
|
|
494
|
+
yield self._tool_calls_completion_chunk(
|
|
495
|
+
self.model_family, self.model_uid, [chunk_text, _], tools
|
|
496
|
+
)
|
|
497
|
+
return
|
|
262
498
|
completion_tokens = completion_tokens + 1
|
|
263
499
|
total_tokens = prompt_tokens + completion_tokens
|
|
264
500
|
chunk_text = chunk_text[last_chunk_text_length:]
|
|
@@ -312,7 +548,12 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
312
548
|
|
|
313
549
|
return self._to_chat_completion_chunks(_stream_generator())
|
|
314
550
|
else:
|
|
315
|
-
|
|
551
|
+
if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
|
|
552
|
+
chat = self.non_stream_chat
|
|
553
|
+
else:
|
|
554
|
+
chat = self._model.chat
|
|
555
|
+
|
|
556
|
+
response = chat(self._tokenizer, prompt, chat_history, **kwargs)
|
|
316
557
|
if tools:
|
|
317
558
|
return self._tool_calls_completion(
|
|
318
559
|
self.model_family, self.model_uid, response, tools
|
|
@@ -387,7 +387,7 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
387
387
|
prompt, system_prompt=system_prompt, chat_history=chat_history
|
|
388
388
|
)
|
|
389
389
|
|
|
390
|
-
input_by_model: dict = self._model.build_conversation_input_ids(
|
|
390
|
+
input_by_model: dict = self._model.build_conversation_input_ids( # type: ignore
|
|
391
391
|
self._tokenizer,
|
|
392
392
|
query=query,
|
|
393
393
|
history=history,
|
|
@@ -52,7 +52,8 @@ class DeepSeekVLChatModel(PytorchChatModel):
|
|
|
52
52
|
def match(
|
|
53
53
|
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
54
54
|
) -> bool:
|
|
55
|
-
|
|
55
|
+
llm_family = model_family.model_family or model_family.model_name
|
|
56
|
+
if "deepseek-vl" in llm_family:
|
|
56
57
|
return True
|
|
57
58
|
return False
|
|
58
59
|
|
|
@@ -71,7 +71,8 @@ class FalconPytorchModel(PytorchModel):
|
|
|
71
71
|
) -> bool:
|
|
72
72
|
if llm_spec.model_format != "pytorch":
|
|
73
73
|
return False
|
|
74
|
-
|
|
74
|
+
model_family = llm_family.model_family or llm_family.model_name
|
|
75
|
+
if "falcon" not in model_family:
|
|
75
76
|
return False
|
|
76
77
|
if "generate" not in llm_family.model_ability:
|
|
77
78
|
return False
|
|
@@ -55,7 +55,8 @@ class LlamaPytorchModel(PytorchModel):
|
|
|
55
55
|
) -> bool:
|
|
56
56
|
if llm_spec.model_format != "pytorch":
|
|
57
57
|
return False
|
|
58
|
-
|
|
58
|
+
model_family = llm_family.model_family or llm_family.model_name
|
|
59
|
+
if "llama-2" not in model_family:
|
|
59
60
|
return False
|
|
60
61
|
if "generate" not in llm_family.model_ability:
|
|
61
62
|
return False
|
|
@@ -99,7 +100,8 @@ class LlamaPytorchChatModel(PytorchChatModel):
|
|
|
99
100
|
) -> bool:
|
|
100
101
|
if llm_spec.model_format != "pytorch":
|
|
101
102
|
return False
|
|
102
|
-
|
|
103
|
+
model_family = llm_family.model_family or llm_family.model_name
|
|
104
|
+
if "llama-2" not in model_family:
|
|
103
105
|
return False
|
|
104
106
|
if "chat" not in llm_family.model_ability:
|
|
105
107
|
return False
|
|
@@ -44,7 +44,8 @@ class OmniLMMModel(PytorchChatModel):
|
|
|
44
44
|
def match(
|
|
45
45
|
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
46
46
|
) -> bool:
|
|
47
|
-
|
|
47
|
+
llm_family = model_family.model_family or model_family.model_name
|
|
48
|
+
if "OmniLMM" in llm_family:
|
|
48
49
|
return True
|
|
49
50
|
return False
|
|
50
51
|
|
|
@@ -52,7 +52,8 @@ class QwenVLChatModel(PytorchChatModel):
|
|
|
52
52
|
def match(
|
|
53
53
|
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
54
54
|
) -> bool:
|
|
55
|
-
|
|
55
|
+
llm_family = model_family.model_family or model_family.model_name
|
|
56
|
+
if "qwen" in llm_family and "vision" in model_family.model_ability:
|
|
56
57
|
return True
|
|
57
58
|
return False
|
|
58
59
|
|
|
@@ -61,7 +61,8 @@ class VicunaPytorchChatModel(PytorchChatModel):
|
|
|
61
61
|
) -> bool:
|
|
62
62
|
if llm_spec.model_format != "pytorch":
|
|
63
63
|
return False
|
|
64
|
-
|
|
64
|
+
model_family = llm_family.model_family or llm_family.model_name
|
|
65
|
+
if "vicuna" not in model_family:
|
|
65
66
|
return False
|
|
66
67
|
if "chat" not in llm_family.model_ability:
|
|
67
68
|
return False
|
|
@@ -51,7 +51,8 @@ class YiVLChatModel(PytorchChatModel):
|
|
|
51
51
|
def match(
|
|
52
52
|
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
53
53
|
) -> bool:
|
|
54
|
-
|
|
54
|
+
llm_family = model_family.model_family or model_family.model_name
|
|
55
|
+
if "yi-vl" in llm_family:
|
|
55
56
|
return True
|
|
56
57
|
return False
|
|
57
58
|
|
|
@@ -17,7 +17,6 @@ import time
|
|
|
17
17
|
import uuid
|
|
18
18
|
from typing import AsyncGenerator, Dict, List, Optional, TypedDict, Union
|
|
19
19
|
|
|
20
|
-
from ....constants import XINFERENCE_ENABLE_SGLANG
|
|
21
20
|
from ....types import (
|
|
22
21
|
ChatCompletion,
|
|
23
22
|
ChatCompletionChunk,
|
|
@@ -63,15 +62,26 @@ try:
|
|
|
63
62
|
except ImportError:
|
|
64
63
|
SGLANG_INSTALLED = False
|
|
65
64
|
|
|
66
|
-
SGLANG_SUPPORTED_MODELS = [
|
|
65
|
+
SGLANG_SUPPORTED_MODELS = [
|
|
66
|
+
"llama-2",
|
|
67
|
+
"llama-3",
|
|
68
|
+
"llama-3.1",
|
|
69
|
+
"mistral-v0.1",
|
|
70
|
+
"mixtral-v0.1",
|
|
71
|
+
]
|
|
67
72
|
SGLANG_SUPPORTED_CHAT_MODELS = [
|
|
68
73
|
"llama-2-chat",
|
|
74
|
+
"llama-3-instruct",
|
|
75
|
+
"llama-3.1-instruct",
|
|
69
76
|
"qwen-chat",
|
|
70
77
|
"qwen1.5-chat",
|
|
78
|
+
"qwen2-instruct",
|
|
79
|
+
"qwen2-moe-instruct",
|
|
71
80
|
"mistral-instruct-v0.1",
|
|
72
81
|
"mistral-instruct-v0.2",
|
|
73
82
|
"mixtral-instruct-v0.1",
|
|
74
83
|
"gemma-it",
|
|
84
|
+
"gemma-2-it",
|
|
75
85
|
]
|
|
76
86
|
|
|
77
87
|
|
|
@@ -168,8 +178,6 @@ class SGLANGModel(LLM):
|
|
|
168
178
|
def match(
|
|
169
179
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
170
180
|
) -> bool:
|
|
171
|
-
if not XINFERENCE_ENABLE_SGLANG:
|
|
172
|
-
return False
|
|
173
181
|
if not cls._has_cuda_device():
|
|
174
182
|
return False
|
|
175
183
|
if not cls._is_linux():
|
|
@@ -332,8 +340,6 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
|
|
|
332
340
|
def match(
|
|
333
341
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
334
342
|
) -> bool:
|
|
335
|
-
if not XINFERENCE_ENABLE_SGLANG:
|
|
336
|
-
return False
|
|
337
343
|
if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
338
344
|
return False
|
|
339
345
|
if llm_spec.model_format == "pytorch":
|
xinference/model/llm/utils.py
CHANGED
|
@@ -483,11 +483,40 @@ Begin!"""
|
|
|
483
483
|
else:
|
|
484
484
|
ret += role
|
|
485
485
|
return ret
|
|
486
|
+
elif prompt_style.style_name == "mistral-nemo":
|
|
487
|
+
seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
|
|
488
|
+
ret = "<s>"
|
|
489
|
+
for i, message in enumerate(chat_history):
|
|
490
|
+
role = get_role(message["role"])
|
|
491
|
+
content = message["content"]
|
|
492
|
+
if content:
|
|
493
|
+
if i == len(chat_history) - 2 and prompt_style.system_prompt:
|
|
494
|
+
ret += (
|
|
495
|
+
role
|
|
496
|
+
+ " "
|
|
497
|
+
+ prompt_style.system_prompt
|
|
498
|
+
+ "\n\n"
|
|
499
|
+
+ content
|
|
500
|
+
+ seps[i % 2]
|
|
501
|
+
)
|
|
502
|
+
else:
|
|
503
|
+
ret += role + " " + content + seps[i % 2]
|
|
504
|
+
else:
|
|
505
|
+
ret += role
|
|
506
|
+
return ret
|
|
486
507
|
else:
|
|
487
508
|
raise ValueError(f"Invalid prompt style: {prompt_style.style_name}")
|
|
488
509
|
|
|
489
510
|
@classmethod
|
|
490
511
|
def _to_chat_completion_chunk(cls, chunk: CompletionChunk) -> ChatCompletionChunk:
|
|
512
|
+
choices = chunk.get("choices")
|
|
513
|
+
if (
|
|
514
|
+
chunk.get("object") == "chat.completion.chunk"
|
|
515
|
+
and choices
|
|
516
|
+
and "delta" in choices[0]
|
|
517
|
+
):
|
|
518
|
+
# Already a ChatCompletionChunk, we don't need to convert chunk.
|
|
519
|
+
return cast(ChatCompletionChunk, chunk)
|
|
491
520
|
chat_chunk = {
|
|
492
521
|
"id": "chat" + chunk["id"],
|
|
493
522
|
"model": chunk["model"],
|
|
@@ -497,7 +526,7 @@ Begin!"""
|
|
|
497
526
|
{
|
|
498
527
|
"index": i,
|
|
499
528
|
"delta": {
|
|
500
|
-
"content": choice
|
|
529
|
+
"content": choice.get("text"),
|
|
501
530
|
**(
|
|
502
531
|
{"tool_calls": choice["tool_calls"]}
|
|
503
532
|
if "tool_calls" in choice
|
|
@@ -718,6 +747,54 @@ Begin!"""
|
|
|
718
747
|
else:
|
|
719
748
|
return lambda tokens, delta: delta
|
|
720
749
|
|
|
750
|
+
@classmethod
|
|
751
|
+
def _tool_calls_completion_chunk(cls, model_family, model_uid, c, tools):
|
|
752
|
+
_id = str(uuid.uuid4())
|
|
753
|
+
content, func, args = cls._eval_tool_arguments(model_family, c, tools)
|
|
754
|
+
if func:
|
|
755
|
+
d = {
|
|
756
|
+
"role": "assistant",
|
|
757
|
+
"content": content,
|
|
758
|
+
"tool_calls": [
|
|
759
|
+
{
|
|
760
|
+
"id": f"call_{_id}",
|
|
761
|
+
"type": "function",
|
|
762
|
+
"function": {
|
|
763
|
+
"name": func,
|
|
764
|
+
"arguments": json.dumps(args),
|
|
765
|
+
},
|
|
766
|
+
}
|
|
767
|
+
],
|
|
768
|
+
}
|
|
769
|
+
finish_reason = "tool_calls"
|
|
770
|
+
else:
|
|
771
|
+
d = {"role": "assistant", "content": content, "tool_calls": []}
|
|
772
|
+
finish_reason = "stop"
|
|
773
|
+
try:
|
|
774
|
+
usage = c.get("usage")
|
|
775
|
+
assert "prompt_tokens" in usage
|
|
776
|
+
except Exception:
|
|
777
|
+
usage = {
|
|
778
|
+
"prompt_tokens": -1,
|
|
779
|
+
"completion_tokens": -1,
|
|
780
|
+
"total_tokens": -1,
|
|
781
|
+
}
|
|
782
|
+
return {
|
|
783
|
+
"id": "chat" + f"cmpl-{_id}",
|
|
784
|
+
"model": model_uid,
|
|
785
|
+
"object": "chat.completion.chunk",
|
|
786
|
+
"created": int(time.time()),
|
|
787
|
+
"choices": [
|
|
788
|
+
{
|
|
789
|
+
"index": 0,
|
|
790
|
+
"delta": d,
|
|
791
|
+
"logprobs": None,
|
|
792
|
+
"finish_reason": finish_reason,
|
|
793
|
+
}
|
|
794
|
+
],
|
|
795
|
+
"usage": usage,
|
|
796
|
+
}
|
|
797
|
+
|
|
721
798
|
@classmethod
|
|
722
799
|
def _tool_calls_completion(cls, model_family, model_uid, c, tools):
|
|
723
800
|
_id = str(uuid.uuid4())
|
|
@@ -28,7 +28,6 @@ from typing import (
|
|
|
28
28
|
Union,
|
|
29
29
|
)
|
|
30
30
|
|
|
31
|
-
from ....constants import XINFERENCE_DISABLE_VLLM
|
|
32
31
|
from ....types import (
|
|
33
32
|
ChatCompletion,
|
|
34
33
|
ChatCompletionChunk,
|
|
@@ -151,6 +150,15 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.4.0":
|
|
|
151
150
|
VLLM_SUPPORTED_CHAT_MODELS.append("qwen2-moe-instruct")
|
|
152
151
|
VLLM_SUPPORTED_CHAT_MODELS.append("c4ai-command-r-v01")
|
|
153
152
|
|
|
153
|
+
if VLLM_INSTALLED and vllm.__version__ >= "0.5.3":
|
|
154
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("gemma-2-it")
|
|
155
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("mistral-nemo-instruct")
|
|
156
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("mistral-large-instruct")
|
|
157
|
+
|
|
158
|
+
if VLLM_INSTALLED and vllm.__version__ > "0.5.3":
|
|
159
|
+
VLLM_SUPPORTED_MODELS.append("llama-3.1")
|
|
160
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("llama-3.1-instruct")
|
|
161
|
+
|
|
154
162
|
|
|
155
163
|
class VLLMModel(LLM):
|
|
156
164
|
def __init__(
|
|
@@ -288,8 +296,6 @@ class VLLMModel(LLM):
|
|
|
288
296
|
def match(
|
|
289
297
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
290
298
|
) -> bool:
|
|
291
|
-
if XINFERENCE_DISABLE_VLLM:
|
|
292
|
-
return False
|
|
293
299
|
if not cls._has_cuda_device():
|
|
294
300
|
return False
|
|
295
301
|
if not cls._is_linux():
|
|
@@ -514,8 +520,6 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
514
520
|
def match(
|
|
515
521
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
516
522
|
) -> bool:
|
|
517
|
-
if XINFERENCE_DISABLE_VLLM:
|
|
518
|
-
return False
|
|
519
523
|
if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
520
524
|
return False
|
|
521
525
|
if llm_spec.model_format == "pytorch":
|
xinference/model/rerank/core.py
CHANGED
|
@@ -107,7 +107,7 @@ class RerankModel:
|
|
|
107
107
|
self,
|
|
108
108
|
model_spec: RerankModelSpec,
|
|
109
109
|
model_uid: str,
|
|
110
|
-
model_path: str,
|
|
110
|
+
model_path: Optional[str] = None,
|
|
111
111
|
device: Optional[str] = None,
|
|
112
112
|
use_fp16: bool = False,
|
|
113
113
|
model_config: Optional[Dict] = None,
|
|
@@ -290,6 +290,7 @@ def create_rerank_model_instance(
|
|
|
290
290
|
model_uid: str,
|
|
291
291
|
model_name: str,
|
|
292
292
|
download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
|
|
293
|
+
model_path: Optional[str] = None,
|
|
293
294
|
**kwargs,
|
|
294
295
|
) -> Tuple[RerankModel, RerankModelDescription]:
|
|
295
296
|
from ..utils import download_from_modelscope
|
|
@@ -321,8 +322,8 @@ def create_rerank_model_instance(
|
|
|
321
322
|
f"Huggingface: {BUILTIN_RERANK_MODELS.keys()}"
|
|
322
323
|
f"ModelScope: {MODELSCOPE_RERANK_MODELS.keys()}"
|
|
323
324
|
)
|
|
324
|
-
|
|
325
|
-
|
|
325
|
+
if not model_path:
|
|
326
|
+
model_path = cache(model_spec)
|
|
326
327
|
use_fp16 = kwargs.pop("use_fp16", False)
|
|
327
328
|
model = RerankModel(
|
|
328
329
|
model_spec, model_uid, model_path, use_fp16=use_fp16, model_config=kwargs
|
|
File without changes
|
|
File without changes
|