xinference 0.11.3__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +143 -6
- xinference/client/restful/restful_client.py +144 -5
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +48 -28
- xinference/core/model.py +160 -19
- xinference/core/scheduler.py +446 -0
- xinference/core/supervisor.py +99 -24
- xinference/core/worker.py +68 -2
- xinference/deploy/cmdline.py +86 -2
- xinference/deploy/test/test_cmdline.py +19 -10
- xinference/isolation.py +9 -2
- xinference/model/audio/__init__.py +14 -1
- xinference/model/audio/chattts.py +84 -0
- xinference/model/audio/core.py +22 -4
- xinference/model/audio/custom.py +6 -4
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/audio/model_spec_modelscope.json +20 -0
- xinference/model/llm/__init__.py +38 -2
- xinference/model/llm/llm_family.json +509 -1
- xinference/model/llm/llm_family.py +86 -1
- xinference/model/llm/llm_family_csghub.json +66 -0
- xinference/model/llm/llm_family_modelscope.json +411 -2
- xinference/model/llm/pytorch/chatglm.py +20 -13
- xinference/model/llm/pytorch/cogvlm2.py +76 -17
- xinference/model/llm/pytorch/core.py +141 -6
- xinference/model/llm/pytorch/glm4v.py +268 -0
- xinference/model/llm/pytorch/minicpmv25.py +232 -0
- xinference/model/llm/pytorch/qwen_vl.py +1 -1
- xinference/model/llm/pytorch/utils.py +405 -8
- xinference/model/llm/utils.py +14 -13
- xinference/model/llm/vllm/core.py +16 -4
- xinference/model/utils.py +8 -2
- xinference/thirdparty/ChatTTS/__init__.py +1 -0
- xinference/thirdparty/ChatTTS/core.py +200 -0
- xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
- xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/infer/api.py +125 -0
- xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
- xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
- xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
- xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
- xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
- xinference/types.py +3 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
- xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
- xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
- xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/METADATA +26 -9
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/RECORD +65 -47
- xinference/web/ui/build/static/css/main.54bca460.css +0 -2
- xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
- xinference/web/ui/build/static/js/main.551aa479.js +0 -3
- xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
- /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
|
@@ -14,13 +14,15 @@
|
|
|
14
14
|
|
|
15
15
|
import gc
|
|
16
16
|
import logging
|
|
17
|
+
import os
|
|
17
18
|
import time
|
|
18
19
|
import uuid
|
|
19
20
|
from threading import Thread
|
|
20
|
-
from typing import Iterable, Iterator, Tuple
|
|
21
|
+
from typing import Dict, Iterable, Iterator, List, Optional, Tuple
|
|
21
22
|
|
|
22
23
|
import torch
|
|
23
24
|
from transformers import GenerationConfig, TextIteratorStreamer
|
|
25
|
+
from transformers.cache_utils import DynamicCache
|
|
24
26
|
from transformers.generation.logits_process import (
|
|
25
27
|
LogitsProcessorList,
|
|
26
28
|
RepetitionPenaltyLogitsProcessor,
|
|
@@ -29,8 +31,10 @@ from transformers.generation.logits_process import (
|
|
|
29
31
|
TopPLogitsWarper,
|
|
30
32
|
)
|
|
31
33
|
|
|
34
|
+
from ....core.scheduler import InferenceRequest
|
|
32
35
|
from ....device_utils import empty_cache
|
|
33
36
|
from ....types import (
|
|
37
|
+
Completion,
|
|
34
38
|
CompletionChoice,
|
|
35
39
|
CompletionChunk,
|
|
36
40
|
CompletionUsage,
|
|
@@ -54,7 +58,7 @@ def is_partial_stop(output: str, stop_str: str):
|
|
|
54
58
|
return False
|
|
55
59
|
|
|
56
60
|
|
|
57
|
-
def get_context_length(config):
|
|
61
|
+
def get_context_length(config) -> int:
|
|
58
62
|
"""Get the context length of a model from a huggingface model config."""
|
|
59
63
|
if (
|
|
60
64
|
hasattr(config, "max_sequence_length")
|
|
@@ -122,6 +126,7 @@ def generate_stream(
|
|
|
122
126
|
stop_str = generate_config.get("stop", None)
|
|
123
127
|
stop_token_ids = generate_config.get("stop_token_ids", None) or []
|
|
124
128
|
stop_token_ids.append(tokenizer.eos_token_id)
|
|
129
|
+
chunk_id = str(uuid.uuid4())
|
|
125
130
|
|
|
126
131
|
logits_processor = prepare_logits_processor(
|
|
127
132
|
temperature, repetition_penalty, top_p, top_k
|
|
@@ -285,7 +290,7 @@ def generate_stream(
|
|
|
285
290
|
text=output, index=0, logprobs=None, finish_reason=None
|
|
286
291
|
)
|
|
287
292
|
completion_chunk = CompletionChunk(
|
|
288
|
-
id=
|
|
293
|
+
id=chunk_id,
|
|
289
294
|
object="text_completion",
|
|
290
295
|
created=int(time.time()),
|
|
291
296
|
model=model_uid,
|
|
@@ -323,7 +328,7 @@ def generate_stream(
|
|
|
323
328
|
)
|
|
324
329
|
|
|
325
330
|
completion_chunk = CompletionChunk(
|
|
326
|
-
id=
|
|
331
|
+
id=chunk_id,
|
|
327
332
|
object="text_completion",
|
|
328
333
|
created=int(time.time()),
|
|
329
334
|
model=model_uid,
|
|
@@ -339,7 +344,7 @@ def generate_stream(
|
|
|
339
344
|
|
|
340
345
|
if include_usage:
|
|
341
346
|
completion_chunk = CompletionChunk(
|
|
342
|
-
id=
|
|
347
|
+
id=chunk_id,
|
|
343
348
|
object="text_completion",
|
|
344
349
|
created=int(time.time()),
|
|
345
350
|
model=model_uid,
|
|
@@ -386,6 +391,7 @@ def generate_stream_falcon(
|
|
|
386
391
|
stop_str = generate_config.get("stop", None)
|
|
387
392
|
stop_token_ids = generate_config.get("stop_token_ids", None) or []
|
|
388
393
|
stop_token_ids.append(tokenizer.eos_token_id)
|
|
394
|
+
chunk_id = str(uuid.uuid4())
|
|
389
395
|
|
|
390
396
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
391
397
|
input_ids = inputs["input_ids"]
|
|
@@ -469,7 +475,7 @@ def generate_stream_falcon(
|
|
|
469
475
|
text=output, index=0, logprobs=None, finish_reason=None
|
|
470
476
|
)
|
|
471
477
|
completion_chunk = CompletionChunk(
|
|
472
|
-
id=
|
|
478
|
+
id=chunk_id,
|
|
473
479
|
object="text_completion",
|
|
474
480
|
created=int(time.time()),
|
|
475
481
|
model=model_uid,
|
|
@@ -496,7 +502,7 @@ def generate_stream_falcon(
|
|
|
496
502
|
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
497
503
|
)
|
|
498
504
|
completion_chunk = CompletionChunk(
|
|
499
|
-
id=
|
|
505
|
+
id=chunk_id,
|
|
500
506
|
object="text_completion",
|
|
501
507
|
created=int(time.time()),
|
|
502
508
|
model=model_uid,
|
|
@@ -512,7 +518,7 @@ def generate_stream_falcon(
|
|
|
512
518
|
|
|
513
519
|
if include_usage:
|
|
514
520
|
completion_chunk = CompletionChunk(
|
|
515
|
-
id=
|
|
521
|
+
id=chunk_id,
|
|
516
522
|
object="text_completion",
|
|
517
523
|
created=int(time.time()),
|
|
518
524
|
model=model_uid,
|
|
@@ -528,3 +534,394 @@ def generate_stream_falcon(
|
|
|
528
534
|
# clean
|
|
529
535
|
gc.collect()
|
|
530
536
|
empty_cache()
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def _get_token_from_logits(
|
|
540
|
+
req: InferenceRequest, i: int, logits, temperature, repetition_penalty, top_p, top_k
|
|
541
|
+
):
|
|
542
|
+
logits_processor = prepare_logits_processor(
|
|
543
|
+
temperature, repetition_penalty, top_p, top_k
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
if logits_processor:
|
|
547
|
+
if repetition_penalty > 1.0:
|
|
548
|
+
tmp_output_ids = torch.as_tensor(
|
|
549
|
+
[req.prompt_tokens + req.new_tokens], device=logits.device
|
|
550
|
+
)
|
|
551
|
+
else:
|
|
552
|
+
tmp_output_ids = None
|
|
553
|
+
last_token_logits = logits_processor(tmp_output_ids, logits[i : i + 1, -1, :])[
|
|
554
|
+
0
|
|
555
|
+
]
|
|
556
|
+
else:
|
|
557
|
+
last_token_logits = logits[i : i + 1, -1, :]
|
|
558
|
+
|
|
559
|
+
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
|
560
|
+
_, indices = torch.topk(last_token_logits, 2)
|
|
561
|
+
else:
|
|
562
|
+
probs = torch.softmax(last_token_logits, dim=-1)
|
|
563
|
+
indices = torch.multinomial(probs, num_samples=2)
|
|
564
|
+
token = indices[0].int().item()
|
|
565
|
+
return token
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def _pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
|
|
569
|
+
assert len(x) <= max_len
|
|
570
|
+
return [pad] * (max_len - len(x)) + x
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def _pad_seqs_inplace(seqs: List[List[int]], pad: int):
|
|
574
|
+
max_len = max(len(seq) for seq in seqs)
|
|
575
|
+
n = len(seqs)
|
|
576
|
+
i = 0
|
|
577
|
+
while i < n:
|
|
578
|
+
seqs[i] = _pad_to_max_length(seqs[i], max_len, pad)
|
|
579
|
+
i += 1
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
def get_max_src_len(context_len: int, r: InferenceRequest) -> int:
|
|
583
|
+
max_new_tokens = int(
|
|
584
|
+
r.sanitized_generate_config.get("max_tokens", max_tokens_field.default)
|
|
585
|
+
)
|
|
586
|
+
return context_len - max_new_tokens - 8
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def _get_completion_chunk(
|
|
590
|
+
output: str,
|
|
591
|
+
chunk_id: str,
|
|
592
|
+
finish_reason: Optional[str],
|
|
593
|
+
model_uid: str,
|
|
594
|
+
r: InferenceRequest,
|
|
595
|
+
just_usage: bool,
|
|
596
|
+
):
|
|
597
|
+
completion_choice = (
|
|
598
|
+
[
|
|
599
|
+
CompletionChoice(
|
|
600
|
+
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
601
|
+
)
|
|
602
|
+
]
|
|
603
|
+
if not just_usage
|
|
604
|
+
else []
|
|
605
|
+
)
|
|
606
|
+
completion_chunk = CompletionChunk(
|
|
607
|
+
id=chunk_id,
|
|
608
|
+
object="text_completion",
|
|
609
|
+
created=int(time.time()),
|
|
610
|
+
model=model_uid,
|
|
611
|
+
choices=completion_choice,
|
|
612
|
+
)
|
|
613
|
+
completion_usage = CompletionUsage(
|
|
614
|
+
prompt_tokens=len(r.prompt_tokens),
|
|
615
|
+
completion_tokens=len(r.new_tokens),
|
|
616
|
+
total_tokens=len(r.prompt_tokens) + len(r.new_tokens),
|
|
617
|
+
)
|
|
618
|
+
completion_chunk["usage"] = completion_usage
|
|
619
|
+
return completion_chunk
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
def _get_completion(
|
|
623
|
+
output: str,
|
|
624
|
+
chunk_id: str,
|
|
625
|
+
finish_reason: Optional[str],
|
|
626
|
+
model_uid: str,
|
|
627
|
+
r: InferenceRequest,
|
|
628
|
+
):
|
|
629
|
+
completion_choice = CompletionChoice(
|
|
630
|
+
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
completion_chunk = CompletionChunk(
|
|
634
|
+
id=chunk_id,
|
|
635
|
+
object="text_completion",
|
|
636
|
+
created=int(time.time()),
|
|
637
|
+
model=model_uid,
|
|
638
|
+
choices=[completion_choice],
|
|
639
|
+
)
|
|
640
|
+
completion_usage = CompletionUsage(
|
|
641
|
+
prompt_tokens=len(r.prompt_tokens),
|
|
642
|
+
completion_tokens=len(r.new_tokens),
|
|
643
|
+
total_tokens=len(r.prompt_tokens) + len(r.new_tokens),
|
|
644
|
+
)
|
|
645
|
+
completion = Completion(
|
|
646
|
+
id=completion_chunk["id"],
|
|
647
|
+
object=completion_chunk["object"],
|
|
648
|
+
created=completion_chunk["created"],
|
|
649
|
+
model=completion_chunk["model"],
|
|
650
|
+
choices=completion_chunk["choices"],
|
|
651
|
+
usage=completion_usage,
|
|
652
|
+
)
|
|
653
|
+
return completion
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
def _merge_kv_cache(
|
|
657
|
+
past_kv: Tuple[Tuple[torch.Tensor]], new_kv: Tuple[Tuple[torch.Tensor]]
|
|
658
|
+
):
|
|
659
|
+
from torch.nn.functional import pad
|
|
660
|
+
|
|
661
|
+
past_cache = DynamicCache.from_legacy_cache(past_kv)
|
|
662
|
+
new_cache = DynamicCache.from_legacy_cache(new_kv)
|
|
663
|
+
past_seq_len = past_cache.get_seq_length()
|
|
664
|
+
new_seq_len = new_cache.get_seq_length()
|
|
665
|
+
if past_seq_len != new_seq_len:
|
|
666
|
+
padding_target = new_cache if past_seq_len > new_seq_len else past_cache
|
|
667
|
+
padding_len = abs(past_seq_len - new_seq_len)
|
|
668
|
+
for idx in range(len(padding_target)):
|
|
669
|
+
k = padding_target.key_cache[idx]
|
|
670
|
+
v = padding_target.value_cache[idx]
|
|
671
|
+
_k = pad(k, (0, 0, padding_len, 0))
|
|
672
|
+
_v = pad(v, (0, 0, padding_len, 0))
|
|
673
|
+
padding_target.key_cache[idx] = _k
|
|
674
|
+
padding_target.value_cache[idx] = _v
|
|
675
|
+
|
|
676
|
+
ret_kv = DynamicCache()
|
|
677
|
+
for idx in range(len(past_cache)):
|
|
678
|
+
k1, k2 = new_cache.key_cache[idx], past_cache.key_cache[idx]
|
|
679
|
+
v1, v2 = new_cache.value_cache[idx], past_cache.value_cache[idx]
|
|
680
|
+
ret_kv.update(torch.cat((k1, k2), 0), torch.cat((v1, v2), 0), idx)
|
|
681
|
+
return ret_kv.to_legacy_cache()
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
@torch.inference_mode()
|
|
685
|
+
def _batch_inference_one_step_internal(
|
|
686
|
+
req_list: List[InferenceRequest],
|
|
687
|
+
model_uid,
|
|
688
|
+
model,
|
|
689
|
+
tokenizer,
|
|
690
|
+
device,
|
|
691
|
+
context_len: int,
|
|
692
|
+
decode_round: int = 16,
|
|
693
|
+
bos_flag: str = "<bos_stream>",
|
|
694
|
+
eos_flag: str = "<eos_stream>",
|
|
695
|
+
):
|
|
696
|
+
# need to judge stopped here,
|
|
697
|
+
# since some requests state may change to stopped due to invalid parameters, e.g. max_src_len
|
|
698
|
+
valid_req_list = [r for r in req_list if not r.stopped]
|
|
699
|
+
if not valid_req_list:
|
|
700
|
+
return
|
|
701
|
+
generate_config_mapping: Dict[InferenceRequest, Tuple] = {
|
|
702
|
+
r: r.get_generate_configs(tokenizer.eos_token_id) for r in valid_req_list
|
|
703
|
+
}
|
|
704
|
+
s_time = time.time()
|
|
705
|
+
|
|
706
|
+
prefill_reqs = []
|
|
707
|
+
prompts = []
|
|
708
|
+
decode_reqs = []
|
|
709
|
+
for r in valid_req_list:
|
|
710
|
+
if r.is_prefill:
|
|
711
|
+
prompts.append(r.full_prompt if r.full_prompt is not None else r.prompt)
|
|
712
|
+
prefill_reqs.append(r)
|
|
713
|
+
else:
|
|
714
|
+
decode_reqs.append(r)
|
|
715
|
+
|
|
716
|
+
if prompts: # prefill first
|
|
717
|
+
input_ids: List[List[int]] = tokenizer(prompts, padding=False).input_ids
|
|
718
|
+
prompt_tokens = []
|
|
719
|
+
for i, input_id in enumerate(input_ids):
|
|
720
|
+
req = valid_req_list[i]
|
|
721
|
+
max_src_len = get_max_src_len(context_len, req)
|
|
722
|
+
req.prompt_tokens = input_id[-max_src_len:]
|
|
723
|
+
prompt_tokens.append(req.prompt_tokens)
|
|
724
|
+
_pad_seqs_inplace(prompt_tokens, 0)
|
|
725
|
+
out = model(torch.as_tensor(prompt_tokens, device=device), use_cache=True)
|
|
726
|
+
|
|
727
|
+
logits = out.logits
|
|
728
|
+
past_key_values = out.past_key_values
|
|
729
|
+
|
|
730
|
+
for i, r in enumerate(prefill_reqs):
|
|
731
|
+
(
|
|
732
|
+
max_new_tokens,
|
|
733
|
+
stream_interval,
|
|
734
|
+
include_usage,
|
|
735
|
+
stop_str,
|
|
736
|
+
stop_token_ids,
|
|
737
|
+
temperature,
|
|
738
|
+
repetition_penalty,
|
|
739
|
+
top_p,
|
|
740
|
+
top_k,
|
|
741
|
+
) = generate_config_mapping[r]
|
|
742
|
+
|
|
743
|
+
token = _get_token_from_logits(
|
|
744
|
+
r, i, logits, temperature, repetition_penalty, top_p, top_k
|
|
745
|
+
)
|
|
746
|
+
r.is_prefill = False
|
|
747
|
+
r.append_new_token(token)
|
|
748
|
+
|
|
749
|
+
if decode_reqs:
|
|
750
|
+
decode_kv = decode_reqs[0].kv_cache
|
|
751
|
+
# prefill and decode kv cache need to be merged at `batch_size` and `seq_len` dimensions.
|
|
752
|
+
merged_kv_cache = _merge_kv_cache(decode_kv, past_key_values)
|
|
753
|
+
for r in valid_req_list:
|
|
754
|
+
r.kv_cache = merged_kv_cache
|
|
755
|
+
empty_cache()
|
|
756
|
+
else:
|
|
757
|
+
for r in valid_req_list:
|
|
758
|
+
r.kv_cache = past_key_values
|
|
759
|
+
|
|
760
|
+
past_key_values = valid_req_list[0].kv_cache
|
|
761
|
+
stop_token_mapping: Dict[InferenceRequest, int] = {}
|
|
762
|
+
output_mapping: Dict[InferenceRequest, str] = {}
|
|
763
|
+
# here, only decode phase, just run some rounds
|
|
764
|
+
for _i in range(decode_round):
|
|
765
|
+
decode_tokens: List[List[int]] = [[r.new_tokens[-1]] for r in valid_req_list]
|
|
766
|
+
out = model(
|
|
767
|
+
input_ids=torch.as_tensor(decode_tokens, device=device),
|
|
768
|
+
use_cache=True,
|
|
769
|
+
past_key_values=past_key_values,
|
|
770
|
+
)
|
|
771
|
+
logits = out.logits
|
|
772
|
+
past_key_values = out.past_key_values
|
|
773
|
+
|
|
774
|
+
for i, r in enumerate(valid_req_list):
|
|
775
|
+
(
|
|
776
|
+
max_new_tokens,
|
|
777
|
+
stream_interval,
|
|
778
|
+
include_usage,
|
|
779
|
+
stop_str,
|
|
780
|
+
stop_token_ids,
|
|
781
|
+
temperature,
|
|
782
|
+
repetition_penalty,
|
|
783
|
+
top_p,
|
|
784
|
+
top_k,
|
|
785
|
+
) = generate_config_mapping[r]
|
|
786
|
+
|
|
787
|
+
token = _get_token_from_logits(
|
|
788
|
+
r, i, logits, temperature, repetition_penalty, top_p, top_k
|
|
789
|
+
)
|
|
790
|
+
r.kv_cache = past_key_values
|
|
791
|
+
r.append_new_token(token)
|
|
792
|
+
|
|
793
|
+
output = None
|
|
794
|
+
if not r.stopped:
|
|
795
|
+
stopped = token in stop_token_ids
|
|
796
|
+
|
|
797
|
+
if stopped:
|
|
798
|
+
finish_reason = "stop"
|
|
799
|
+
elif len(r.new_tokens) == max_new_tokens:
|
|
800
|
+
finish_reason = "length"
|
|
801
|
+
stopped = True
|
|
802
|
+
else:
|
|
803
|
+
finish_reason = None
|
|
804
|
+
|
|
805
|
+
# handle stop str
|
|
806
|
+
if stop_str and r not in output_mapping:
|
|
807
|
+
output = tokenizer.decode(
|
|
808
|
+
r.new_tokens,
|
|
809
|
+
skip_special_tokens=True,
|
|
810
|
+
spaces_between_special_tokens=False,
|
|
811
|
+
clean_up_tokenization_spaces=True,
|
|
812
|
+
)
|
|
813
|
+
if isinstance(stop_str, str):
|
|
814
|
+
stop_str = [stop_str]
|
|
815
|
+
for stop in stop_str:
|
|
816
|
+
pos = output.rfind(stop)
|
|
817
|
+
if pos != -1:
|
|
818
|
+
output = output[:pos]
|
|
819
|
+
output_mapping[r] = output
|
|
820
|
+
stopped = True
|
|
821
|
+
finish_reason = "stop"
|
|
822
|
+
break
|
|
823
|
+
|
|
824
|
+
r.stopped = stopped
|
|
825
|
+
r.finish_reason = finish_reason
|
|
826
|
+
|
|
827
|
+
if r.stopped and r not in stop_token_mapping and r not in output_mapping:
|
|
828
|
+
stop_token_mapping[r] = _i + 1
|
|
829
|
+
|
|
830
|
+
if r.stream:
|
|
831
|
+
"""
|
|
832
|
+
Note that you can't just decode based on the newest r.new_tokens here,
|
|
833
|
+
which may destroy the integrity of the parsed characters,
|
|
834
|
+
and at the same time is not good at handling some special characters.
|
|
835
|
+
So the implementation here is to decode all the tokens that have been generated each time,
|
|
836
|
+
and then take the slice.
|
|
837
|
+
"""
|
|
838
|
+
if r.stopped or len(r.new_tokens) % stream_interval == 0:
|
|
839
|
+
if output is None:
|
|
840
|
+
output = tokenizer.decode(
|
|
841
|
+
r.new_tokens,
|
|
842
|
+
skip_special_tokens=True,
|
|
843
|
+
spaces_between_special_tokens=False,
|
|
844
|
+
clean_up_tokenization_spaces=True,
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
if r.last_output_length == 0:
|
|
848
|
+
r.completion.append(bos_flag)
|
|
849
|
+
|
|
850
|
+
# this special character is mainly for qwen
|
|
851
|
+
output = output.strip("�")
|
|
852
|
+
output = output[r.last_output_length :]
|
|
853
|
+
r.last_output_length += len(output)
|
|
854
|
+
|
|
855
|
+
completion_chunk = _get_completion_chunk(
|
|
856
|
+
output, r.chunk_id, r.finish_reason, model_uid, r, False
|
|
857
|
+
)
|
|
858
|
+
r.completion.append(completion_chunk)
|
|
859
|
+
if r.stopped:
|
|
860
|
+
r.completion.append(eos_flag)
|
|
861
|
+
|
|
862
|
+
# last round, handle stream result
|
|
863
|
+
# append usage information when enable `include_usage` for OPENAI API compatibility
|
|
864
|
+
# The reason for counting the usage in the last round of the iteration is that,
|
|
865
|
+
# these tokens are real generated and should be counted.
|
|
866
|
+
if r.stopped and _i == decode_round - 1 and include_usage:
|
|
867
|
+
r.completion.append(
|
|
868
|
+
_get_completion_chunk(
|
|
869
|
+
"", r.chunk_id, r.finish_reason, model_uid, r, True
|
|
870
|
+
)
|
|
871
|
+
)
|
|
872
|
+
else:
|
|
873
|
+
# last round, handle non-stream result
|
|
874
|
+
if r.stopped and _i == decode_round - 1:
|
|
875
|
+
invalid_token_num = decode_round - stop_token_mapping[r]
|
|
876
|
+
outputs = (
|
|
877
|
+
tokenizer.decode(
|
|
878
|
+
r.new_tokens[: -(invalid_token_num + 1)]
|
|
879
|
+
if r.finish_reason == "stop"
|
|
880
|
+
else r.new_tokens[:-invalid_token_num],
|
|
881
|
+
skip_special_tokens=True,
|
|
882
|
+
spaces_between_special_tokens=False,
|
|
883
|
+
clean_up_tokenization_spaces=True,
|
|
884
|
+
)
|
|
885
|
+
if r not in output_mapping
|
|
886
|
+
else output_mapping[r]
|
|
887
|
+
)
|
|
888
|
+
completion = _get_completion(
|
|
889
|
+
outputs, r.chunk_id, r.finish_reason, model_uid, r
|
|
890
|
+
)
|
|
891
|
+
r.completion = [completion]
|
|
892
|
+
|
|
893
|
+
e_time = time.time()
|
|
894
|
+
logger.debug(
|
|
895
|
+
f"Average throughput for a step: {(len(valid_req_list) * decode_round + len(prompts)) / (e_time - s_time)} token/s."
|
|
896
|
+
)
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
def batch_inference_one_step(
|
|
900
|
+
req_list: List[InferenceRequest],
|
|
901
|
+
model_uid,
|
|
902
|
+
model,
|
|
903
|
+
tokenizer,
|
|
904
|
+
device,
|
|
905
|
+
context_len: int,
|
|
906
|
+
):
|
|
907
|
+
from ....core.model import OutOfMemoryError
|
|
908
|
+
|
|
909
|
+
try:
|
|
910
|
+
_batch_inference_one_step_internal(
|
|
911
|
+
req_list, model_uid, model, tokenizer, device, context_len
|
|
912
|
+
)
|
|
913
|
+
except OutOfMemoryError:
|
|
914
|
+
logger.exception(
|
|
915
|
+
f"Batch inference out of memory. "
|
|
916
|
+
f"Xinference will restart the model: {model_uid}. "
|
|
917
|
+
f"Please be patient for a few moments."
|
|
918
|
+
)
|
|
919
|
+
# Just kill the process and let xinference auto-recover the model
|
|
920
|
+
os._exit(1)
|
|
921
|
+
except Exception as e:
|
|
922
|
+
logger.exception(f"Internal error for batch inference: {e}.")
|
|
923
|
+
# If internal error happens, just skip all the requests in this batch.
|
|
924
|
+
# If not handle here, the client will hang.
|
|
925
|
+
for r in req_list:
|
|
926
|
+
r.stopped = True
|
|
927
|
+
r.error_msg = str(e)
|
xinference/model/llm/utils.py
CHANGED
|
@@ -607,7 +607,7 @@ Begin!"""
|
|
|
607
607
|
return arguments, None, None
|
|
608
608
|
|
|
609
609
|
@staticmethod
|
|
610
|
-
def
|
|
610
|
+
def _eval_glm_chat_arguments(c, tools):
|
|
611
611
|
if isinstance(c[0], str):
|
|
612
612
|
return c[0], None, None
|
|
613
613
|
return None, c[0]["name"], c[0]["parameters"]
|
|
@@ -659,9 +659,9 @@ Begin!"""
|
|
|
659
659
|
family = model_family.model_family or model_family.model_name
|
|
660
660
|
if family in ["gorilla-openfunctions-v1", "gorilla-openfunctions-v2"]:
|
|
661
661
|
content, func, args = cls._eval_gorilla_openfunctions_arguments(c, tools)
|
|
662
|
-
elif "chatglm3"
|
|
663
|
-
content, func, args = cls.
|
|
664
|
-
elif family in ["qwen-chat", "qwen1.5-chat"]:
|
|
662
|
+
elif family in ["chatglm3", "glm4-chat"]:
|
|
663
|
+
content, func, args = cls._eval_glm_chat_arguments(c, tools)
|
|
664
|
+
elif family in ["qwen-chat", "qwen1.5-chat", "qwen2-instruct"]:
|
|
665
665
|
content, func, args = cls._eval_qwen_chat_arguments(c, tools)
|
|
666
666
|
else:
|
|
667
667
|
raise Exception(
|
|
@@ -676,28 +676,29 @@ Begin!"""
|
|
|
676
676
|
Generates a filter function for Qwen series models to retain outputs after "\nFinal Answer:".
|
|
677
677
|
|
|
678
678
|
Returns:
|
|
679
|
-
A function that takes tokens (string output by the model so far) as input
|
|
680
|
-
returns
|
|
679
|
+
A function that takes tokens (string output by the model so far) and delta (new tokens added) as input,
|
|
680
|
+
returns the part after "\nFinal Answer:" if found, else returns delta.
|
|
681
681
|
"""
|
|
682
682
|
family = model_family.model_family or model_family.model_name
|
|
683
683
|
if family in ["qwen-chat", "qwen1.5-chat"]:
|
|
684
684
|
# Encapsulating function to reset 'found' after each call
|
|
685
685
|
found = False
|
|
686
686
|
|
|
687
|
-
def
|
|
687
|
+
def process_tokens(tokens: str, delta: str):
|
|
688
688
|
nonlocal found
|
|
689
689
|
# Once "Final Answer:" is found, future tokens are allowed.
|
|
690
690
|
if found:
|
|
691
|
-
return
|
|
691
|
+
return delta
|
|
692
692
|
# Check if the token ends with "\nFinal Answer:" and update `found`.
|
|
693
|
-
|
|
693
|
+
final_answer_idx = tokens.lower().rfind("\nfinal answer:")
|
|
694
|
+
if final_answer_idx != -1:
|
|
694
695
|
found = True
|
|
695
|
-
|
|
696
|
+
return tokens[final_answer_idx + len("\nfinal answer:") :]
|
|
697
|
+
return ""
|
|
696
698
|
|
|
697
|
-
return
|
|
699
|
+
return process_tokens
|
|
698
700
|
else:
|
|
699
|
-
|
|
700
|
-
return lambda tokens: True
|
|
701
|
+
return lambda tokens, delta: delta
|
|
701
702
|
|
|
702
703
|
@classmethod
|
|
703
704
|
def _tool_calls_completion(cls, model_family, model_uid, c, tools):
|
|
@@ -93,6 +93,7 @@ VLLM_SUPPORTED_MODELS = [
|
|
|
93
93
|
"baichuan",
|
|
94
94
|
"internlm-16k",
|
|
95
95
|
"mistral-v0.1",
|
|
96
|
+
"codestral-v0.1",
|
|
96
97
|
"Yi",
|
|
97
98
|
"Yi-1.5",
|
|
98
99
|
"code-llama",
|
|
@@ -118,11 +119,14 @@ VLLM_SUPPORTED_CHAT_MODELS = [
|
|
|
118
119
|
"code-llama-instruct",
|
|
119
120
|
"mistral-instruct-v0.1",
|
|
120
121
|
"mistral-instruct-v0.2",
|
|
122
|
+
"mistral-instruct-v0.3",
|
|
121
123
|
"mixtral-instruct-v0.1",
|
|
122
124
|
"mixtral-8x22B-instruct-v0.1",
|
|
123
125
|
"chatglm3",
|
|
124
126
|
"chatglm3-32k",
|
|
125
127
|
"chatglm3-128k",
|
|
128
|
+
"glm4-chat",
|
|
129
|
+
"glm4-chat-1m",
|
|
126
130
|
"deepseek-chat",
|
|
127
131
|
"deepseek-coder-instruct",
|
|
128
132
|
]
|
|
@@ -130,6 +134,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
|
|
|
130
134
|
VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-chat")
|
|
131
135
|
VLLM_SUPPORTED_MODELS.append("codeqwen1.5")
|
|
132
136
|
VLLM_SUPPORTED_CHAT_MODELS.append("codeqwen1.5-chat")
|
|
137
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("qwen2-instruct")
|
|
133
138
|
|
|
134
139
|
if VLLM_INSTALLED and vllm.__version__ >= "0.3.2":
|
|
135
140
|
VLLM_SUPPORTED_CHAT_MODELS.append("gemma-it")
|
|
@@ -140,6 +145,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.3":
|
|
|
140
145
|
|
|
141
146
|
if VLLM_INSTALLED and vllm.__version__ >= "0.4.0":
|
|
142
147
|
VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-moe-chat")
|
|
148
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("qwen2-moe-instruct")
|
|
143
149
|
VLLM_SUPPORTED_CHAT_MODELS.append("c4ai-command-r-v01")
|
|
144
150
|
|
|
145
151
|
|
|
@@ -438,7 +444,9 @@ class VLLMModel(LLM):
|
|
|
438
444
|
_content, func, args = ChatModelMixin._eval_tool_arguments(
|
|
439
445
|
self.model_family, chunk, tools
|
|
440
446
|
)
|
|
441
|
-
choice["text"] =
|
|
447
|
+
choice["text"] = tools_token_filter(
|
|
448
|
+
tokens=previous_texts[0], delta=choice_delta
|
|
449
|
+
)
|
|
442
450
|
if func is not None:
|
|
443
451
|
choice["text"] = None
|
|
444
452
|
choice["finish_reason"] = "tool_calls"
|
|
@@ -452,9 +460,13 @@ class VLLMModel(LLM):
|
|
|
452
460
|
),
|
|
453
461
|
)
|
|
454
462
|
]
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
463
|
+
else:
|
|
464
|
+
# use a filter function to skip Qwen's react thought process
|
|
465
|
+
choice["text"] = tools_token_filter(
|
|
466
|
+
tokens=previous_texts[0], delta=choice["text"]
|
|
467
|
+
)
|
|
468
|
+
if not choice["text"]:
|
|
469
|
+
continue
|
|
458
470
|
prompt_tokens = len(_request_output.prompt_token_ids)
|
|
459
471
|
completion_tokens = sum(
|
|
460
472
|
len(output.token_ids) for output in _request_output.outputs
|
xinference/model/utils.py
CHANGED
|
@@ -42,14 +42,20 @@ def is_locale_chinese_simplified() -> bool:
|
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
def download_from_modelscope() -> bool:
|
|
45
|
-
if os.environ.get(XINFERENCE_ENV_MODEL_SRC)
|
|
46
|
-
return
|
|
45
|
+
if os.environ.get(XINFERENCE_ENV_MODEL_SRC):
|
|
46
|
+
return os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "modelscope"
|
|
47
47
|
elif is_locale_chinese_simplified():
|
|
48
48
|
return True
|
|
49
49
|
else:
|
|
50
50
|
return False
|
|
51
51
|
|
|
52
52
|
|
|
53
|
+
def download_from_csghub() -> bool:
|
|
54
|
+
if os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "csghub":
|
|
55
|
+
return True
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
|
|
53
59
|
def symlink_local_file(path: str, local_dir: str, relpath: str) -> str:
|
|
54
60
|
from huggingface_hub.file_download import _create_symlink
|
|
55
61
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .core import Chat
|