xinference 0.11.3__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +69 -0
- xinference/client/restful/restful_client.py +70 -0
- xinference/constants.py +4 -0
- xinference/core/model.py +141 -12
- xinference/core/scheduler.py +428 -0
- xinference/core/supervisor.py +26 -0
- xinference/isolation.py +9 -2
- xinference/model/audio/chattts.py +84 -0
- xinference/model/audio/core.py +10 -3
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/llm/__init__.py +4 -0
- xinference/model/llm/llm_family.json +507 -1
- xinference/model/llm/llm_family_modelscope.json +409 -2
- xinference/model/llm/pytorch/chatglm.py +2 -1
- xinference/model/llm/pytorch/cogvlm2.py +76 -17
- xinference/model/llm/pytorch/core.py +91 -6
- xinference/model/llm/pytorch/glm4v.py +258 -0
- xinference/model/llm/pytorch/minicpmv25.py +232 -0
- xinference/model/llm/pytorch/utils.py +386 -2
- xinference/model/llm/vllm/core.py +6 -0
- xinference/thirdparty/ChatTTS/__init__.py +1 -0
- xinference/thirdparty/ChatTTS/core.py +200 -0
- xinference/types.py +3 -0
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/METADATA +26 -9
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/RECORD +30 -24
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/LICENSE +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/WHEEL +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/top_level.txt +0 -0
|
@@ -14,13 +14,15 @@
|
|
|
14
14
|
|
|
15
15
|
import gc
|
|
16
16
|
import logging
|
|
17
|
+
import os
|
|
17
18
|
import time
|
|
18
19
|
import uuid
|
|
19
20
|
from threading import Thread
|
|
20
|
-
from typing import Iterable, Iterator, Tuple
|
|
21
|
+
from typing import Dict, Iterable, Iterator, List, Optional, Tuple
|
|
21
22
|
|
|
22
23
|
import torch
|
|
23
24
|
from transformers import GenerationConfig, TextIteratorStreamer
|
|
25
|
+
from transformers.cache_utils import DynamicCache
|
|
24
26
|
from transformers.generation.logits_process import (
|
|
25
27
|
LogitsProcessorList,
|
|
26
28
|
RepetitionPenaltyLogitsProcessor,
|
|
@@ -29,8 +31,10 @@ from transformers.generation.logits_process import (
|
|
|
29
31
|
TopPLogitsWarper,
|
|
30
32
|
)
|
|
31
33
|
|
|
34
|
+
from ....core.scheduler import InferenceRequest
|
|
32
35
|
from ....device_utils import empty_cache
|
|
33
36
|
from ....types import (
|
|
37
|
+
Completion,
|
|
34
38
|
CompletionChoice,
|
|
35
39
|
CompletionChunk,
|
|
36
40
|
CompletionUsage,
|
|
@@ -54,7 +58,7 @@ def is_partial_stop(output: str, stop_str: str):
|
|
|
54
58
|
return False
|
|
55
59
|
|
|
56
60
|
|
|
57
|
-
def get_context_length(config):
|
|
61
|
+
def get_context_length(config) -> int:
|
|
58
62
|
"""Get the context length of a model from a huggingface model config."""
|
|
59
63
|
if (
|
|
60
64
|
hasattr(config, "max_sequence_length")
|
|
@@ -528,3 +532,383 @@ def generate_stream_falcon(
|
|
|
528
532
|
# clean
|
|
529
533
|
gc.collect()
|
|
530
534
|
empty_cache()
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
def _get_token_from_logits(
|
|
538
|
+
req: InferenceRequest, i: int, logits, temperature, repetition_penalty, top_p, top_k
|
|
539
|
+
):
|
|
540
|
+
logits_processor = prepare_logits_processor(
|
|
541
|
+
temperature, repetition_penalty, top_p, top_k
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
if logits_processor:
|
|
545
|
+
if repetition_penalty > 1.0:
|
|
546
|
+
tmp_output_ids = torch.as_tensor(
|
|
547
|
+
[req.prompt_tokens + req.new_tokens], device=logits.device
|
|
548
|
+
)
|
|
549
|
+
else:
|
|
550
|
+
tmp_output_ids = None
|
|
551
|
+
last_token_logits = logits_processor(tmp_output_ids, logits[i : i + 1, -1, :])[
|
|
552
|
+
0
|
|
553
|
+
]
|
|
554
|
+
else:
|
|
555
|
+
last_token_logits = logits[i : i + 1, -1, :]
|
|
556
|
+
|
|
557
|
+
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
|
558
|
+
_, indices = torch.topk(last_token_logits, 2)
|
|
559
|
+
else:
|
|
560
|
+
probs = torch.softmax(last_token_logits, dim=-1)
|
|
561
|
+
indices = torch.multinomial(probs, num_samples=2)
|
|
562
|
+
token = indices[0].int().item()
|
|
563
|
+
return token
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def _pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
|
|
567
|
+
assert len(x) <= max_len
|
|
568
|
+
return [pad] * (max_len - len(x)) + x
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
def _pad_seqs_inplace(seqs: List[List[int]], pad: int):
|
|
572
|
+
max_len = max(len(seq) for seq in seqs)
|
|
573
|
+
n = len(seqs)
|
|
574
|
+
i = 0
|
|
575
|
+
while i < n:
|
|
576
|
+
seqs[i] = _pad_to_max_length(seqs[i], max_len, pad)
|
|
577
|
+
i += 1
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def get_max_src_len(context_len: int, r: InferenceRequest) -> int:
|
|
581
|
+
max_new_tokens = int(
|
|
582
|
+
r.sanitized_generate_config.get("max_tokens", max_tokens_field.default)
|
|
583
|
+
)
|
|
584
|
+
return context_len - max_new_tokens - 8
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
def _get_completion_chunk(
|
|
588
|
+
output: str,
|
|
589
|
+
finish_reason: Optional[str],
|
|
590
|
+
model_uid: str,
|
|
591
|
+
r: InferenceRequest,
|
|
592
|
+
just_usage: bool,
|
|
593
|
+
):
|
|
594
|
+
completion_choice = (
|
|
595
|
+
[
|
|
596
|
+
CompletionChoice(
|
|
597
|
+
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
598
|
+
)
|
|
599
|
+
]
|
|
600
|
+
if not just_usage
|
|
601
|
+
else []
|
|
602
|
+
)
|
|
603
|
+
completion_chunk = CompletionChunk(
|
|
604
|
+
id=str(uuid.uuid1()),
|
|
605
|
+
object="text_completion",
|
|
606
|
+
created=int(time.time()),
|
|
607
|
+
model=model_uid,
|
|
608
|
+
choices=completion_choice,
|
|
609
|
+
)
|
|
610
|
+
completion_usage = CompletionUsage(
|
|
611
|
+
prompt_tokens=len(r.prompt_tokens),
|
|
612
|
+
completion_tokens=len(r.new_tokens),
|
|
613
|
+
total_tokens=len(r.prompt_tokens) + len(r.new_tokens),
|
|
614
|
+
)
|
|
615
|
+
completion_chunk["usage"] = completion_usage
|
|
616
|
+
return completion_chunk
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def _get_completion(
|
|
620
|
+
output: str, finish_reason: Optional[str], model_uid: str, r: InferenceRequest
|
|
621
|
+
):
|
|
622
|
+
completion_choice = CompletionChoice(
|
|
623
|
+
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
completion_chunk = CompletionChunk(
|
|
627
|
+
id=str(uuid.uuid1()),
|
|
628
|
+
object="text_completion",
|
|
629
|
+
created=int(time.time()),
|
|
630
|
+
model=model_uid,
|
|
631
|
+
choices=[completion_choice],
|
|
632
|
+
)
|
|
633
|
+
completion_usage = CompletionUsage(
|
|
634
|
+
prompt_tokens=len(r.prompt_tokens),
|
|
635
|
+
completion_tokens=len(r.new_tokens),
|
|
636
|
+
total_tokens=len(r.prompt_tokens) + len(r.new_tokens),
|
|
637
|
+
)
|
|
638
|
+
completion = Completion(
|
|
639
|
+
id=completion_chunk["id"],
|
|
640
|
+
object=completion_chunk["object"],
|
|
641
|
+
created=completion_chunk["created"],
|
|
642
|
+
model=completion_chunk["model"],
|
|
643
|
+
choices=completion_chunk["choices"],
|
|
644
|
+
usage=completion_usage,
|
|
645
|
+
)
|
|
646
|
+
return completion
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
def _merge_kv_cache(
|
|
650
|
+
past_kv: Tuple[Tuple[torch.Tensor]], new_kv: Tuple[Tuple[torch.Tensor]]
|
|
651
|
+
):
|
|
652
|
+
from torch.nn.functional import pad
|
|
653
|
+
|
|
654
|
+
past_cache = DynamicCache.from_legacy_cache(past_kv)
|
|
655
|
+
new_cache = DynamicCache.from_legacy_cache(new_kv)
|
|
656
|
+
past_seq_len = past_cache.get_seq_length()
|
|
657
|
+
new_seq_len = new_cache.get_seq_length()
|
|
658
|
+
if past_seq_len != new_seq_len:
|
|
659
|
+
padding_target = new_cache if past_seq_len > new_seq_len else past_cache
|
|
660
|
+
padding_len = abs(past_seq_len - new_seq_len)
|
|
661
|
+
for idx in range(len(padding_target)):
|
|
662
|
+
k = padding_target.key_cache[idx]
|
|
663
|
+
v = padding_target.value_cache[idx]
|
|
664
|
+
_k = pad(k, (0, 0, padding_len, 0))
|
|
665
|
+
_v = pad(v, (0, 0, padding_len, 0))
|
|
666
|
+
padding_target.key_cache[idx] = _k
|
|
667
|
+
padding_target.value_cache[idx] = _v
|
|
668
|
+
|
|
669
|
+
ret_kv = DynamicCache()
|
|
670
|
+
for idx in range(len(past_cache)):
|
|
671
|
+
k1, k2 = new_cache.key_cache[idx], past_cache.key_cache[idx]
|
|
672
|
+
v1, v2 = new_cache.value_cache[idx], past_cache.value_cache[idx]
|
|
673
|
+
ret_kv.update(torch.cat((k1, k2), 0), torch.cat((v1, v2), 0), idx)
|
|
674
|
+
return ret_kv.to_legacy_cache()
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
@torch.inference_mode()
|
|
678
|
+
def _batch_inference_one_step_internal(
|
|
679
|
+
req_list: List[InferenceRequest],
|
|
680
|
+
model_uid,
|
|
681
|
+
model,
|
|
682
|
+
tokenizer,
|
|
683
|
+
device,
|
|
684
|
+
context_len: int,
|
|
685
|
+
decode_round: int = 16,
|
|
686
|
+
bos_flag: str = "<bos_stream>",
|
|
687
|
+
eos_flag: str = "<eos_stream>",
|
|
688
|
+
):
|
|
689
|
+
# need to judge stopped here,
|
|
690
|
+
# since some requests state may change to stopped due to invalid parameters, e.g. max_src_len
|
|
691
|
+
valid_req_list = [r for r in req_list if not r.stopped]
|
|
692
|
+
if not valid_req_list:
|
|
693
|
+
return
|
|
694
|
+
generate_config_mapping: Dict[InferenceRequest, Tuple] = {
|
|
695
|
+
r: r.get_generate_configs(tokenizer.eos_token_id) for r in valid_req_list
|
|
696
|
+
}
|
|
697
|
+
s_time = time.time()
|
|
698
|
+
|
|
699
|
+
prefill_reqs = []
|
|
700
|
+
prompts = []
|
|
701
|
+
decode_reqs = []
|
|
702
|
+
for r in valid_req_list:
|
|
703
|
+
if r.is_prefill:
|
|
704
|
+
prompts.append(r.full_prompt)
|
|
705
|
+
prefill_reqs.append(r)
|
|
706
|
+
else:
|
|
707
|
+
decode_reqs.append(r)
|
|
708
|
+
|
|
709
|
+
if prompts: # prefill first
|
|
710
|
+
input_ids: List[List[int]] = tokenizer(prompts, padding=False).input_ids
|
|
711
|
+
prompt_tokens = []
|
|
712
|
+
for i, input_id in enumerate(input_ids):
|
|
713
|
+
req = valid_req_list[i]
|
|
714
|
+
max_src_len = get_max_src_len(context_len, req)
|
|
715
|
+
req.prompt_tokens = input_id[-max_src_len:]
|
|
716
|
+
prompt_tokens.append(req.prompt_tokens)
|
|
717
|
+
_pad_seqs_inplace(prompt_tokens, 0)
|
|
718
|
+
out = model(torch.as_tensor(prompt_tokens, device=device), use_cache=True)
|
|
719
|
+
|
|
720
|
+
logits = out.logits
|
|
721
|
+
past_key_values = out.past_key_values
|
|
722
|
+
|
|
723
|
+
for i, r in enumerate(prefill_reqs):
|
|
724
|
+
(
|
|
725
|
+
max_new_tokens,
|
|
726
|
+
stream_interval,
|
|
727
|
+
include_usage,
|
|
728
|
+
stop_str,
|
|
729
|
+
stop_token_ids,
|
|
730
|
+
temperature,
|
|
731
|
+
repetition_penalty,
|
|
732
|
+
top_p,
|
|
733
|
+
top_k,
|
|
734
|
+
) = generate_config_mapping[r]
|
|
735
|
+
|
|
736
|
+
token = _get_token_from_logits(
|
|
737
|
+
r, i, logits, temperature, repetition_penalty, top_p, top_k
|
|
738
|
+
)
|
|
739
|
+
r.is_prefill = False
|
|
740
|
+
r.append_new_token(token)
|
|
741
|
+
|
|
742
|
+
if decode_reqs:
|
|
743
|
+
decode_kv = decode_reqs[0].kv_cache
|
|
744
|
+
# prefill and decode kv cache need to be merged at `batch_size` and `seq_len` dimensions.
|
|
745
|
+
merged_kv_cache = _merge_kv_cache(decode_kv, past_key_values)
|
|
746
|
+
for r in valid_req_list:
|
|
747
|
+
r.kv_cache = merged_kv_cache
|
|
748
|
+
empty_cache()
|
|
749
|
+
else:
|
|
750
|
+
for r in valid_req_list:
|
|
751
|
+
r.kv_cache = past_key_values
|
|
752
|
+
|
|
753
|
+
past_key_values = valid_req_list[0].kv_cache
|
|
754
|
+
stop_token_mapping: Dict[InferenceRequest, int] = {}
|
|
755
|
+
output_mapping: Dict[InferenceRequest, str] = {}
|
|
756
|
+
# here, only decode phase, just run some rounds
|
|
757
|
+
for _i in range(decode_round):
|
|
758
|
+
decode_tokens: List[List[int]] = [[r.new_tokens[-1]] for r in valid_req_list]
|
|
759
|
+
out = model(
|
|
760
|
+
input_ids=torch.as_tensor(decode_tokens, device=device),
|
|
761
|
+
use_cache=True,
|
|
762
|
+
past_key_values=past_key_values,
|
|
763
|
+
)
|
|
764
|
+
logits = out.logits
|
|
765
|
+
past_key_values = out.past_key_values
|
|
766
|
+
|
|
767
|
+
for i, r in enumerate(valid_req_list):
|
|
768
|
+
(
|
|
769
|
+
max_new_tokens,
|
|
770
|
+
stream_interval,
|
|
771
|
+
include_usage,
|
|
772
|
+
stop_str,
|
|
773
|
+
stop_token_ids,
|
|
774
|
+
temperature,
|
|
775
|
+
repetition_penalty,
|
|
776
|
+
top_p,
|
|
777
|
+
top_k,
|
|
778
|
+
) = generate_config_mapping[r]
|
|
779
|
+
|
|
780
|
+
token = _get_token_from_logits(
|
|
781
|
+
r, i, logits, temperature, repetition_penalty, top_p, top_k
|
|
782
|
+
)
|
|
783
|
+
r.kv_cache = past_key_values
|
|
784
|
+
r.append_new_token(token)
|
|
785
|
+
|
|
786
|
+
output = None
|
|
787
|
+
if not r.stopped:
|
|
788
|
+
stopped = token in stop_token_ids
|
|
789
|
+
|
|
790
|
+
if stopped:
|
|
791
|
+
finish_reason = "stop"
|
|
792
|
+
elif len(r.new_tokens) == max_new_tokens:
|
|
793
|
+
finish_reason = "length"
|
|
794
|
+
stopped = True
|
|
795
|
+
else:
|
|
796
|
+
finish_reason = None
|
|
797
|
+
|
|
798
|
+
# handle stop str
|
|
799
|
+
if stop_str and r not in output_mapping:
|
|
800
|
+
output = tokenizer.decode(
|
|
801
|
+
r.new_tokens,
|
|
802
|
+
skip_special_tokens=True,
|
|
803
|
+
spaces_between_special_tokens=False,
|
|
804
|
+
clean_up_tokenization_spaces=True,
|
|
805
|
+
)
|
|
806
|
+
if isinstance(stop_str, str):
|
|
807
|
+
stop_str = [stop_str]
|
|
808
|
+
for stop in stop_str:
|
|
809
|
+
pos = output.rfind(stop)
|
|
810
|
+
if pos != -1:
|
|
811
|
+
output = output[:pos]
|
|
812
|
+
output_mapping[r] = output
|
|
813
|
+
stopped = True
|
|
814
|
+
finish_reason = "stop"
|
|
815
|
+
break
|
|
816
|
+
|
|
817
|
+
r.stopped = stopped
|
|
818
|
+
r.finish_reason = finish_reason
|
|
819
|
+
|
|
820
|
+
if r.stopped and r not in stop_token_mapping and r not in output_mapping:
|
|
821
|
+
stop_token_mapping[r] = _i + 1
|
|
822
|
+
|
|
823
|
+
if r.stream:
|
|
824
|
+
"""
|
|
825
|
+
Note that you can't just decode based on the newest r.new_tokens here,
|
|
826
|
+
which may destroy the integrity of the parsed characters,
|
|
827
|
+
and at the same time is not good at handling some special characters.
|
|
828
|
+
So the implementation here is to decode all the tokens that have been generated each time,
|
|
829
|
+
and then take the slice.
|
|
830
|
+
"""
|
|
831
|
+
if r.stopped or len(r.new_tokens) % stream_interval == 0:
|
|
832
|
+
if output is None:
|
|
833
|
+
output = tokenizer.decode(
|
|
834
|
+
r.new_tokens,
|
|
835
|
+
skip_special_tokens=True,
|
|
836
|
+
spaces_between_special_tokens=False,
|
|
837
|
+
clean_up_tokenization_spaces=True,
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
if r.last_output_length == 0:
|
|
841
|
+
r.completion.append(bos_flag)
|
|
842
|
+
|
|
843
|
+
# this special character is mainly for qwen
|
|
844
|
+
output = output.strip("�")
|
|
845
|
+
output = output[r.last_output_length :]
|
|
846
|
+
r.last_output_length += len(output)
|
|
847
|
+
|
|
848
|
+
completion_chunk = _get_completion_chunk(
|
|
849
|
+
output, r.finish_reason, model_uid, r, False
|
|
850
|
+
)
|
|
851
|
+
r.completion.append(completion_chunk)
|
|
852
|
+
if r.stopped:
|
|
853
|
+
r.completion.append(eos_flag)
|
|
854
|
+
|
|
855
|
+
# last round, handle stream result
|
|
856
|
+
# append usage information when enable `include_usage` for OPENAI API compatibility
|
|
857
|
+
# The reason for counting the usage in the last round of the iteration is that,
|
|
858
|
+
# these tokens are real generated and should be counted.
|
|
859
|
+
if r.stopped and _i == decode_round - 1 and include_usage:
|
|
860
|
+
r.completion.append(
|
|
861
|
+
_get_completion_chunk(
|
|
862
|
+
"", r.finish_reason, model_uid, r, True
|
|
863
|
+
)
|
|
864
|
+
)
|
|
865
|
+
else:
|
|
866
|
+
# last round, handle non-stream result
|
|
867
|
+
if r.stopped and _i == decode_round - 1:
|
|
868
|
+
invalid_token_num = decode_round - stop_token_mapping[r]
|
|
869
|
+
outputs = (
|
|
870
|
+
tokenizer.decode(
|
|
871
|
+
r.new_tokens[: -(invalid_token_num + 1)]
|
|
872
|
+
if r.finish_reason == "stop"
|
|
873
|
+
else r.new_tokens[:-invalid_token_num],
|
|
874
|
+
skip_special_tokens=True,
|
|
875
|
+
spaces_between_special_tokens=False,
|
|
876
|
+
clean_up_tokenization_spaces=True,
|
|
877
|
+
)
|
|
878
|
+
if r not in output_mapping
|
|
879
|
+
else output_mapping[r]
|
|
880
|
+
)
|
|
881
|
+
completion = _get_completion(outputs, r.finish_reason, model_uid, r)
|
|
882
|
+
r.completion = [completion]
|
|
883
|
+
|
|
884
|
+
e_time = time.time()
|
|
885
|
+
logger.debug(
|
|
886
|
+
f"Average throughput for a step: {(len(valid_req_list) * decode_round + len(prompts)) / (e_time - s_time)} token/s."
|
|
887
|
+
)
|
|
888
|
+
|
|
889
|
+
|
|
890
|
+
def batch_inference_one_step(
|
|
891
|
+
req_list: List[InferenceRequest],
|
|
892
|
+
model_uid,
|
|
893
|
+
model,
|
|
894
|
+
tokenizer,
|
|
895
|
+
device,
|
|
896
|
+
context_len: int,
|
|
897
|
+
):
|
|
898
|
+
from ....core.model import OutOfMemoryError
|
|
899
|
+
|
|
900
|
+
try:
|
|
901
|
+
_batch_inference_one_step_internal(
|
|
902
|
+
req_list, model_uid, model, tokenizer, device, context_len
|
|
903
|
+
)
|
|
904
|
+
except OutOfMemoryError:
|
|
905
|
+
logger.exception(
|
|
906
|
+
f"Batch inference out of memory. "
|
|
907
|
+
f"Xinference will restart the model: {model_uid}. "
|
|
908
|
+
f"Please be patient for a few moments."
|
|
909
|
+
)
|
|
910
|
+
# Just kill the process and let xinference auto-recover the model
|
|
911
|
+
os._exit(1)
|
|
912
|
+
except Exception as e:
|
|
913
|
+
logger.exception(f"Internal error for batch inference: {e}.")
|
|
914
|
+
# TODO: handle this
|
|
@@ -93,6 +93,7 @@ VLLM_SUPPORTED_MODELS = [
|
|
|
93
93
|
"baichuan",
|
|
94
94
|
"internlm-16k",
|
|
95
95
|
"mistral-v0.1",
|
|
96
|
+
"codestral-v0.1",
|
|
96
97
|
"Yi",
|
|
97
98
|
"Yi-1.5",
|
|
98
99
|
"code-llama",
|
|
@@ -118,11 +119,14 @@ VLLM_SUPPORTED_CHAT_MODELS = [
|
|
|
118
119
|
"code-llama-instruct",
|
|
119
120
|
"mistral-instruct-v0.1",
|
|
120
121
|
"mistral-instruct-v0.2",
|
|
122
|
+
"mistral-instruct-v0.3",
|
|
121
123
|
"mixtral-instruct-v0.1",
|
|
122
124
|
"mixtral-8x22B-instruct-v0.1",
|
|
123
125
|
"chatglm3",
|
|
124
126
|
"chatglm3-32k",
|
|
125
127
|
"chatglm3-128k",
|
|
128
|
+
"glm4-chat",
|
|
129
|
+
"glm4-chat-1m",
|
|
126
130
|
"deepseek-chat",
|
|
127
131
|
"deepseek-coder-instruct",
|
|
128
132
|
]
|
|
@@ -130,6 +134,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
|
|
|
130
134
|
VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-chat")
|
|
131
135
|
VLLM_SUPPORTED_MODELS.append("codeqwen1.5")
|
|
132
136
|
VLLM_SUPPORTED_CHAT_MODELS.append("codeqwen1.5-chat")
|
|
137
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("qwen2-instruct")
|
|
133
138
|
|
|
134
139
|
if VLLM_INSTALLED and vllm.__version__ >= "0.3.2":
|
|
135
140
|
VLLM_SUPPORTED_CHAT_MODELS.append("gemma-it")
|
|
@@ -140,6 +145,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.3":
|
|
|
140
145
|
|
|
141
146
|
if VLLM_INSTALLED and vllm.__version__ >= "0.4.0":
|
|
142
147
|
VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-moe-chat")
|
|
148
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("qwen2-moe-instruct")
|
|
143
149
|
VLLM_SUPPORTED_CHAT_MODELS.append("c4ai-command-r-v01")
|
|
144
150
|
|
|
145
151
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .core import Chat
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
|
|
2
|
+
import os
|
|
3
|
+
import logging
|
|
4
|
+
from functools import partial
|
|
5
|
+
from omegaconf import OmegaConf
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from vocos import Vocos
|
|
9
|
+
from .model.dvae import DVAE
|
|
10
|
+
from .model.gpt import GPT_warpper
|
|
11
|
+
from .utils.gpu_utils import select_device
|
|
12
|
+
from .utils.infer_utils import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map
|
|
13
|
+
from .utils.io_utils import get_latest_modified_file
|
|
14
|
+
from .infer.api import refine_text, infer_code
|
|
15
|
+
|
|
16
|
+
from huggingface_hub import snapshot_download
|
|
17
|
+
|
|
18
|
+
logging.basicConfig(level = logging.INFO)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Chat:
|
|
22
|
+
def __init__(self, ):
|
|
23
|
+
self.pretrain_models = {}
|
|
24
|
+
self.normalizer = {}
|
|
25
|
+
self.logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
def check_model(self, level = logging.INFO, use_decoder = False):
|
|
28
|
+
not_finish = False
|
|
29
|
+
check_list = ['vocos', 'gpt', 'tokenizer']
|
|
30
|
+
|
|
31
|
+
if use_decoder:
|
|
32
|
+
check_list.append('decoder')
|
|
33
|
+
else:
|
|
34
|
+
check_list.append('dvae')
|
|
35
|
+
|
|
36
|
+
for module in check_list:
|
|
37
|
+
if module not in self.pretrain_models:
|
|
38
|
+
self.logger.log(logging.WARNING, f'{module} not initialized.')
|
|
39
|
+
not_finish = True
|
|
40
|
+
|
|
41
|
+
if not not_finish:
|
|
42
|
+
self.logger.log(level, f'All initialized.')
|
|
43
|
+
|
|
44
|
+
return not not_finish
|
|
45
|
+
|
|
46
|
+
def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>', **kwargs):
|
|
47
|
+
if source == 'huggingface':
|
|
48
|
+
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
|
|
49
|
+
try:
|
|
50
|
+
download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
|
|
51
|
+
except:
|
|
52
|
+
download_path = None
|
|
53
|
+
if download_path is None or force_redownload:
|
|
54
|
+
self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
|
|
55
|
+
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
|
|
56
|
+
else:
|
|
57
|
+
self.logger.log(logging.INFO, f'Load from cache: {download_path}')
|
|
58
|
+
elif source == 'local':
|
|
59
|
+
self.logger.log(logging.INFO, f'Load from local: {local_path}')
|
|
60
|
+
download_path = local_path
|
|
61
|
+
|
|
62
|
+
self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs)
|
|
63
|
+
|
|
64
|
+
def _load(
|
|
65
|
+
self,
|
|
66
|
+
vocos_config_path: str = None,
|
|
67
|
+
vocos_ckpt_path: str = None,
|
|
68
|
+
dvae_config_path: str = None,
|
|
69
|
+
dvae_ckpt_path: str = None,
|
|
70
|
+
gpt_config_path: str = None,
|
|
71
|
+
gpt_ckpt_path: str = None,
|
|
72
|
+
decoder_config_path: str = None,
|
|
73
|
+
decoder_ckpt_path: str = None,
|
|
74
|
+
tokenizer_path: str = None,
|
|
75
|
+
device: str = None,
|
|
76
|
+
compile: bool = True,
|
|
77
|
+
):
|
|
78
|
+
if not device:
|
|
79
|
+
device = select_device(4096)
|
|
80
|
+
self.logger.log(logging.INFO, f'use {device}')
|
|
81
|
+
|
|
82
|
+
if vocos_config_path:
|
|
83
|
+
vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
|
|
84
|
+
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
|
|
85
|
+
vocos.load_state_dict(torch.load(vocos_ckpt_path))
|
|
86
|
+
self.pretrain_models['vocos'] = vocos
|
|
87
|
+
self.logger.log(logging.INFO, 'vocos loaded.')
|
|
88
|
+
|
|
89
|
+
if dvae_config_path:
|
|
90
|
+
cfg = OmegaConf.load(dvae_config_path)
|
|
91
|
+
dvae = DVAE(**cfg).to(device).eval()
|
|
92
|
+
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
|
|
93
|
+
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
|
|
94
|
+
self.pretrain_models['dvae'] = dvae
|
|
95
|
+
self.logger.log(logging.INFO, 'dvae loaded.')
|
|
96
|
+
|
|
97
|
+
if gpt_config_path:
|
|
98
|
+
cfg = OmegaConf.load(gpt_config_path)
|
|
99
|
+
gpt = GPT_warpper(**cfg).to(device).eval()
|
|
100
|
+
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
|
|
101
|
+
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
|
|
102
|
+
if compile and 'cuda' in str(device):
|
|
103
|
+
gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True)
|
|
104
|
+
self.pretrain_models['gpt'] = gpt
|
|
105
|
+
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt')
|
|
106
|
+
assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'
|
|
107
|
+
self.pretrain_models['spk_stat'] = torch.load(spk_stat_path).to(device)
|
|
108
|
+
self.logger.log(logging.INFO, 'gpt loaded.')
|
|
109
|
+
|
|
110
|
+
if decoder_config_path:
|
|
111
|
+
cfg = OmegaConf.load(decoder_config_path)
|
|
112
|
+
decoder = DVAE(**cfg).to(device).eval()
|
|
113
|
+
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
|
|
114
|
+
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
|
|
115
|
+
self.pretrain_models['decoder'] = decoder
|
|
116
|
+
self.logger.log(logging.INFO, 'decoder loaded.')
|
|
117
|
+
|
|
118
|
+
if tokenizer_path:
|
|
119
|
+
tokenizer = torch.load(tokenizer_path, map_location='cpu')
|
|
120
|
+
tokenizer.padding_side = 'left'
|
|
121
|
+
self.pretrain_models['tokenizer'] = tokenizer
|
|
122
|
+
self.logger.log(logging.INFO, 'tokenizer loaded.')
|
|
123
|
+
|
|
124
|
+
self.check_model()
|
|
125
|
+
|
|
126
|
+
def infer(
|
|
127
|
+
self,
|
|
128
|
+
text,
|
|
129
|
+
skip_refine_text=False,
|
|
130
|
+
refine_text_only=False,
|
|
131
|
+
params_refine_text={},
|
|
132
|
+
params_infer_code={'prompt':'[speed_5]'},
|
|
133
|
+
use_decoder=True,
|
|
134
|
+
do_text_normalization=True,
|
|
135
|
+
lang=None,
|
|
136
|
+
):
|
|
137
|
+
|
|
138
|
+
assert self.check_model(use_decoder=use_decoder)
|
|
139
|
+
|
|
140
|
+
if not isinstance(text, list):
|
|
141
|
+
text = [text]
|
|
142
|
+
|
|
143
|
+
if do_text_normalization:
|
|
144
|
+
for i, t in enumerate(text):
|
|
145
|
+
_lang = detect_language(t) if lang is None else lang
|
|
146
|
+
self.init_normalizer(_lang)
|
|
147
|
+
text[i] = self.normalizer[_lang](t)
|
|
148
|
+
if _lang == 'zh':
|
|
149
|
+
text[i] = apply_half2full_map(text[i])
|
|
150
|
+
|
|
151
|
+
for i, t in enumerate(text):
|
|
152
|
+
invalid_characters = count_invalid_characters(t)
|
|
153
|
+
if len(invalid_characters):
|
|
154
|
+
self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
|
|
155
|
+
text[i] = apply_character_map(t)
|
|
156
|
+
|
|
157
|
+
if not skip_refine_text:
|
|
158
|
+
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
|
|
159
|
+
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
|
|
160
|
+
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
|
|
161
|
+
if refine_text_only:
|
|
162
|
+
return text
|
|
163
|
+
|
|
164
|
+
text = [params_infer_code.get('prompt', '') + i for i in text]
|
|
165
|
+
params_infer_code.pop('prompt', '')
|
|
166
|
+
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
|
|
167
|
+
|
|
168
|
+
if use_decoder:
|
|
169
|
+
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
|
|
170
|
+
else:
|
|
171
|
+
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
|
|
172
|
+
|
|
173
|
+
wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
|
|
174
|
+
|
|
175
|
+
return wav
|
|
176
|
+
|
|
177
|
+
def sample_random_speaker(self, ):
|
|
178
|
+
|
|
179
|
+
dim = self.pretrain_models['gpt'].gpt.layers[0].mlp.gate_proj.in_features
|
|
180
|
+
std, mean = self.pretrain_models['spk_stat'].chunk(2)
|
|
181
|
+
return torch.randn(dim, device=std.device) * std + mean
|
|
182
|
+
|
|
183
|
+
def init_normalizer(self, lang):
|
|
184
|
+
|
|
185
|
+
if lang not in self.normalizer:
|
|
186
|
+
if lang == 'zh':
|
|
187
|
+
try:
|
|
188
|
+
from tn.chinese.normalizer import Normalizer
|
|
189
|
+
except:
|
|
190
|
+
self.logger.log(logging.WARNING, f'Package WeTextProcessing not found! \
|
|
191
|
+
Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing')
|
|
192
|
+
self.normalizer[lang] = Normalizer().normalize
|
|
193
|
+
else:
|
|
194
|
+
try:
|
|
195
|
+
from nemo_text_processing.text_normalization.normalize import Normalizer
|
|
196
|
+
except:
|
|
197
|
+
self.logger.log(logging.WARNING, f'Package nemo_text_processing not found! \
|
|
198
|
+
Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing')
|
|
199
|
+
self.normalizer[lang] = partial(Normalizer(input_case='cased', lang=lang).normalize, verbose=False, punct_post_process=True)
|
|
200
|
+
|
xinference/types.py
CHANGED
|
@@ -284,6 +284,7 @@ class PytorchGenerateConfig(TypedDict, total=False):
|
|
|
284
284
|
tools: Optional[List[Dict]]
|
|
285
285
|
lora_name: Optional[str]
|
|
286
286
|
stream_options: Optional[Union[dict, None]]
|
|
287
|
+
request_id: Optional[str]
|
|
287
288
|
|
|
288
289
|
|
|
289
290
|
class PytorchModelConfig(TypedDict, total=False):
|
|
@@ -297,6 +298,7 @@ class PytorchModelConfig(TypedDict, total=False):
|
|
|
297
298
|
gptq_groupsize: int
|
|
298
299
|
gptq_act_order: bool
|
|
299
300
|
trust_remote_code: bool
|
|
301
|
+
max_num_seqs: int
|
|
300
302
|
|
|
301
303
|
|
|
302
304
|
def get_pydantic_model_from_method(
|
|
@@ -361,6 +363,7 @@ class CreateCompletionTorch(BaseModel):
|
|
|
361
363
|
top_p: float = top_p_field
|
|
362
364
|
top_k: int = top_k_field
|
|
363
365
|
lora_name: Optional[str]
|
|
366
|
+
request_id: Optional[str]
|
|
364
367
|
|
|
365
368
|
|
|
366
369
|
CreateCompletionLlamaCpp: BaseModel
|