xinference 0.11.3__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

@@ -14,13 +14,15 @@
14
14
 
15
15
  import gc
16
16
  import logging
17
+ import os
17
18
  import time
18
19
  import uuid
19
20
  from threading import Thread
20
- from typing import Iterable, Iterator, Tuple
21
+ from typing import Dict, Iterable, Iterator, List, Optional, Tuple
21
22
 
22
23
  import torch
23
24
  from transformers import GenerationConfig, TextIteratorStreamer
25
+ from transformers.cache_utils import DynamicCache
24
26
  from transformers.generation.logits_process import (
25
27
  LogitsProcessorList,
26
28
  RepetitionPenaltyLogitsProcessor,
@@ -29,8 +31,10 @@ from transformers.generation.logits_process import (
29
31
  TopPLogitsWarper,
30
32
  )
31
33
 
34
+ from ....core.scheduler import InferenceRequest
32
35
  from ....device_utils import empty_cache
33
36
  from ....types import (
37
+ Completion,
34
38
  CompletionChoice,
35
39
  CompletionChunk,
36
40
  CompletionUsage,
@@ -54,7 +58,7 @@ def is_partial_stop(output: str, stop_str: str):
54
58
  return False
55
59
 
56
60
 
57
- def get_context_length(config):
61
+ def get_context_length(config) -> int:
58
62
  """Get the context length of a model from a huggingface model config."""
59
63
  if (
60
64
  hasattr(config, "max_sequence_length")
@@ -528,3 +532,383 @@ def generate_stream_falcon(
528
532
  # clean
529
533
  gc.collect()
530
534
  empty_cache()
535
+
536
+
537
+ def _get_token_from_logits(
538
+ req: InferenceRequest, i: int, logits, temperature, repetition_penalty, top_p, top_k
539
+ ):
540
+ logits_processor = prepare_logits_processor(
541
+ temperature, repetition_penalty, top_p, top_k
542
+ )
543
+
544
+ if logits_processor:
545
+ if repetition_penalty > 1.0:
546
+ tmp_output_ids = torch.as_tensor(
547
+ [req.prompt_tokens + req.new_tokens], device=logits.device
548
+ )
549
+ else:
550
+ tmp_output_ids = None
551
+ last_token_logits = logits_processor(tmp_output_ids, logits[i : i + 1, -1, :])[
552
+ 0
553
+ ]
554
+ else:
555
+ last_token_logits = logits[i : i + 1, -1, :]
556
+
557
+ if temperature < 1e-5 or top_p < 1e-8: # greedy
558
+ _, indices = torch.topk(last_token_logits, 2)
559
+ else:
560
+ probs = torch.softmax(last_token_logits, dim=-1)
561
+ indices = torch.multinomial(probs, num_samples=2)
562
+ token = indices[0].int().item()
563
+ return token
564
+
565
+
566
+ def _pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
567
+ assert len(x) <= max_len
568
+ return [pad] * (max_len - len(x)) + x
569
+
570
+
571
+ def _pad_seqs_inplace(seqs: List[List[int]], pad: int):
572
+ max_len = max(len(seq) for seq in seqs)
573
+ n = len(seqs)
574
+ i = 0
575
+ while i < n:
576
+ seqs[i] = _pad_to_max_length(seqs[i], max_len, pad)
577
+ i += 1
578
+
579
+
580
+ def get_max_src_len(context_len: int, r: InferenceRequest) -> int:
581
+ max_new_tokens = int(
582
+ r.sanitized_generate_config.get("max_tokens", max_tokens_field.default)
583
+ )
584
+ return context_len - max_new_tokens - 8
585
+
586
+
587
+ def _get_completion_chunk(
588
+ output: str,
589
+ finish_reason: Optional[str],
590
+ model_uid: str,
591
+ r: InferenceRequest,
592
+ just_usage: bool,
593
+ ):
594
+ completion_choice = (
595
+ [
596
+ CompletionChoice(
597
+ text=output, index=0, logprobs=None, finish_reason=finish_reason
598
+ )
599
+ ]
600
+ if not just_usage
601
+ else []
602
+ )
603
+ completion_chunk = CompletionChunk(
604
+ id=str(uuid.uuid1()),
605
+ object="text_completion",
606
+ created=int(time.time()),
607
+ model=model_uid,
608
+ choices=completion_choice,
609
+ )
610
+ completion_usage = CompletionUsage(
611
+ prompt_tokens=len(r.prompt_tokens),
612
+ completion_tokens=len(r.new_tokens),
613
+ total_tokens=len(r.prompt_tokens) + len(r.new_tokens),
614
+ )
615
+ completion_chunk["usage"] = completion_usage
616
+ return completion_chunk
617
+
618
+
619
+ def _get_completion(
620
+ output: str, finish_reason: Optional[str], model_uid: str, r: InferenceRequest
621
+ ):
622
+ completion_choice = CompletionChoice(
623
+ text=output, index=0, logprobs=None, finish_reason=finish_reason
624
+ )
625
+
626
+ completion_chunk = CompletionChunk(
627
+ id=str(uuid.uuid1()),
628
+ object="text_completion",
629
+ created=int(time.time()),
630
+ model=model_uid,
631
+ choices=[completion_choice],
632
+ )
633
+ completion_usage = CompletionUsage(
634
+ prompt_tokens=len(r.prompt_tokens),
635
+ completion_tokens=len(r.new_tokens),
636
+ total_tokens=len(r.prompt_tokens) + len(r.new_tokens),
637
+ )
638
+ completion = Completion(
639
+ id=completion_chunk["id"],
640
+ object=completion_chunk["object"],
641
+ created=completion_chunk["created"],
642
+ model=completion_chunk["model"],
643
+ choices=completion_chunk["choices"],
644
+ usage=completion_usage,
645
+ )
646
+ return completion
647
+
648
+
649
+ def _merge_kv_cache(
650
+ past_kv: Tuple[Tuple[torch.Tensor]], new_kv: Tuple[Tuple[torch.Tensor]]
651
+ ):
652
+ from torch.nn.functional import pad
653
+
654
+ past_cache = DynamicCache.from_legacy_cache(past_kv)
655
+ new_cache = DynamicCache.from_legacy_cache(new_kv)
656
+ past_seq_len = past_cache.get_seq_length()
657
+ new_seq_len = new_cache.get_seq_length()
658
+ if past_seq_len != new_seq_len:
659
+ padding_target = new_cache if past_seq_len > new_seq_len else past_cache
660
+ padding_len = abs(past_seq_len - new_seq_len)
661
+ for idx in range(len(padding_target)):
662
+ k = padding_target.key_cache[idx]
663
+ v = padding_target.value_cache[idx]
664
+ _k = pad(k, (0, 0, padding_len, 0))
665
+ _v = pad(v, (0, 0, padding_len, 0))
666
+ padding_target.key_cache[idx] = _k
667
+ padding_target.value_cache[idx] = _v
668
+
669
+ ret_kv = DynamicCache()
670
+ for idx in range(len(past_cache)):
671
+ k1, k2 = new_cache.key_cache[idx], past_cache.key_cache[idx]
672
+ v1, v2 = new_cache.value_cache[idx], past_cache.value_cache[idx]
673
+ ret_kv.update(torch.cat((k1, k2), 0), torch.cat((v1, v2), 0), idx)
674
+ return ret_kv.to_legacy_cache()
675
+
676
+
677
+ @torch.inference_mode()
678
+ def _batch_inference_one_step_internal(
679
+ req_list: List[InferenceRequest],
680
+ model_uid,
681
+ model,
682
+ tokenizer,
683
+ device,
684
+ context_len: int,
685
+ decode_round: int = 16,
686
+ bos_flag: str = "<bos_stream>",
687
+ eos_flag: str = "<eos_stream>",
688
+ ):
689
+ # need to judge stopped here,
690
+ # since some requests state may change to stopped due to invalid parameters, e.g. max_src_len
691
+ valid_req_list = [r for r in req_list if not r.stopped]
692
+ if not valid_req_list:
693
+ return
694
+ generate_config_mapping: Dict[InferenceRequest, Tuple] = {
695
+ r: r.get_generate_configs(tokenizer.eos_token_id) for r in valid_req_list
696
+ }
697
+ s_time = time.time()
698
+
699
+ prefill_reqs = []
700
+ prompts = []
701
+ decode_reqs = []
702
+ for r in valid_req_list:
703
+ if r.is_prefill:
704
+ prompts.append(r.full_prompt)
705
+ prefill_reqs.append(r)
706
+ else:
707
+ decode_reqs.append(r)
708
+
709
+ if prompts: # prefill first
710
+ input_ids: List[List[int]] = tokenizer(prompts, padding=False).input_ids
711
+ prompt_tokens = []
712
+ for i, input_id in enumerate(input_ids):
713
+ req = valid_req_list[i]
714
+ max_src_len = get_max_src_len(context_len, req)
715
+ req.prompt_tokens = input_id[-max_src_len:]
716
+ prompt_tokens.append(req.prompt_tokens)
717
+ _pad_seqs_inplace(prompt_tokens, 0)
718
+ out = model(torch.as_tensor(prompt_tokens, device=device), use_cache=True)
719
+
720
+ logits = out.logits
721
+ past_key_values = out.past_key_values
722
+
723
+ for i, r in enumerate(prefill_reqs):
724
+ (
725
+ max_new_tokens,
726
+ stream_interval,
727
+ include_usage,
728
+ stop_str,
729
+ stop_token_ids,
730
+ temperature,
731
+ repetition_penalty,
732
+ top_p,
733
+ top_k,
734
+ ) = generate_config_mapping[r]
735
+
736
+ token = _get_token_from_logits(
737
+ r, i, logits, temperature, repetition_penalty, top_p, top_k
738
+ )
739
+ r.is_prefill = False
740
+ r.append_new_token(token)
741
+
742
+ if decode_reqs:
743
+ decode_kv = decode_reqs[0].kv_cache
744
+ # prefill and decode kv cache need to be merged at `batch_size` and `seq_len` dimensions.
745
+ merged_kv_cache = _merge_kv_cache(decode_kv, past_key_values)
746
+ for r in valid_req_list:
747
+ r.kv_cache = merged_kv_cache
748
+ empty_cache()
749
+ else:
750
+ for r in valid_req_list:
751
+ r.kv_cache = past_key_values
752
+
753
+ past_key_values = valid_req_list[0].kv_cache
754
+ stop_token_mapping: Dict[InferenceRequest, int] = {}
755
+ output_mapping: Dict[InferenceRequest, str] = {}
756
+ # here, only decode phase, just run some rounds
757
+ for _i in range(decode_round):
758
+ decode_tokens: List[List[int]] = [[r.new_tokens[-1]] for r in valid_req_list]
759
+ out = model(
760
+ input_ids=torch.as_tensor(decode_tokens, device=device),
761
+ use_cache=True,
762
+ past_key_values=past_key_values,
763
+ )
764
+ logits = out.logits
765
+ past_key_values = out.past_key_values
766
+
767
+ for i, r in enumerate(valid_req_list):
768
+ (
769
+ max_new_tokens,
770
+ stream_interval,
771
+ include_usage,
772
+ stop_str,
773
+ stop_token_ids,
774
+ temperature,
775
+ repetition_penalty,
776
+ top_p,
777
+ top_k,
778
+ ) = generate_config_mapping[r]
779
+
780
+ token = _get_token_from_logits(
781
+ r, i, logits, temperature, repetition_penalty, top_p, top_k
782
+ )
783
+ r.kv_cache = past_key_values
784
+ r.append_new_token(token)
785
+
786
+ output = None
787
+ if not r.stopped:
788
+ stopped = token in stop_token_ids
789
+
790
+ if stopped:
791
+ finish_reason = "stop"
792
+ elif len(r.new_tokens) == max_new_tokens:
793
+ finish_reason = "length"
794
+ stopped = True
795
+ else:
796
+ finish_reason = None
797
+
798
+ # handle stop str
799
+ if stop_str and r not in output_mapping:
800
+ output = tokenizer.decode(
801
+ r.new_tokens,
802
+ skip_special_tokens=True,
803
+ spaces_between_special_tokens=False,
804
+ clean_up_tokenization_spaces=True,
805
+ )
806
+ if isinstance(stop_str, str):
807
+ stop_str = [stop_str]
808
+ for stop in stop_str:
809
+ pos = output.rfind(stop)
810
+ if pos != -1:
811
+ output = output[:pos]
812
+ output_mapping[r] = output
813
+ stopped = True
814
+ finish_reason = "stop"
815
+ break
816
+
817
+ r.stopped = stopped
818
+ r.finish_reason = finish_reason
819
+
820
+ if r.stopped and r not in stop_token_mapping and r not in output_mapping:
821
+ stop_token_mapping[r] = _i + 1
822
+
823
+ if r.stream:
824
+ """
825
+ Note that you can't just decode based on the newest r.new_tokens here,
826
+ which may destroy the integrity of the parsed characters,
827
+ and at the same time is not good at handling some special characters.
828
+ So the implementation here is to decode all the tokens that have been generated each time,
829
+ and then take the slice.
830
+ """
831
+ if r.stopped or len(r.new_tokens) % stream_interval == 0:
832
+ if output is None:
833
+ output = tokenizer.decode(
834
+ r.new_tokens,
835
+ skip_special_tokens=True,
836
+ spaces_between_special_tokens=False,
837
+ clean_up_tokenization_spaces=True,
838
+ )
839
+
840
+ if r.last_output_length == 0:
841
+ r.completion.append(bos_flag)
842
+
843
+ # this special character is mainly for qwen
844
+ output = output.strip("�")
845
+ output = output[r.last_output_length :]
846
+ r.last_output_length += len(output)
847
+
848
+ completion_chunk = _get_completion_chunk(
849
+ output, r.finish_reason, model_uid, r, False
850
+ )
851
+ r.completion.append(completion_chunk)
852
+ if r.stopped:
853
+ r.completion.append(eos_flag)
854
+
855
+ # last round, handle stream result
856
+ # append usage information when enable `include_usage` for OPENAI API compatibility
857
+ # The reason for counting the usage in the last round of the iteration is that,
858
+ # these tokens are real generated and should be counted.
859
+ if r.stopped and _i == decode_round - 1 and include_usage:
860
+ r.completion.append(
861
+ _get_completion_chunk(
862
+ "", r.finish_reason, model_uid, r, True
863
+ )
864
+ )
865
+ else:
866
+ # last round, handle non-stream result
867
+ if r.stopped and _i == decode_round - 1:
868
+ invalid_token_num = decode_round - stop_token_mapping[r]
869
+ outputs = (
870
+ tokenizer.decode(
871
+ r.new_tokens[: -(invalid_token_num + 1)]
872
+ if r.finish_reason == "stop"
873
+ else r.new_tokens[:-invalid_token_num],
874
+ skip_special_tokens=True,
875
+ spaces_between_special_tokens=False,
876
+ clean_up_tokenization_spaces=True,
877
+ )
878
+ if r not in output_mapping
879
+ else output_mapping[r]
880
+ )
881
+ completion = _get_completion(outputs, r.finish_reason, model_uid, r)
882
+ r.completion = [completion]
883
+
884
+ e_time = time.time()
885
+ logger.debug(
886
+ f"Average throughput for a step: {(len(valid_req_list) * decode_round + len(prompts)) / (e_time - s_time)} token/s."
887
+ )
888
+
889
+
890
+ def batch_inference_one_step(
891
+ req_list: List[InferenceRequest],
892
+ model_uid,
893
+ model,
894
+ tokenizer,
895
+ device,
896
+ context_len: int,
897
+ ):
898
+ from ....core.model import OutOfMemoryError
899
+
900
+ try:
901
+ _batch_inference_one_step_internal(
902
+ req_list, model_uid, model, tokenizer, device, context_len
903
+ )
904
+ except OutOfMemoryError:
905
+ logger.exception(
906
+ f"Batch inference out of memory. "
907
+ f"Xinference will restart the model: {model_uid}. "
908
+ f"Please be patient for a few moments."
909
+ )
910
+ # Just kill the process and let xinference auto-recover the model
911
+ os._exit(1)
912
+ except Exception as e:
913
+ logger.exception(f"Internal error for batch inference: {e}.")
914
+ # TODO: handle this
@@ -93,6 +93,7 @@ VLLM_SUPPORTED_MODELS = [
93
93
  "baichuan",
94
94
  "internlm-16k",
95
95
  "mistral-v0.1",
96
+ "codestral-v0.1",
96
97
  "Yi",
97
98
  "Yi-1.5",
98
99
  "code-llama",
@@ -118,11 +119,14 @@ VLLM_SUPPORTED_CHAT_MODELS = [
118
119
  "code-llama-instruct",
119
120
  "mistral-instruct-v0.1",
120
121
  "mistral-instruct-v0.2",
122
+ "mistral-instruct-v0.3",
121
123
  "mixtral-instruct-v0.1",
122
124
  "mixtral-8x22B-instruct-v0.1",
123
125
  "chatglm3",
124
126
  "chatglm3-32k",
125
127
  "chatglm3-128k",
128
+ "glm4-chat",
129
+ "glm4-chat-1m",
126
130
  "deepseek-chat",
127
131
  "deepseek-coder-instruct",
128
132
  ]
@@ -130,6 +134,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
130
134
  VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-chat")
131
135
  VLLM_SUPPORTED_MODELS.append("codeqwen1.5")
132
136
  VLLM_SUPPORTED_CHAT_MODELS.append("codeqwen1.5-chat")
137
+ VLLM_SUPPORTED_CHAT_MODELS.append("qwen2-instruct")
133
138
 
134
139
  if VLLM_INSTALLED and vllm.__version__ >= "0.3.2":
135
140
  VLLM_SUPPORTED_CHAT_MODELS.append("gemma-it")
@@ -140,6 +145,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.3":
140
145
 
141
146
  if VLLM_INSTALLED and vllm.__version__ >= "0.4.0":
142
147
  VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-moe-chat")
148
+ VLLM_SUPPORTED_CHAT_MODELS.append("qwen2-moe-instruct")
143
149
  VLLM_SUPPORTED_CHAT_MODELS.append("c4ai-command-r-v01")
144
150
 
145
151
 
@@ -0,0 +1 @@
1
+ from .core import Chat
@@ -0,0 +1,200 @@
1
+
2
+ import os
3
+ import logging
4
+ from functools import partial
5
+ from omegaconf import OmegaConf
6
+
7
+ import torch
8
+ from vocos import Vocos
9
+ from .model.dvae import DVAE
10
+ from .model.gpt import GPT_warpper
11
+ from .utils.gpu_utils import select_device
12
+ from .utils.infer_utils import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map
13
+ from .utils.io_utils import get_latest_modified_file
14
+ from .infer.api import refine_text, infer_code
15
+
16
+ from huggingface_hub import snapshot_download
17
+
18
+ logging.basicConfig(level = logging.INFO)
19
+
20
+
21
+ class Chat:
22
+ def __init__(self, ):
23
+ self.pretrain_models = {}
24
+ self.normalizer = {}
25
+ self.logger = logging.getLogger(__name__)
26
+
27
+ def check_model(self, level = logging.INFO, use_decoder = False):
28
+ not_finish = False
29
+ check_list = ['vocos', 'gpt', 'tokenizer']
30
+
31
+ if use_decoder:
32
+ check_list.append('decoder')
33
+ else:
34
+ check_list.append('dvae')
35
+
36
+ for module in check_list:
37
+ if module not in self.pretrain_models:
38
+ self.logger.log(logging.WARNING, f'{module} not initialized.')
39
+ not_finish = True
40
+
41
+ if not not_finish:
42
+ self.logger.log(level, f'All initialized.')
43
+
44
+ return not not_finish
45
+
46
+ def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>', **kwargs):
47
+ if source == 'huggingface':
48
+ hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
49
+ try:
50
+ download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
51
+ except:
52
+ download_path = None
53
+ if download_path is None or force_redownload:
54
+ self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
55
+ download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
56
+ else:
57
+ self.logger.log(logging.INFO, f'Load from cache: {download_path}')
58
+ elif source == 'local':
59
+ self.logger.log(logging.INFO, f'Load from local: {local_path}')
60
+ download_path = local_path
61
+
62
+ self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs)
63
+
64
+ def _load(
65
+ self,
66
+ vocos_config_path: str = None,
67
+ vocos_ckpt_path: str = None,
68
+ dvae_config_path: str = None,
69
+ dvae_ckpt_path: str = None,
70
+ gpt_config_path: str = None,
71
+ gpt_ckpt_path: str = None,
72
+ decoder_config_path: str = None,
73
+ decoder_ckpt_path: str = None,
74
+ tokenizer_path: str = None,
75
+ device: str = None,
76
+ compile: bool = True,
77
+ ):
78
+ if not device:
79
+ device = select_device(4096)
80
+ self.logger.log(logging.INFO, f'use {device}')
81
+
82
+ if vocos_config_path:
83
+ vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
84
+ assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
85
+ vocos.load_state_dict(torch.load(vocos_ckpt_path))
86
+ self.pretrain_models['vocos'] = vocos
87
+ self.logger.log(logging.INFO, 'vocos loaded.')
88
+
89
+ if dvae_config_path:
90
+ cfg = OmegaConf.load(dvae_config_path)
91
+ dvae = DVAE(**cfg).to(device).eval()
92
+ assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
93
+ dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
94
+ self.pretrain_models['dvae'] = dvae
95
+ self.logger.log(logging.INFO, 'dvae loaded.')
96
+
97
+ if gpt_config_path:
98
+ cfg = OmegaConf.load(gpt_config_path)
99
+ gpt = GPT_warpper(**cfg).to(device).eval()
100
+ assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
101
+ gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
102
+ if compile and 'cuda' in str(device):
103
+ gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True)
104
+ self.pretrain_models['gpt'] = gpt
105
+ spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt')
106
+ assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'
107
+ self.pretrain_models['spk_stat'] = torch.load(spk_stat_path).to(device)
108
+ self.logger.log(logging.INFO, 'gpt loaded.')
109
+
110
+ if decoder_config_path:
111
+ cfg = OmegaConf.load(decoder_config_path)
112
+ decoder = DVAE(**cfg).to(device).eval()
113
+ assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
114
+ decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
115
+ self.pretrain_models['decoder'] = decoder
116
+ self.logger.log(logging.INFO, 'decoder loaded.')
117
+
118
+ if tokenizer_path:
119
+ tokenizer = torch.load(tokenizer_path, map_location='cpu')
120
+ tokenizer.padding_side = 'left'
121
+ self.pretrain_models['tokenizer'] = tokenizer
122
+ self.logger.log(logging.INFO, 'tokenizer loaded.')
123
+
124
+ self.check_model()
125
+
126
+ def infer(
127
+ self,
128
+ text,
129
+ skip_refine_text=False,
130
+ refine_text_only=False,
131
+ params_refine_text={},
132
+ params_infer_code={'prompt':'[speed_5]'},
133
+ use_decoder=True,
134
+ do_text_normalization=True,
135
+ lang=None,
136
+ ):
137
+
138
+ assert self.check_model(use_decoder=use_decoder)
139
+
140
+ if not isinstance(text, list):
141
+ text = [text]
142
+
143
+ if do_text_normalization:
144
+ for i, t in enumerate(text):
145
+ _lang = detect_language(t) if lang is None else lang
146
+ self.init_normalizer(_lang)
147
+ text[i] = self.normalizer[_lang](t)
148
+ if _lang == 'zh':
149
+ text[i] = apply_half2full_map(text[i])
150
+
151
+ for i, t in enumerate(text):
152
+ invalid_characters = count_invalid_characters(t)
153
+ if len(invalid_characters):
154
+ self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
155
+ text[i] = apply_character_map(t)
156
+
157
+ if not skip_refine_text:
158
+ text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
159
+ text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
160
+ text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
161
+ if refine_text_only:
162
+ return text
163
+
164
+ text = [params_infer_code.get('prompt', '') + i for i in text]
165
+ params_infer_code.pop('prompt', '')
166
+ result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
167
+
168
+ if use_decoder:
169
+ mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
170
+ else:
171
+ mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
172
+
173
+ wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
174
+
175
+ return wav
176
+
177
+ def sample_random_speaker(self, ):
178
+
179
+ dim = self.pretrain_models['gpt'].gpt.layers[0].mlp.gate_proj.in_features
180
+ std, mean = self.pretrain_models['spk_stat'].chunk(2)
181
+ return torch.randn(dim, device=std.device) * std + mean
182
+
183
+ def init_normalizer(self, lang):
184
+
185
+ if lang not in self.normalizer:
186
+ if lang == 'zh':
187
+ try:
188
+ from tn.chinese.normalizer import Normalizer
189
+ except:
190
+ self.logger.log(logging.WARNING, f'Package WeTextProcessing not found! \
191
+ Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing')
192
+ self.normalizer[lang] = Normalizer().normalize
193
+ else:
194
+ try:
195
+ from nemo_text_processing.text_normalization.normalize import Normalizer
196
+ except:
197
+ self.logger.log(logging.WARNING, f'Package nemo_text_processing not found! \
198
+ Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing')
199
+ self.normalizer[lang] = partial(Normalizer(input_case='cased', lang=lang).normalize, verbose=False, punct_post_process=True)
200
+
xinference/types.py CHANGED
@@ -284,6 +284,7 @@ class PytorchGenerateConfig(TypedDict, total=False):
284
284
  tools: Optional[List[Dict]]
285
285
  lora_name: Optional[str]
286
286
  stream_options: Optional[Union[dict, None]]
287
+ request_id: Optional[str]
287
288
 
288
289
 
289
290
  class PytorchModelConfig(TypedDict, total=False):
@@ -297,6 +298,7 @@ class PytorchModelConfig(TypedDict, total=False):
297
298
  gptq_groupsize: int
298
299
  gptq_act_order: bool
299
300
  trust_remote_code: bool
301
+ max_num_seqs: int
300
302
 
301
303
 
302
304
  def get_pydantic_model_from_method(
@@ -361,6 +363,7 @@ class CreateCompletionTorch(BaseModel):
361
363
  top_p: float = top_p_field
362
364
  top_k: int = top_k_field
363
365
  lora_name: Optional[str]
366
+ request_id: Optional[str]
364
367
 
365
368
 
366
369
  CreateCompletionLlamaCpp: BaseModel