xinference 0.11.2.post1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (36) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +83 -8
  3. xinference/client/restful/restful_client.py +70 -0
  4. xinference/constants.py +8 -0
  5. xinference/core/__init__.py +0 -2
  6. xinference/core/cache_tracker.py +22 -1
  7. xinference/core/chat_interface.py +71 -10
  8. xinference/core/model.py +141 -12
  9. xinference/core/scheduler.py +428 -0
  10. xinference/core/supervisor.py +31 -3
  11. xinference/core/worker.py +8 -3
  12. xinference/isolation.py +9 -2
  13. xinference/model/audio/chattts.py +84 -0
  14. xinference/model/audio/core.py +10 -3
  15. xinference/model/audio/model_spec.json +20 -0
  16. xinference/model/llm/__init__.py +6 -0
  17. xinference/model/llm/llm_family.json +1063 -260
  18. xinference/model/llm/llm_family_modelscope.json +686 -13
  19. xinference/model/llm/pytorch/baichuan.py +2 -1
  20. xinference/model/llm/pytorch/chatglm.py +2 -1
  21. xinference/model/llm/pytorch/cogvlm2.py +316 -0
  22. xinference/model/llm/pytorch/core.py +92 -6
  23. xinference/model/llm/pytorch/glm4v.py +258 -0
  24. xinference/model/llm/pytorch/intern_vl.py +5 -10
  25. xinference/model/llm/pytorch/minicpmv25.py +232 -0
  26. xinference/model/llm/pytorch/utils.py +386 -2
  27. xinference/model/llm/vllm/core.py +7 -1
  28. xinference/thirdparty/ChatTTS/__init__.py +1 -0
  29. xinference/thirdparty/ChatTTS/core.py +200 -0
  30. xinference/types.py +3 -0
  31. {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/METADATA +28 -11
  32. {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/RECORD +36 -29
  33. {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/LICENSE +0 -0
  34. {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/WHEEL +0 -0
  35. {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/entry_points.txt +0 -0
  36. {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/top_level.txt +0 -0
@@ -73,7 +73,8 @@ class BaichuanPytorchChatModel(PytorchChatModel):
73
73
  ) -> bool:
74
74
  if llm_spec.model_format != "pytorch":
75
75
  return False
76
- if llm_family.model_name not in ["baichuan-chat", "baichuan-2-chat"]:
76
+ model_family = llm_family.model_family or llm_family.model_name
77
+ if model_family not in ["baichuan-chat", "baichuan-2-chat"]:
77
78
  return False
78
79
  if "chat" not in llm_family.model_ability:
79
80
  return False
@@ -82,7 +82,8 @@ class ChatglmPytorchChatModel(PytorchChatModel):
82
82
  ) -> bool:
83
83
  if llm_spec.model_format != "pytorch":
84
84
  return False
85
- if "chatglm" not in llm_family.model_name:
85
+ model_family = llm_family.model_family or llm_family.model_name
86
+ if "chatglm" not in model_family and "glm4" not in model_family:
86
87
  return False
87
88
  if "chat" not in llm_family.model_ability:
88
89
  return False
@@ -0,0 +1,316 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import base64
15
+ import logging
16
+ import time
17
+ import uuid
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ from io import BytesIO
20
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
21
+
22
+ import requests
23
+ import torch
24
+ from PIL import Image
25
+
26
+ from ....model.utils import select_device
27
+ from ....types import (
28
+ ChatCompletion,
29
+ ChatCompletionChunk,
30
+ ChatCompletionMessage,
31
+ Completion,
32
+ CompletionChoice,
33
+ CompletionChunk,
34
+ CompletionUsage,
35
+ )
36
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
37
+ from .core import PytorchChatModel, PytorchGenerateConfig
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
42
+ IMAGENET_STD = (0.229, 0.224, 0.225)
43
+
44
+
45
+ class CogVLM2Model(PytorchChatModel):
46
+ def __init__(self, *args, **kwargs):
47
+ super().__init__(*args, **kwargs)
48
+ self._torch_type = None
49
+ self._device = None
50
+ self._tokenizer = None
51
+ self._model = None
52
+
53
+ @classmethod
54
+ def match(
55
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
56
+ ) -> bool:
57
+ family = model_family.model_family or model_family.model_name
58
+ if "cogvlm" in family.lower():
59
+ return True
60
+ return False
61
+
62
+ def load(self, **kwargs):
63
+ from transformers import AutoModelForCausalLM, AutoTokenizer
64
+ from transformers.generation import GenerationConfig
65
+
66
+ device = self._pytorch_model_config.get("device", "auto")
67
+ self._device = select_device(device)
68
+ self._torch_type = (
69
+ torch.bfloat16
70
+ if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
71
+ else torch.float16
72
+ )
73
+
74
+ self._tokenizer = AutoTokenizer.from_pretrained(
75
+ self.model_path,
76
+ trust_remote_code=True,
77
+ )
78
+
79
+ self._model = AutoModelForCausalLM.from_pretrained(
80
+ self.model_path,
81
+ torch_dtype=self._torch_type,
82
+ trust_remote_code=True,
83
+ low_cpu_mem_usage=True,
84
+ device_map="auto",
85
+ ).eval()
86
+
87
+ # Specify hyperparameters for generation
88
+ self._model.generation_config = GenerationConfig.from_pretrained(
89
+ self.model_path,
90
+ trust_remote_code=True,
91
+ )
92
+
93
+ def _message_content_to_cogvlm2(self, content):
94
+ def _load_image(_url):
95
+ if _url.startswith("data:"):
96
+ logging.info("Parse url by base64 decoder.")
97
+ # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
98
+ # e.g. f"data:image/jpeg;base64,{base64_image}"
99
+ _type, data = _url.split(";")
100
+ _, ext = _type.split("/")
101
+ data = data[len("base64,") :]
102
+ data = base64.b64decode(data.encode("utf-8"))
103
+ return Image.open(BytesIO(data)).convert("RGB")
104
+ else:
105
+ try:
106
+ response = requests.get(_url)
107
+ except requests.exceptions.MissingSchema:
108
+ return Image.open(_url).convert("RGB")
109
+ else:
110
+ return Image.open(BytesIO(response.content)).convert("RGB")
111
+
112
+ if not isinstance(content, str):
113
+ texts = []
114
+ image_urls = []
115
+ for c in content:
116
+ c_type = c.get("type")
117
+ if c_type == "text":
118
+ texts.append(c["text"])
119
+ elif c_type == "image_url":
120
+ image_urls.append(c["image_url"]["url"])
121
+ image_futures = []
122
+ with ThreadPoolExecutor() as executor:
123
+ for image_url in image_urls:
124
+ fut = executor.submit(_load_image, image_url)
125
+ image_futures.append(fut)
126
+ images = [fut.result() for fut in image_futures]
127
+ text = " ".join(texts)
128
+ if len(images) == 0:
129
+ return text, None
130
+ elif len(images) == 1:
131
+ return text, images
132
+ else:
133
+ raise RuntimeError(
134
+ "Only one image per message is supported by CogVLM2."
135
+ )
136
+ return content, None
137
+
138
+ def _history_content_to_cogvlm2(
139
+ self, system_prompt: str, chat_history: List[ChatCompletionMessage]
140
+ ):
141
+ def _image_to_piexl_values(image):
142
+ if image.startswith("data:"):
143
+ logging.info("Parse url by base64 decoder.")
144
+ # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
145
+ # e.g. f"data:image/jpeg;base64,{base64_image}"
146
+ _type, data = image.split(";")
147
+ _, ext = _type.split("/")
148
+ data = data[len("base64,") :]
149
+ data = base64.b64decode(data.encode("utf-8"))
150
+ return Image.open(BytesIO(data)).convert("RGB")
151
+ else:
152
+ try:
153
+ response = requests.get(image)
154
+ except requests.exceptions.MissingSchema:
155
+ return Image.open(image).convert("RGB")
156
+ else:
157
+ return Image.open(BytesIO(response.content)).convert("RGB")
158
+
159
+ query = system_prompt
160
+ history: List[Tuple] = []
161
+ pixel_values = None
162
+ for i in range(0, len(chat_history), 2):
163
+ user = chat_history[i]["content"]
164
+ if isinstance(user, List):
165
+ for content in user:
166
+ c_type = content.get("type")
167
+ if c_type == "text":
168
+ user = content["text"]
169
+ elif c_type == "image_url" and not pixel_values:
170
+ pixel_values = _image_to_piexl_values(
171
+ content["image_url"]["url"]
172
+ )
173
+ assistant = chat_history[i + 1]["content"]
174
+ query = query + f" USER: {user} ASSISTANT:"
175
+ history.append((query, assistant))
176
+ query = query + f" {assistant}"
177
+ return query, history, [pixel_values]
178
+
179
+ def chat(
180
+ self,
181
+ prompt: Union[str, List[Dict]],
182
+ system_prompt: Optional[str] = None,
183
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
184
+ generate_config: Optional[PytorchGenerateConfig] = None,
185
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
186
+ system_prompt = system_prompt if system_prompt else ""
187
+ stream = generate_config.get("stream", False) if generate_config else False
188
+
189
+ sanitized_config = {
190
+ "pad_token_id": 128002,
191
+ "max_new_tokens": generate_config.get("max_tokens", 512)
192
+ if generate_config
193
+ else 512,
194
+ }
195
+
196
+ content, image = self._message_content_to_cogvlm2(prompt)
197
+
198
+ history = []
199
+ query = ""
200
+ history_image = None
201
+ if chat_history:
202
+ query, history, history_image = self._history_content_to_cogvlm2(
203
+ system_prompt, chat_history
204
+ )
205
+
206
+ if image and history_image:
207
+ history = []
208
+ query = system_prompt + f" USER: {content} ASSISTANT:"
209
+ else:
210
+ image = image if image else history_image
211
+ query = query + f" USER: {content} ASSISTANT:"
212
+
213
+ input_by_model = self._model.build_conversation_input_ids(
214
+ self._tokenizer,
215
+ query=query,
216
+ history=history,
217
+ images=image,
218
+ template_version="chat",
219
+ )
220
+
221
+ inputs = {
222
+ "input_ids": input_by_model["input_ids"].unsqueeze(0).to(self._device),
223
+ "token_type_ids": input_by_model["token_type_ids"]
224
+ .unsqueeze(0)
225
+ .to(self._device),
226
+ "attention_mask": input_by_model["attention_mask"]
227
+ .unsqueeze(0)
228
+ .to(self._device),
229
+ "images": [
230
+ [input_by_model["images"][0].to(self._device).to(self._torch_type)]
231
+ ]
232
+ if image is not None
233
+ else None,
234
+ }
235
+
236
+ if stream:
237
+ it = self._streaming_chat_response(inputs, sanitized_config)
238
+ return self._to_chat_completion_chunks(it)
239
+ else:
240
+ with torch.no_grad():
241
+ outputs = self._model.generate(**inputs, **sanitized_config)
242
+ outputs = outputs[:, inputs["input_ids"].shape[1] :]
243
+ response = self._tokenizer.decode(outputs[0])
244
+ response = response.split("<|end_of_text|>")[0]
245
+
246
+ chunk = Completion(
247
+ id=str(uuid.uuid1()),
248
+ object="text_completion",
249
+ created=int(time.time()),
250
+ model=self.model_uid,
251
+ choices=[
252
+ CompletionChoice(
253
+ index=0, text=response, finish_reason="stop", logprobs=None
254
+ )
255
+ ],
256
+ usage=CompletionUsage(
257
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
258
+ ),
259
+ )
260
+ return self._to_chat_completion(chunk)
261
+
262
+ def _streaming_chat_response(
263
+ self, inputs: Dict, config: Dict
264
+ ) -> Iterator[CompletionChunk]:
265
+ from threading import Thread
266
+
267
+ from transformers import TextIteratorStreamer
268
+
269
+ streamer = TextIteratorStreamer(
270
+ self._tokenizer, skip_prompt=True, skip_special_tokens=True
271
+ )
272
+ generation_kwargs = {
273
+ "input_ids": inputs["input_ids"],
274
+ "attention_mask": inputs["attention_mask"],
275
+ "token_type_ids": inputs["token_type_ids"],
276
+ "images": inputs["images"],
277
+ "max_new_tokens": config["max_new_tokens"],
278
+ "pad_token_id": config["pad_token_id"],
279
+ "streamer": streamer,
280
+ }
281
+
282
+ thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
283
+ thread.start()
284
+
285
+ completion_id = str(uuid.uuid1())
286
+ for new_text in streamer:
287
+ chunk = CompletionChunk(
288
+ id=completion_id,
289
+ object="text_completion",
290
+ created=int(time.time()),
291
+ model=self.model_uid,
292
+ choices=[
293
+ CompletionChoice(
294
+ index=0, text=new_text, finish_reason=None, logprobs=None
295
+ )
296
+ ],
297
+ usage=CompletionUsage(
298
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
299
+ ),
300
+ )
301
+ yield chunk
302
+
303
+ completion_choice = CompletionChoice(
304
+ text="", index=0, logprobs=None, finish_reason="stop"
305
+ )
306
+ chunk = CompletionChunk(
307
+ id=completion_id,
308
+ object="text_completion",
309
+ created=int(time.time()),
310
+ model=self.model_uid,
311
+ choices=[completion_choice],
312
+ usage=CompletionUsage(
313
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
314
+ ),
315
+ )
316
+ yield chunk
@@ -17,6 +17,7 @@ import logging
17
17
  import os
18
18
  from typing import Iterable, Iterator, List, Optional, Union
19
19
 
20
+ from ....core.scheduler import InferenceRequest
20
21
  from ....device_utils import (
21
22
  get_device_preferred_dtype,
22
23
  gpu_count,
@@ -40,6 +41,7 @@ from ...utils import select_device
40
41
  from ..core import LLM
41
42
  from ..llm_family import LLMFamilyV1, LLMSpecV1
42
43
  from ..utils import ChatModelMixin
44
+ from .utils import get_context_length, get_max_src_len
43
45
 
44
46
  logger = logging.getLogger(__name__)
45
47
 
@@ -53,6 +55,11 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
53
55
  "chatglm2",
54
56
  "chatglm2-32k",
55
57
  "chatglm2-128k",
58
+ "chatglm3",
59
+ "chatglm3-32k",
60
+ "chatglm3-128k",
61
+ "glm4-chat",
62
+ "glm4-chat-1m",
56
63
  "llama-2",
57
64
  "llama-2-chat",
58
65
  "internlm2-chat",
@@ -62,6 +69,9 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
62
69
  "deepseek-vl-chat",
63
70
  "internvl-chat",
64
71
  "mini-internvl-chat",
72
+ "cogvlm2",
73
+ "MiniCPM-Llama3-V-2_5",
74
+ "glm-4v",
65
75
  ]
66
76
 
67
77
 
@@ -95,6 +105,7 @@ class PytorchModel(LLM):
95
105
  pytorch_model_config.setdefault("gptq_act_order", False)
96
106
  pytorch_model_config.setdefault("device", "auto")
97
107
  pytorch_model_config.setdefault("trust_remote_code", True)
108
+ pytorch_model_config.setdefault("max_num_seqs", 16)
98
109
  return pytorch_model_config
99
110
 
100
111
  def _sanitize_generate_config(
@@ -453,6 +464,7 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
453
464
  pytorch_model_config,
454
465
  peft_model,
455
466
  )
467
+ self._context_len = None
456
468
 
457
469
  def _sanitize_generate_config(
458
470
  self,
@@ -496,13 +508,8 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
496
508
  chat_history: Optional[List[ChatCompletionMessage]] = None,
497
509
  generate_config: Optional[PytorchGenerateConfig] = None,
498
510
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
499
- assert self.model_family.prompt_style is not None
500
- prompt_style = self.model_family.prompt_style.copy()
501
- if system_prompt:
502
- prompt_style.system_prompt = system_prompt
503
- chat_history = chat_history or []
504
511
  tools = generate_config.pop("tools", []) if generate_config else None
505
- full_prompt = self.get_prompt(prompt, chat_history, prompt_style, tools=tools)
512
+ full_prompt = self._get_full_prompt(prompt, system_prompt, chat_history, tools)
506
513
 
507
514
  generate_config = self._sanitize_generate_config(generate_config)
508
515
  # TODO(codingl2k1): qwen hacky to set stop for function call.
@@ -530,3 +537,82 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
530
537
  self.model_family, self.model_uid, c, tools
531
538
  )
532
539
  return self._to_chat_completion(c)
540
+
541
+ def load(self):
542
+ super().load()
543
+ self._context_len = get_context_length(self._model.config)
544
+
545
+ def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
546
+ assert self.model_family.prompt_style is not None
547
+ prompt_style = self.model_family.prompt_style.copy()
548
+ if system_prompt:
549
+ prompt_style.system_prompt = system_prompt
550
+ chat_history = chat_history or []
551
+ full_prompt = ChatModelMixin.get_prompt(
552
+ prompt, chat_history, prompt_style, tools=tools
553
+ )
554
+ return full_prompt
555
+
556
+ def get_max_num_seqs(self) -> int:
557
+ return self._pytorch_model_config.get("max_num_seqs") # type: ignore
558
+
559
+ def batch_inference(self, req_list: List[InferenceRequest]):
560
+ from .utils import batch_inference_one_step
561
+
562
+ for r in req_list:
563
+ if r.sanitized_generate_config is None:
564
+ r.sanitized_generate_config = self._sanitize_generate_config(
565
+ r.generate_config
566
+ )
567
+ if r.is_prefill:
568
+ # check some generate params
569
+ max_src_len = get_max_src_len(self._context_len, r) # type: ignore
570
+ if max_src_len < 0:
571
+ r.stopped = True
572
+ r.error_msg = "Max tokens exceeds model's max length"
573
+ continue
574
+ if r.stream_interval <= 0:
575
+ r.stopped = True
576
+ r.error_msg = "`stream_interval` must be greater than 0"
577
+ continue
578
+ stop_str = r.sanitized_generate_config.get("stop", None)
579
+ if stop_str and (
580
+ not (isinstance(stop_str, str) or isinstance(stop_str, Iterable))
581
+ ):
582
+ r.stopped = True
583
+ r.error_msg = "Invalid `stop` field type"
584
+ continue
585
+ r.full_prompt = self._get_full_prompt(
586
+ r.prompt, r.system_prompt, r.chat_history, None
587
+ )
588
+
589
+ assert isinstance(self._context_len, int)
590
+ batch_inference_one_step(
591
+ req_list,
592
+ self.model_uid,
593
+ self._model,
594
+ self._tokenizer,
595
+ self._device,
596
+ self._context_len,
597
+ )
598
+ for req in req_list:
599
+ if req.stream and req.error_msg is None:
600
+ if req.completion:
601
+ results = []
602
+ for i, c in enumerate(req.completion):
603
+ if c == "<bos_stream>":
604
+ results.append(
605
+ self._get_first_chat_completion_chunk(
606
+ req.completion[i + 1]
607
+ )
608
+ )
609
+ elif c == "<eos_stream>":
610
+ break
611
+ else:
612
+ results.append(self._to_chat_completion_chunk(c))
613
+
614
+ if req.stopped and req.include_usage:
615
+ results.append(
616
+ self._get_final_chat_completion_chunk(req.completion[-1])
617
+ )
618
+ req.completion = results