xinference 0.12.0__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +74 -6
- xinference/client/restful/restful_client.py +74 -5
- xinference/constants.py +1 -0
- xinference/core/cache_tracker.py +48 -28
- xinference/core/model.py +54 -42
- xinference/core/scheduler.py +34 -16
- xinference/core/supervisor.py +73 -24
- xinference/core/worker.py +68 -2
- xinference/deploy/cmdline.py +86 -2
- xinference/deploy/test/test_cmdline.py +19 -10
- xinference/model/audio/__init__.py +14 -1
- xinference/model/audio/core.py +12 -1
- xinference/model/audio/custom.py +6 -4
- xinference/model/audio/model_spec_modelscope.json +20 -0
- xinference/model/llm/__init__.py +34 -2
- xinference/model/llm/llm_family.json +2 -0
- xinference/model/llm/llm_family.py +86 -1
- xinference/model/llm/llm_family_csghub.json +66 -0
- xinference/model/llm/llm_family_modelscope.json +2 -0
- xinference/model/llm/pytorch/chatglm.py +18 -12
- xinference/model/llm/pytorch/core.py +92 -42
- xinference/model/llm/pytorch/glm4v.py +13 -3
- xinference/model/llm/pytorch/qwen_vl.py +1 -1
- xinference/model/llm/pytorch/utils.py +27 -14
- xinference/model/llm/utils.py +14 -13
- xinference/model/llm/vllm/core.py +10 -4
- xinference/model/utils.py +8 -2
- xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
- xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/infer/api.py +125 -0
- xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
- xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
- xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
- xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
- xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
- xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
- xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
- xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/METADATA +1 -1
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/RECORD +57 -45
- xinference/web/ui/build/static/css/main.54bca460.css +0 -2
- xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
- xinference/web/ui/build/static/js/main.551aa479.js +0 -3
- xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
- /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
import json
|
|
16
16
|
import logging
|
|
17
17
|
import os
|
|
18
|
+
from functools import lru_cache
|
|
18
19
|
from typing import Iterable, Iterator, List, Optional, Union
|
|
19
20
|
|
|
20
21
|
from ....core.scheduler import InferenceRequest
|
|
@@ -28,6 +29,7 @@ from ....types import (
|
|
|
28
29
|
ChatCompletionChunk,
|
|
29
30
|
ChatCompletionMessage,
|
|
30
31
|
Completion,
|
|
32
|
+
CompletionChoice,
|
|
31
33
|
CompletionChunk,
|
|
32
34
|
CreateCompletionTorch,
|
|
33
35
|
Embedding,
|
|
@@ -366,6 +368,90 @@ class PytorchModel(LLM):
|
|
|
366
368
|
else:
|
|
367
369
|
return generator_wrapper(prompt, generate_config)
|
|
368
370
|
|
|
371
|
+
@lru_cache
|
|
372
|
+
def get_context_len(self):
|
|
373
|
+
return get_context_length(self._model.config)
|
|
374
|
+
|
|
375
|
+
def get_max_num_seqs(self) -> int:
|
|
376
|
+
return self._pytorch_model_config.get("max_num_seqs") # type: ignore
|
|
377
|
+
|
|
378
|
+
def prepare_batch_inference(self, req_list: List[InferenceRequest]):
|
|
379
|
+
# check some parameters
|
|
380
|
+
for r in req_list:
|
|
381
|
+
if r.sanitized_generate_config is None:
|
|
382
|
+
r.sanitized_generate_config = self._sanitize_generate_config(
|
|
383
|
+
r.generate_config
|
|
384
|
+
)
|
|
385
|
+
if r.is_prefill:
|
|
386
|
+
# check some generate params
|
|
387
|
+
max_src_len = get_max_src_len(self.get_context_len(), r) # type: ignore
|
|
388
|
+
if max_src_len < 0:
|
|
389
|
+
r.stopped = True
|
|
390
|
+
r.error_msg = "Max tokens exceeds model's max length"
|
|
391
|
+
continue
|
|
392
|
+
if r.stream_interval <= 0:
|
|
393
|
+
r.stopped = True
|
|
394
|
+
r.error_msg = "`stream_interval` must be greater than 0"
|
|
395
|
+
continue
|
|
396
|
+
stop_str = r.sanitized_generate_config.get("stop", None)
|
|
397
|
+
if stop_str and (
|
|
398
|
+
not (isinstance(stop_str, str) or isinstance(stop_str, Iterable))
|
|
399
|
+
):
|
|
400
|
+
r.stopped = True
|
|
401
|
+
r.error_msg = "Invalid `stop` field type"
|
|
402
|
+
continue
|
|
403
|
+
|
|
404
|
+
def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
|
|
405
|
+
for req in req_list:
|
|
406
|
+
if req.error_msg is None:
|
|
407
|
+
# nothing need handle for non-stream case
|
|
408
|
+
if req.stream:
|
|
409
|
+
results = []
|
|
410
|
+
for i, c in enumerate(req.completion):
|
|
411
|
+
if c == "<bos_stream>":
|
|
412
|
+
chunk = req.completion[i + 1]
|
|
413
|
+
results.append(
|
|
414
|
+
CompletionChunk(
|
|
415
|
+
id=chunk["id"],
|
|
416
|
+
object=chunk["object"],
|
|
417
|
+
created=chunk["created"],
|
|
418
|
+
model=chunk["model"],
|
|
419
|
+
choices=[
|
|
420
|
+
CompletionChoice(
|
|
421
|
+
text="",
|
|
422
|
+
index=0,
|
|
423
|
+
logprobs=None,
|
|
424
|
+
finish_reason=None,
|
|
425
|
+
)
|
|
426
|
+
],
|
|
427
|
+
)
|
|
428
|
+
)
|
|
429
|
+
continue
|
|
430
|
+
elif c == "<eos_stream>":
|
|
431
|
+
break
|
|
432
|
+
else:
|
|
433
|
+
results.append(c)
|
|
434
|
+
|
|
435
|
+
if req.stopped and req.include_usage:
|
|
436
|
+
results.append(req.completion[-1])
|
|
437
|
+
req.completion = results
|
|
438
|
+
|
|
439
|
+
def batch_inference(self, req_list: List[InferenceRequest]):
|
|
440
|
+
from .utils import batch_inference_one_step
|
|
441
|
+
|
|
442
|
+
self.prepare_batch_inference(req_list)
|
|
443
|
+
context_len = self.get_context_len()
|
|
444
|
+
assert isinstance(context_len, int)
|
|
445
|
+
batch_inference_one_step(
|
|
446
|
+
req_list,
|
|
447
|
+
self.model_uid,
|
|
448
|
+
self._model,
|
|
449
|
+
self._tokenizer,
|
|
450
|
+
self._device,
|
|
451
|
+
context_len,
|
|
452
|
+
)
|
|
453
|
+
self.handle_batch_inference_results(req_list)
|
|
454
|
+
|
|
369
455
|
def create_embedding(self, input: Union[str, List[str]]) -> Embedding:
|
|
370
456
|
try:
|
|
371
457
|
import torch
|
|
@@ -464,7 +550,6 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
464
550
|
pytorch_model_config,
|
|
465
551
|
peft_model,
|
|
466
552
|
)
|
|
467
|
-
self._context_len = None
|
|
468
553
|
|
|
469
554
|
def _sanitize_generate_config(
|
|
470
555
|
self,
|
|
@@ -540,7 +625,6 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
540
625
|
|
|
541
626
|
def load(self):
|
|
542
627
|
super().load()
|
|
543
|
-
self._context_len = get_context_length(self._model.config)
|
|
544
628
|
|
|
545
629
|
def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
|
|
546
630
|
assert self.model_family.prompt_style is not None
|
|
@@ -553,48 +637,14 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
553
637
|
)
|
|
554
638
|
return full_prompt
|
|
555
639
|
|
|
556
|
-
def
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
def batch_inference(self, req_list: List[InferenceRequest]):
|
|
560
|
-
from .utils import batch_inference_one_step
|
|
561
|
-
|
|
640
|
+
def prepare_batch_inference(self, req_list: List[InferenceRequest]):
|
|
641
|
+
super().prepare_batch_inference(req_list)
|
|
562
642
|
for r in req_list:
|
|
563
|
-
|
|
564
|
-
r.
|
|
565
|
-
|
|
566
|
-
)
|
|
567
|
-
if r.is_prefill:
|
|
568
|
-
# check some generate params
|
|
569
|
-
max_src_len = get_max_src_len(self._context_len, r) # type: ignore
|
|
570
|
-
if max_src_len < 0:
|
|
571
|
-
r.stopped = True
|
|
572
|
-
r.error_msg = "Max tokens exceeds model's max length"
|
|
573
|
-
continue
|
|
574
|
-
if r.stream_interval <= 0:
|
|
575
|
-
r.stopped = True
|
|
576
|
-
r.error_msg = "`stream_interval` must be greater than 0"
|
|
577
|
-
continue
|
|
578
|
-
stop_str = r.sanitized_generate_config.get("stop", None)
|
|
579
|
-
if stop_str and (
|
|
580
|
-
not (isinstance(stop_str, str) or isinstance(stop_str, Iterable))
|
|
581
|
-
):
|
|
582
|
-
r.stopped = True
|
|
583
|
-
r.error_msg = "Invalid `stop` field type"
|
|
584
|
-
continue
|
|
585
|
-
r.full_prompt = self._get_full_prompt(
|
|
586
|
-
r.prompt, r.system_prompt, r.chat_history, None
|
|
587
|
-
)
|
|
643
|
+
r.full_prompt = self._get_full_prompt(
|
|
644
|
+
r.prompt, r.system_prompt, r.chat_history, None
|
|
645
|
+
)
|
|
588
646
|
|
|
589
|
-
|
|
590
|
-
batch_inference_one_step(
|
|
591
|
-
req_list,
|
|
592
|
-
self.model_uid,
|
|
593
|
-
self._model,
|
|
594
|
-
self._tokenizer,
|
|
595
|
-
self._device,
|
|
596
|
-
self._context_len,
|
|
597
|
-
)
|
|
647
|
+
def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
|
|
598
648
|
for req in req_list:
|
|
599
649
|
if req.stream and req.error_msg is None:
|
|
600
650
|
if req.completion:
|
|
@@ -56,19 +56,29 @@ class Glm4VModel(PytorchChatModel):
|
|
|
56
56
|
return True
|
|
57
57
|
return False
|
|
58
58
|
|
|
59
|
-
def load(self
|
|
59
|
+
def load(self):
|
|
60
60
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
61
61
|
|
|
62
62
|
device = self._pytorch_model_config.get("device", "auto")
|
|
63
63
|
self._device = select_device(device)
|
|
64
|
-
|
|
64
|
+
|
|
65
|
+
kwargs = {"device_map": self._device}
|
|
66
|
+
quantization = self.quantization
|
|
67
|
+
if quantization != "none":
|
|
68
|
+
if self._device == "cuda" and self._is_linux():
|
|
69
|
+
kwargs["device_map"] = "auto"
|
|
70
|
+
self._device = "auto"
|
|
71
|
+
if quantization == "4-bit":
|
|
72
|
+
kwargs["load_in_4bit"] = True
|
|
73
|
+
elif quantization == "8-bit":
|
|
74
|
+
kwargs["load_in_8bit"] = True
|
|
65
75
|
|
|
66
76
|
model = AutoModelForCausalLM.from_pretrained(
|
|
67
77
|
self.model_path,
|
|
68
78
|
low_cpu_mem_usage=True,
|
|
69
79
|
trust_remote_code=True,
|
|
70
80
|
torch_dtype=torch.float16,
|
|
71
|
-
|
|
81
|
+
**kwargs,
|
|
72
82
|
)
|
|
73
83
|
self._model = model.eval()
|
|
74
84
|
|
|
@@ -45,7 +45,7 @@ class QwenVLChatModel(PytorchChatModel):
|
|
|
45
45
|
def match(
|
|
46
46
|
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
47
47
|
) -> bool:
|
|
48
|
-
if "qwen" in model_family.model_name:
|
|
48
|
+
if "qwen" in model_family.model_name and "vision" in model_family.model_ability:
|
|
49
49
|
return True
|
|
50
50
|
return False
|
|
51
51
|
|
|
@@ -126,6 +126,7 @@ def generate_stream(
|
|
|
126
126
|
stop_str = generate_config.get("stop", None)
|
|
127
127
|
stop_token_ids = generate_config.get("stop_token_ids", None) or []
|
|
128
128
|
stop_token_ids.append(tokenizer.eos_token_id)
|
|
129
|
+
chunk_id = str(uuid.uuid4())
|
|
129
130
|
|
|
130
131
|
logits_processor = prepare_logits_processor(
|
|
131
132
|
temperature, repetition_penalty, top_p, top_k
|
|
@@ -289,7 +290,7 @@ def generate_stream(
|
|
|
289
290
|
text=output, index=0, logprobs=None, finish_reason=None
|
|
290
291
|
)
|
|
291
292
|
completion_chunk = CompletionChunk(
|
|
292
|
-
id=
|
|
293
|
+
id=chunk_id,
|
|
293
294
|
object="text_completion",
|
|
294
295
|
created=int(time.time()),
|
|
295
296
|
model=model_uid,
|
|
@@ -327,7 +328,7 @@ def generate_stream(
|
|
|
327
328
|
)
|
|
328
329
|
|
|
329
330
|
completion_chunk = CompletionChunk(
|
|
330
|
-
id=
|
|
331
|
+
id=chunk_id,
|
|
331
332
|
object="text_completion",
|
|
332
333
|
created=int(time.time()),
|
|
333
334
|
model=model_uid,
|
|
@@ -343,7 +344,7 @@ def generate_stream(
|
|
|
343
344
|
|
|
344
345
|
if include_usage:
|
|
345
346
|
completion_chunk = CompletionChunk(
|
|
346
|
-
id=
|
|
347
|
+
id=chunk_id,
|
|
347
348
|
object="text_completion",
|
|
348
349
|
created=int(time.time()),
|
|
349
350
|
model=model_uid,
|
|
@@ -390,6 +391,7 @@ def generate_stream_falcon(
|
|
|
390
391
|
stop_str = generate_config.get("stop", None)
|
|
391
392
|
stop_token_ids = generate_config.get("stop_token_ids", None) or []
|
|
392
393
|
stop_token_ids.append(tokenizer.eos_token_id)
|
|
394
|
+
chunk_id = str(uuid.uuid4())
|
|
393
395
|
|
|
394
396
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
395
397
|
input_ids = inputs["input_ids"]
|
|
@@ -473,7 +475,7 @@ def generate_stream_falcon(
|
|
|
473
475
|
text=output, index=0, logprobs=None, finish_reason=None
|
|
474
476
|
)
|
|
475
477
|
completion_chunk = CompletionChunk(
|
|
476
|
-
id=
|
|
478
|
+
id=chunk_id,
|
|
477
479
|
object="text_completion",
|
|
478
480
|
created=int(time.time()),
|
|
479
481
|
model=model_uid,
|
|
@@ -500,7 +502,7 @@ def generate_stream_falcon(
|
|
|
500
502
|
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
501
503
|
)
|
|
502
504
|
completion_chunk = CompletionChunk(
|
|
503
|
-
id=
|
|
505
|
+
id=chunk_id,
|
|
504
506
|
object="text_completion",
|
|
505
507
|
created=int(time.time()),
|
|
506
508
|
model=model_uid,
|
|
@@ -516,7 +518,7 @@ def generate_stream_falcon(
|
|
|
516
518
|
|
|
517
519
|
if include_usage:
|
|
518
520
|
completion_chunk = CompletionChunk(
|
|
519
|
-
id=
|
|
521
|
+
id=chunk_id,
|
|
520
522
|
object="text_completion",
|
|
521
523
|
created=int(time.time()),
|
|
522
524
|
model=model_uid,
|
|
@@ -586,6 +588,7 @@ def get_max_src_len(context_len: int, r: InferenceRequest) -> int:
|
|
|
586
588
|
|
|
587
589
|
def _get_completion_chunk(
|
|
588
590
|
output: str,
|
|
591
|
+
chunk_id: str,
|
|
589
592
|
finish_reason: Optional[str],
|
|
590
593
|
model_uid: str,
|
|
591
594
|
r: InferenceRequest,
|
|
@@ -601,7 +604,7 @@ def _get_completion_chunk(
|
|
|
601
604
|
else []
|
|
602
605
|
)
|
|
603
606
|
completion_chunk = CompletionChunk(
|
|
604
|
-
id=
|
|
607
|
+
id=chunk_id,
|
|
605
608
|
object="text_completion",
|
|
606
609
|
created=int(time.time()),
|
|
607
610
|
model=model_uid,
|
|
@@ -617,14 +620,18 @@ def _get_completion_chunk(
|
|
|
617
620
|
|
|
618
621
|
|
|
619
622
|
def _get_completion(
|
|
620
|
-
output: str,
|
|
623
|
+
output: str,
|
|
624
|
+
chunk_id: str,
|
|
625
|
+
finish_reason: Optional[str],
|
|
626
|
+
model_uid: str,
|
|
627
|
+
r: InferenceRequest,
|
|
621
628
|
):
|
|
622
629
|
completion_choice = CompletionChoice(
|
|
623
630
|
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
624
631
|
)
|
|
625
632
|
|
|
626
633
|
completion_chunk = CompletionChunk(
|
|
627
|
-
id=
|
|
634
|
+
id=chunk_id,
|
|
628
635
|
object="text_completion",
|
|
629
636
|
created=int(time.time()),
|
|
630
637
|
model=model_uid,
|
|
@@ -701,7 +708,7 @@ def _batch_inference_one_step_internal(
|
|
|
701
708
|
decode_reqs = []
|
|
702
709
|
for r in valid_req_list:
|
|
703
710
|
if r.is_prefill:
|
|
704
|
-
prompts.append(r.full_prompt)
|
|
711
|
+
prompts.append(r.full_prompt if r.full_prompt is not None else r.prompt)
|
|
705
712
|
prefill_reqs.append(r)
|
|
706
713
|
else:
|
|
707
714
|
decode_reqs.append(r)
|
|
@@ -846,7 +853,7 @@ def _batch_inference_one_step_internal(
|
|
|
846
853
|
r.last_output_length += len(output)
|
|
847
854
|
|
|
848
855
|
completion_chunk = _get_completion_chunk(
|
|
849
|
-
output, r.finish_reason, model_uid, r, False
|
|
856
|
+
output, r.chunk_id, r.finish_reason, model_uid, r, False
|
|
850
857
|
)
|
|
851
858
|
r.completion.append(completion_chunk)
|
|
852
859
|
if r.stopped:
|
|
@@ -859,7 +866,7 @@ def _batch_inference_one_step_internal(
|
|
|
859
866
|
if r.stopped and _i == decode_round - 1 and include_usage:
|
|
860
867
|
r.completion.append(
|
|
861
868
|
_get_completion_chunk(
|
|
862
|
-
"", r.finish_reason, model_uid, r, True
|
|
869
|
+
"", r.chunk_id, r.finish_reason, model_uid, r, True
|
|
863
870
|
)
|
|
864
871
|
)
|
|
865
872
|
else:
|
|
@@ -878,7 +885,9 @@ def _batch_inference_one_step_internal(
|
|
|
878
885
|
if r not in output_mapping
|
|
879
886
|
else output_mapping[r]
|
|
880
887
|
)
|
|
881
|
-
completion = _get_completion(
|
|
888
|
+
completion = _get_completion(
|
|
889
|
+
outputs, r.chunk_id, r.finish_reason, model_uid, r
|
|
890
|
+
)
|
|
882
891
|
r.completion = [completion]
|
|
883
892
|
|
|
884
893
|
e_time = time.time()
|
|
@@ -911,4 +920,8 @@ def batch_inference_one_step(
|
|
|
911
920
|
os._exit(1)
|
|
912
921
|
except Exception as e:
|
|
913
922
|
logger.exception(f"Internal error for batch inference: {e}.")
|
|
914
|
-
#
|
|
923
|
+
# If internal error happens, just skip all the requests in this batch.
|
|
924
|
+
# If not handle here, the client will hang.
|
|
925
|
+
for r in req_list:
|
|
926
|
+
r.stopped = True
|
|
927
|
+
r.error_msg = str(e)
|
xinference/model/llm/utils.py
CHANGED
|
@@ -607,7 +607,7 @@ Begin!"""
|
|
|
607
607
|
return arguments, None, None
|
|
608
608
|
|
|
609
609
|
@staticmethod
|
|
610
|
-
def
|
|
610
|
+
def _eval_glm_chat_arguments(c, tools):
|
|
611
611
|
if isinstance(c[0], str):
|
|
612
612
|
return c[0], None, None
|
|
613
613
|
return None, c[0]["name"], c[0]["parameters"]
|
|
@@ -659,9 +659,9 @@ Begin!"""
|
|
|
659
659
|
family = model_family.model_family or model_family.model_name
|
|
660
660
|
if family in ["gorilla-openfunctions-v1", "gorilla-openfunctions-v2"]:
|
|
661
661
|
content, func, args = cls._eval_gorilla_openfunctions_arguments(c, tools)
|
|
662
|
-
elif "chatglm3"
|
|
663
|
-
content, func, args = cls.
|
|
664
|
-
elif family in ["qwen-chat", "qwen1.5-chat"]:
|
|
662
|
+
elif family in ["chatglm3", "glm4-chat"]:
|
|
663
|
+
content, func, args = cls._eval_glm_chat_arguments(c, tools)
|
|
664
|
+
elif family in ["qwen-chat", "qwen1.5-chat", "qwen2-instruct"]:
|
|
665
665
|
content, func, args = cls._eval_qwen_chat_arguments(c, tools)
|
|
666
666
|
else:
|
|
667
667
|
raise Exception(
|
|
@@ -676,28 +676,29 @@ Begin!"""
|
|
|
676
676
|
Generates a filter function for Qwen series models to retain outputs after "\nFinal Answer:".
|
|
677
677
|
|
|
678
678
|
Returns:
|
|
679
|
-
A function that takes tokens (string output by the model so far) as input
|
|
680
|
-
returns
|
|
679
|
+
A function that takes tokens (string output by the model so far) and delta (new tokens added) as input,
|
|
680
|
+
returns the part after "\nFinal Answer:" if found, else returns delta.
|
|
681
681
|
"""
|
|
682
682
|
family = model_family.model_family or model_family.model_name
|
|
683
683
|
if family in ["qwen-chat", "qwen1.5-chat"]:
|
|
684
684
|
# Encapsulating function to reset 'found' after each call
|
|
685
685
|
found = False
|
|
686
686
|
|
|
687
|
-
def
|
|
687
|
+
def process_tokens(tokens: str, delta: str):
|
|
688
688
|
nonlocal found
|
|
689
689
|
# Once "Final Answer:" is found, future tokens are allowed.
|
|
690
690
|
if found:
|
|
691
|
-
return
|
|
691
|
+
return delta
|
|
692
692
|
# Check if the token ends with "\nFinal Answer:" and update `found`.
|
|
693
|
-
|
|
693
|
+
final_answer_idx = tokens.lower().rfind("\nfinal answer:")
|
|
694
|
+
if final_answer_idx != -1:
|
|
694
695
|
found = True
|
|
695
|
-
|
|
696
|
+
return tokens[final_answer_idx + len("\nfinal answer:") :]
|
|
697
|
+
return ""
|
|
696
698
|
|
|
697
|
-
return
|
|
699
|
+
return process_tokens
|
|
698
700
|
else:
|
|
699
|
-
|
|
700
|
-
return lambda tokens: True
|
|
701
|
+
return lambda tokens, delta: delta
|
|
701
702
|
|
|
702
703
|
@classmethod
|
|
703
704
|
def _tool_calls_completion(cls, model_family, model_uid, c, tools):
|
|
@@ -444,7 +444,9 @@ class VLLMModel(LLM):
|
|
|
444
444
|
_content, func, args = ChatModelMixin._eval_tool_arguments(
|
|
445
445
|
self.model_family, chunk, tools
|
|
446
446
|
)
|
|
447
|
-
choice["text"] =
|
|
447
|
+
choice["text"] = tools_token_filter(
|
|
448
|
+
tokens=previous_texts[0], delta=choice_delta
|
|
449
|
+
)
|
|
448
450
|
if func is not None:
|
|
449
451
|
choice["text"] = None
|
|
450
452
|
choice["finish_reason"] = "tool_calls"
|
|
@@ -458,9 +460,13 @@ class VLLMModel(LLM):
|
|
|
458
460
|
),
|
|
459
461
|
)
|
|
460
462
|
]
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
463
|
+
else:
|
|
464
|
+
# use a filter function to skip Qwen's react thought process
|
|
465
|
+
choice["text"] = tools_token_filter(
|
|
466
|
+
tokens=previous_texts[0], delta=choice["text"]
|
|
467
|
+
)
|
|
468
|
+
if not choice["text"]:
|
|
469
|
+
continue
|
|
464
470
|
prompt_tokens = len(_request_output.prompt_token_ids)
|
|
465
471
|
completion_tokens = sum(
|
|
466
472
|
len(output.token_ids) for output in _request_output.outputs
|
xinference/model/utils.py
CHANGED
|
@@ -42,14 +42,20 @@ def is_locale_chinese_simplified() -> bool:
|
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
def download_from_modelscope() -> bool:
|
|
45
|
-
if os.environ.get(XINFERENCE_ENV_MODEL_SRC)
|
|
46
|
-
return
|
|
45
|
+
if os.environ.get(XINFERENCE_ENV_MODEL_SRC):
|
|
46
|
+
return os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "modelscope"
|
|
47
47
|
elif is_locale_chinese_simplified():
|
|
48
48
|
return True
|
|
49
49
|
else:
|
|
50
50
|
return False
|
|
51
51
|
|
|
52
52
|
|
|
53
|
+
def download_from_csghub() -> bool:
|
|
54
|
+
if os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "csghub":
|
|
55
|
+
return True
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
|
|
53
59
|
def symlink_local_file(path: str, local_dir: str, relpath: str) -> str:
|
|
54
60
|
from huggingface_hub.file_download import _create_symlink
|
|
55
61
|
|
|
File without changes
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
|
|
2
|
+
from openai import OpenAI
|
|
3
|
+
|
|
4
|
+
prompt_dict = {
|
|
5
|
+
'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"},
|
|
6
|
+
{"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
|
|
7
|
+
{"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
|
|
8
|
+
'deepseek': [
|
|
9
|
+
{"role": "system", "content": "You are a helpful assistant"},
|
|
10
|
+
{"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
|
|
11
|
+
{"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
|
|
12
|
+
'deepseek_TN': [
|
|
13
|
+
{"role": "system", "content": "You are a helpful assistant"},
|
|
14
|
+
{"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"},
|
|
15
|
+
{"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"},
|
|
16
|
+
{"role": "user", "content": "We paid $123 for this desk."},
|
|
17
|
+
{"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."},
|
|
18
|
+
{"role": "user", "content": "详询请拨打010-724654"},
|
|
19
|
+
{"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"},
|
|
20
|
+
{"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"},
|
|
21
|
+
{"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"},
|
|
22
|
+
],
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
class llm_api:
|
|
26
|
+
def __init__(self, api_key, base_url, model):
|
|
27
|
+
self.client = OpenAI(
|
|
28
|
+
api_key = api_key,
|
|
29
|
+
base_url = base_url,
|
|
30
|
+
)
|
|
31
|
+
self.model = model
|
|
32
|
+
def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs):
|
|
33
|
+
|
|
34
|
+
completion = self.client.chat.completions.create(
|
|
35
|
+
model = self.model,
|
|
36
|
+
messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},],
|
|
37
|
+
temperature = temperature,
|
|
38
|
+
**kwargs
|
|
39
|
+
)
|
|
40
|
+
return completion.choices[0].message.content
|
|
File without changes
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
|
|
5
|
+
from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
|
|
6
|
+
|
|
7
|
+
def infer_code(
|
|
8
|
+
models,
|
|
9
|
+
text,
|
|
10
|
+
spk_emb = None,
|
|
11
|
+
top_P = 0.7,
|
|
12
|
+
top_K = 20,
|
|
13
|
+
temperature = 0.3,
|
|
14
|
+
repetition_penalty = 1.05,
|
|
15
|
+
max_new_token = 2048,
|
|
16
|
+
**kwargs
|
|
17
|
+
):
|
|
18
|
+
|
|
19
|
+
device = next(models['gpt'].parameters()).device
|
|
20
|
+
|
|
21
|
+
if not isinstance(text, list):
|
|
22
|
+
text = [text]
|
|
23
|
+
|
|
24
|
+
if not isinstance(temperature, list):
|
|
25
|
+
temperature = [temperature] * models['gpt'].num_vq
|
|
26
|
+
|
|
27
|
+
if spk_emb is not None:
|
|
28
|
+
text = [f'[Stts][spk_emb]{i}[Ptts]' for i in text]
|
|
29
|
+
else:
|
|
30
|
+
text = [f'[Stts][empty_spk]{i}[Ptts]' for i in text]
|
|
31
|
+
|
|
32
|
+
text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
|
|
33
|
+
input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq)
|
|
34
|
+
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
|
|
35
|
+
|
|
36
|
+
inputs = {
|
|
37
|
+
'input_ids': input_ids,
|
|
38
|
+
'text_mask': text_mask,
|
|
39
|
+
'attention_mask': text_token['attention_mask'],
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
emb = models['gpt'].get_emb(**inputs)
|
|
43
|
+
if spk_emb is not None:
|
|
44
|
+
emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \
|
|
45
|
+
F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12)
|
|
46
|
+
|
|
47
|
+
num_code = models['gpt'].emb_code[0].num_embeddings - 1
|
|
48
|
+
|
|
49
|
+
LogitsWarpers = []
|
|
50
|
+
if top_P is not None:
|
|
51
|
+
LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
|
|
52
|
+
if top_K is not None:
|
|
53
|
+
LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
|
|
54
|
+
|
|
55
|
+
LogitsProcessors = []
|
|
56
|
+
if repetition_penalty is not None and repetition_penalty != 1:
|
|
57
|
+
LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\
|
|
58
|
+
repetition_penalty, num_code, 16))
|
|
59
|
+
|
|
60
|
+
result = models['gpt'].generate(
|
|
61
|
+
emb, inputs['input_ids'],
|
|
62
|
+
temperature = torch.tensor(temperature, device=device),
|
|
63
|
+
attention_mask = inputs['attention_mask'],
|
|
64
|
+
LogitsWarpers = LogitsWarpers,
|
|
65
|
+
LogitsProcessors = LogitsProcessors,
|
|
66
|
+
eos_token = num_code,
|
|
67
|
+
max_new_token = max_new_token,
|
|
68
|
+
infer_text = False,
|
|
69
|
+
**kwargs
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return result
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def refine_text(
|
|
76
|
+
models,
|
|
77
|
+
text,
|
|
78
|
+
top_P = 0.7,
|
|
79
|
+
top_K = 20,
|
|
80
|
+
temperature = 0.7,
|
|
81
|
+
repetition_penalty = 1.0,
|
|
82
|
+
max_new_token = 384,
|
|
83
|
+
prompt = '',
|
|
84
|
+
**kwargs
|
|
85
|
+
):
|
|
86
|
+
|
|
87
|
+
device = next(models['gpt'].parameters()).device
|
|
88
|
+
|
|
89
|
+
if not isinstance(text, list):
|
|
90
|
+
text = [text]
|
|
91
|
+
|
|
92
|
+
assert len(text), 'text should not be empty'
|
|
93
|
+
|
|
94
|
+
text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
|
|
95
|
+
text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
|
|
96
|
+
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
|
|
97
|
+
|
|
98
|
+
inputs = {
|
|
99
|
+
'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq),
|
|
100
|
+
'text_mask': text_mask,
|
|
101
|
+
'attention_mask': text_token['attention_mask'],
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
LogitsWarpers = []
|
|
105
|
+
if top_P is not None:
|
|
106
|
+
LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
|
|
107
|
+
if top_K is not None:
|
|
108
|
+
LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
|
|
109
|
+
|
|
110
|
+
LogitsProcessors = []
|
|
111
|
+
if repetition_penalty is not None and repetition_penalty != 1:
|
|
112
|
+
LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16))
|
|
113
|
+
|
|
114
|
+
result = models['gpt'].generate(
|
|
115
|
+
models['gpt'].get_emb(**inputs), inputs['input_ids'],
|
|
116
|
+
temperature = torch.tensor([temperature,], device=device),
|
|
117
|
+
attention_mask = inputs['attention_mask'],
|
|
118
|
+
LogitsWarpers = LogitsWarpers,
|
|
119
|
+
LogitsProcessors = LogitsProcessors,
|
|
120
|
+
eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
|
|
121
|
+
max_new_token = max_new_token,
|
|
122
|
+
infer_text = True,
|
|
123
|
+
**kwargs
|
|
124
|
+
)
|
|
125
|
+
return result
|
|
File without changes
|