xinference 0.14.1.post1__py3-none-any.whl → 0.14.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (87) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +15 -34
  3. xinference/client/restful/restful_client.py +2 -2
  4. xinference/core/chat_interface.py +44 -9
  5. xinference/core/model.py +4 -4
  6. xinference/core/scheduler.py +1 -2
  7. xinference/core/worker.py +1 -1
  8. xinference/deploy/cmdline.py +2 -2
  9. xinference/deploy/test/test_cmdline.py +7 -7
  10. xinference/model/llm/__init__.py +20 -27
  11. xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
  12. xinference/model/llm/llm_family.json +448 -1153
  13. xinference/model/llm/llm_family.py +14 -139
  14. xinference/model/llm/llm_family_modelscope.json +230 -313
  15. xinference/model/llm/memory.py +9 -9
  16. xinference/model/llm/sglang/core.py +2 -2
  17. xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
  18. xinference/model/llm/{pytorch → transformers}/core.py +2 -10
  19. xinference/model/llm/transformers/intern_vl.py +457 -0
  20. xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
  21. xinference/model/llm/{pytorch → transformers}/minicpmv26.py +67 -22
  22. xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
  23. xinference/model/llm/utils.py +76 -70
  24. xinference/model/llm/vllm/core.py +110 -11
  25. xinference/model/utils.py +1 -95
  26. xinference/thirdparty/internvl/__init__.py +0 -0
  27. xinference/thirdparty/internvl/conversation.py +393 -0
  28. xinference/thirdparty/omnilmm/model/utils.py +16 -1
  29. xinference/web/ui/build/asset-manifest.json +3 -3
  30. xinference/web/ui/build/index.html +1 -1
  31. xinference/web/ui/build/static/js/main.ffc26121.js +3 -0
  32. xinference/web/ui/build/static/js/main.ffc26121.js.map +1 -0
  33. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
  34. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
  35. xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
  36. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
  37. xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
  38. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
  39. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
  40. xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
  41. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
  42. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
  43. xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
  44. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
  45. {xinference-0.14.1.post1.dist-info → xinference-0.14.2.dist-info}/METADATA +5 -8
  46. {xinference-0.14.1.post1.dist-info → xinference-0.14.2.dist-info}/RECORD +63 -70
  47. xinference/locale/utils.py +0 -39
  48. xinference/locale/zh_CN.json +0 -26
  49. xinference/model/llm/ggml/tools/__init__.py +0 -15
  50. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
  51. xinference/model/llm/ggml/tools/gguf.py +0 -884
  52. xinference/model/llm/pytorch/__init__.py +0 -13
  53. xinference/model/llm/pytorch/baichuan.py +0 -81
  54. xinference/model/llm/pytorch/falcon.py +0 -138
  55. xinference/model/llm/pytorch/intern_vl.py +0 -352
  56. xinference/model/llm/pytorch/vicuna.py +0 -69
  57. xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
  58. xinference/web/ui/build/static/js/main.17ca0398.js.map +0 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
  63. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
  64. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
  65. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
  66. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
  67. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
  68. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
  69. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
  70. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
  71. /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
  72. /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
  73. /xinference/model/llm/{pytorch → transformers}/cogvlm2.py +0 -0
  74. /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
  75. /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
  76. /xinference/model/llm/{pytorch → transformers}/glm4v.py +0 -0
  77. /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
  78. /xinference/model/llm/{pytorch → transformers}/minicpmv25.py +0 -0
  79. /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
  80. /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
  81. /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
  82. /xinference/model/llm/{pytorch → transformers}/yi_vl.py +0 -0
  83. /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.ffc26121.js.LICENSE.txt} +0 -0
  84. {xinference-0.14.1.post1.dist-info → xinference-0.14.2.dist-info}/LICENSE +0 -0
  85. {xinference-0.14.1.post1.dist-info → xinference-0.14.2.dist-info}/WHEEL +0 -0
  86. {xinference-0.14.1.post1.dist-info → xinference-0.14.2.dist-info}/entry_points.txt +0 -0
  87. {xinference-0.14.1.post1.dist-info → xinference-0.14.2.dist-info}/top_level.txt +0 -0
@@ -189,7 +189,7 @@ class SGLANGModel(LLM):
189
189
  return False
190
190
  if not cls._is_linux():
191
191
  return False
192
- if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
192
+ if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
193
193
  return False
194
194
  if llm_spec.model_format == "pytorch":
195
195
  if quantization != "none" and not (quantization is None):
@@ -378,7 +378,7 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
378
378
  def match(
379
379
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
380
380
  ) -> bool:
381
- if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
381
+ if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
382
382
  return False
383
383
  if llm_spec.model_format == "pytorch":
384
384
  if quantization != "none" and not (quantization is None):
@@ -344,7 +344,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
344
344
  return kwargs, tools
345
345
 
346
346
  @torch.inference_mode()
347
- def stream_chat(
347
+ def _stream_chat(
348
348
  self,
349
349
  tokenizer,
350
350
  query: str,
@@ -399,7 +399,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
399
399
  yield new_response, new_history
400
400
 
401
401
  @torch.inference_mode()
402
- def non_stream_chat(
402
+ def _non_stream_chat(
403
403
  self,
404
404
  tokenizer,
405
405
  query: str,
@@ -475,10 +475,6 @@ class ChatglmPytorchChatModel(PytorchChatModel):
475
475
  if stream and (
476
476
  not tools or self.model_family.model_name in GLM4_TOOL_CALL_FAMILY
477
477
  ):
478
- if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
479
- stream_chat = self.stream_chat
480
- else:
481
- stream_chat = self._model.stream_chat
482
478
 
483
479
  def _stream_generator():
484
480
  last_chunk_text_length = 0
@@ -487,7 +483,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
487
483
  inputs = self._tokenizer([prompt], return_tensors="pt")
488
484
  inputs = inputs.to(self._model.device)
489
485
  prompt_tokens = len(inputs["input_ids"][0])
490
- for chunk_text, _ in stream_chat(
486
+ for chunk_text, _ in self._stream_chat(
491
487
  self._tokenizer, prompt, chat_history, **kwargs
492
488
  ):
493
489
  if tools and isinstance(chunk_text, dict):
@@ -548,12 +544,9 @@ class ChatglmPytorchChatModel(PytorchChatModel):
548
544
 
549
545
  return self._to_chat_completion_chunks(_stream_generator())
550
546
  else:
551
- if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
552
- chat = self.non_stream_chat
553
- else:
554
- chat = self._model.chat
555
-
556
- response = chat(self._tokenizer, prompt, chat_history, **kwargs)
547
+ response = self._non_stream_chat(
548
+ self._tokenizer, prompt, chat_history, **kwargs
549
+ )
557
550
  if tools:
558
551
  return self._tool_calls_completion(
559
552
  self.model_family, self.model_uid, response, tools
@@ -47,15 +47,6 @@ from .utils import get_context_length, get_max_src_len, pad_prefill_tokens
47
47
  logger = logging.getLogger(__name__)
48
48
 
49
49
  NON_DEFAULT_MODEL_LIST: List[str] = [
50
- "baichuan-chat",
51
- "baichuan-2-chat",
52
- "vicuna-v1.3",
53
- "falcon",
54
- "falcon-instruct",
55
- "chatglm",
56
- "chatglm2",
57
- "chatglm2-32k",
58
- "chatglm2-128k",
59
50
  "chatglm3",
60
51
  "chatglm3-32k",
61
52
  "chatglm3-128k",
@@ -64,12 +55,13 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
64
55
  "llama-2",
65
56
  "llama-2-chat",
66
57
  "internlm2-chat",
58
+ "internlm2.5-chat",
67
59
  "qwen-vl-chat",
68
60
  "OmniLMM",
69
61
  "yi-vl-chat",
70
62
  "deepseek-vl-chat",
71
63
  "internvl-chat",
72
- "mini-internvl-chat",
64
+ "internvl2",
73
65
  "cogvlm2",
74
66
  "MiniCPM-Llama3-V-2_5",
75
67
  "MiniCPM-V-2.6",
@@ -0,0 +1,457 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import time
16
+ import uuid
17
+ from concurrent.futures import ThreadPoolExecutor
18
+ from typing import Dict, Iterator, List, Optional, Union
19
+
20
+ import torch
21
+
22
+ from ....types import (
23
+ ChatCompletion,
24
+ ChatCompletionChunk,
25
+ ChatCompletionMessage,
26
+ Completion,
27
+ CompletionChoice,
28
+ CompletionChunk,
29
+ CompletionUsage,
30
+ )
31
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
32
+ from ..utils import _decode_image
33
+ from .core import PytorchChatModel, PytorchGenerateConfig
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
38
+ IMAGENET_STD = (0.229, 0.224, 0.225)
39
+
40
+
41
+ def _message_content_to_intern(content, image_cnt):
42
+ if not isinstance(content, str):
43
+ texts = []
44
+ image_urls = []
45
+ for c in content:
46
+ c_type = c.get("type")
47
+ if c_type == "text":
48
+ texts.append(c["text"])
49
+ elif c_type == "image_url":
50
+ image_urls.append(c["image_url"]["url"])
51
+ image_futures = []
52
+ with ThreadPoolExecutor() as executor:
53
+ for image_url in image_urls:
54
+ fut = executor.submit(_decode_image, image_url)
55
+ image_futures.append(fut)
56
+ images = [fut.result() for fut in image_futures]
57
+ prefix = ""
58
+ for i, _ in enumerate(images):
59
+ prefix += f"Image-{image_cnt + i + 1}: <image>\n\n"
60
+ text = prefix + " ".join(texts)
61
+ if len(images) == 0:
62
+ return text, []
63
+ else:
64
+ return text, images
65
+ return content, []
66
+
67
+
68
+ def _get_prompt_and_chat_history(
69
+ prompt: Union[str, List[Dict]],
70
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
71
+ ):
72
+ # Convert openai history to intern vl history
73
+ images = []
74
+ history = []
75
+ image_cnt = 0
76
+ for h1, h2 in zip(*[iter(chat_history or [])] * 2):
77
+ content1, img = _message_content_to_intern(h1["content"], image_cnt)
78
+ content2, _ = _message_content_to_intern(h2["content"], image_cnt)
79
+ history.append([content1, content2])
80
+ images.extend(img)
81
+ image_cnt += len(img)
82
+
83
+ question, img = _message_content_to_intern(prompt, image_cnt)
84
+ images.extend(img)
85
+ return question, history, images
86
+
87
+
88
+ def _build_transform(input_size=448):
89
+ import torchvision.transforms as T
90
+ from torchvision.transforms.functional import InterpolationMode
91
+
92
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
93
+ transform = T.Compose(
94
+ [
95
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
96
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
97
+ T.ToTensor(),
98
+ T.Normalize(mean=MEAN, std=STD),
99
+ ]
100
+ )
101
+ return transform
102
+
103
+
104
+ def _find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
105
+ best_ratio_diff = float("inf")
106
+ best_ratio = (1, 1)
107
+ area = width * height
108
+ for ratio in target_ratios:
109
+ target_aspect_ratio = ratio[0] / ratio[1]
110
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
111
+ if ratio_diff < best_ratio_diff:
112
+ best_ratio_diff = ratio_diff
113
+ best_ratio = ratio
114
+ elif ratio_diff == best_ratio_diff:
115
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
116
+ best_ratio = ratio
117
+ return best_ratio
118
+
119
+
120
+ def _dynamic_preprocess(
121
+ image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
122
+ ):
123
+ orig_width, orig_height = image.size
124
+ aspect_ratio = orig_width / orig_height
125
+
126
+ # calculate the existing image aspect ratio
127
+ target_ratios = set(
128
+ (i, j)
129
+ for n in range(min_num, max_num + 1)
130
+ for i in range(1, n + 1)
131
+ for j in range(1, n + 1)
132
+ if i * j <= max_num and i * j >= min_num
133
+ )
134
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
135
+
136
+ # find the closest aspect ratio to the target
137
+ target_aspect_ratio = _find_closest_aspect_ratio(
138
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
139
+ )
140
+
141
+ # calculate the target width and height
142
+ target_width = image_size * target_aspect_ratio[0]
143
+ target_height = image_size * target_aspect_ratio[1]
144
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
145
+
146
+ # resize the image
147
+ resized_img = image.resize((target_width, target_height))
148
+ processed_images = []
149
+ for i in range(blocks):
150
+ box = (
151
+ (i % (target_width // image_size)) * image_size,
152
+ (i // (target_width // image_size)) * image_size,
153
+ ((i % (target_width // image_size)) + 1) * image_size,
154
+ ((i // (target_width // image_size)) + 1) * image_size,
155
+ )
156
+ # split the image
157
+ split_img = resized_img.crop(box)
158
+ processed_images.append(split_img)
159
+ assert len(processed_images) == blocks
160
+ if use_thumbnail and len(processed_images) != 1:
161
+ thumbnail_img = image.resize((image_size, image_size))
162
+ processed_images.append(thumbnail_img)
163
+ return processed_images
164
+
165
+
166
+ def _load_image(image_file, input_size=448, max_num=12):
167
+ image = image_file.convert("RGB")
168
+ transform = _build_transform(input_size=input_size)
169
+ images = _dynamic_preprocess(
170
+ image, image_size=input_size, use_thumbnail=True, max_num=max_num
171
+ )
172
+ pixel_values = [transform(image) for image in images]
173
+ pixel_values = torch.stack(pixel_values)
174
+ return pixel_values
175
+
176
+
177
+ class InternVLChatModel(PytorchChatModel):
178
+ def __init__(self, *args, **kwargs):
179
+ super().__init__(*args, **kwargs)
180
+ self._tokenizer = None
181
+ self._model = None
182
+
183
+ @classmethod
184
+ def match(
185
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
186
+ ) -> bool:
187
+ family = model_family.model_family or model_family.model_name
188
+ if "internvl" not in family.lower():
189
+ return False
190
+ if "pytorch" not in model_spec.model_format:
191
+ return False
192
+ return True
193
+
194
+ def _get_model_class(self):
195
+ from transformers import AutoModel
196
+
197
+ return AutoModel
198
+
199
+ # Copy from InternVL page
200
+ # reference: https://huggingface.co/OpenGVLab/InternVL2-8B
201
+ def _split_model(self):
202
+ import math
203
+
204
+ device_map = {}
205
+ world_size = torch.cuda.device_count()
206
+ # single gpu
207
+ if world_size == 1:
208
+ return None
209
+ model_size = f"{self.model_spec.model_size_in_billions}B"
210
+ num_layers = {
211
+ "1B": 24,
212
+ "2B": 24,
213
+ "4B": 32,
214
+ "8B": 32,
215
+ "26B": 48,
216
+ "40B": 60,
217
+ "76B": 80,
218
+ }[model_size]
219
+ # Since the first GPU will be used for ViT, treat it as half a GPU.
220
+ num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
221
+ num_layers_per_gpu = [num_layers_per_gpu] * world_size
222
+ num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
223
+ layer_cnt = 0
224
+ for i, num_layer in enumerate(num_layers_per_gpu):
225
+ for j in range(num_layer):
226
+ device_map[f"language_model.model.layers.{layer_cnt}"] = i
227
+ layer_cnt += 1
228
+ device_map["vision_model"] = 0
229
+ device_map["mlp1"] = 0
230
+ device_map["language_model.model.tok_embeddings"] = 0
231
+ device_map["language_model.model.embed_tokens"] = 0
232
+ device_map["language_model.output"] = 0
233
+ device_map["language_model.model.norm"] = 0
234
+ device_map["language_model.lm_head"] = 0
235
+ device_map[f"language_model.model.layers.{num_layers - 1}"] = 0
236
+ return device_map
237
+
238
+ def load(self, **kwargs):
239
+ from transformers import AutoModel, AutoTokenizer
240
+
241
+ if self._check_tensorizer_integrity():
242
+ self._model, self._tokenizer = self._load_tensorizer()
243
+ return
244
+
245
+ device = self._split_model()
246
+
247
+ kwargs = {
248
+ "torch_dtype": torch.bfloat16,
249
+ "low_cpu_mem_usage": True,
250
+ "trust_remote_code": True,
251
+ }
252
+
253
+ if device is not None:
254
+ kwargs["device_map"] = device
255
+
256
+ if "8-bit" in self.quantization.lower():
257
+ kwargs["load_in_8bit"] = True
258
+ elif "4-bit" in self.quantization.lower():
259
+ kwargs["load_in_4bit"] = True
260
+
261
+ self._model = AutoModel.from_pretrained(self.model_path, **kwargs).eval()
262
+
263
+ if device is None and "none" in self.quantization.lower():
264
+ self._model.cuda()
265
+
266
+ self._tokenizer = AutoTokenizer.from_pretrained(
267
+ self.model_path,
268
+ trust_remote_code=True,
269
+ use_fast=False,
270
+ )
271
+
272
+ def chat(
273
+ self,
274
+ prompt: Union[str, List[Dict]],
275
+ system_prompt: Optional[str] = None,
276
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
277
+ generate_config: Optional[PytorchGenerateConfig] = None,
278
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
279
+ from ....thirdparty.internvl.conversation import get_conv_template
280
+
281
+ IMG_START_TOKEN = "<img>"
282
+ IMG_END_TOKEN = "</img>"
283
+ IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
284
+
285
+ generation_config = {
286
+ "max_new_tokens": generate_config.get("max_tokens", 1024)
287
+ if generate_config
288
+ else 1024,
289
+ "do_sample": False,
290
+ }
291
+
292
+ stream = (
293
+ generate_config.get("stream", False)
294
+ if isinstance(generate_config, dict)
295
+ else False
296
+ )
297
+ stream_options = (
298
+ generate_config.get("stream_options", None)
299
+ if isinstance(generate_config, dict)
300
+ else False
301
+ )
302
+ include_usage = (
303
+ stream_options["include_usage"]
304
+ if isinstance(stream_options, dict)
305
+ else False
306
+ )
307
+
308
+ content, history, images = _get_prompt_and_chat_history(prompt, chat_history)
309
+
310
+ num_patches_list = []
311
+ if len(images) == 1:
312
+ content = content.replace("Image-1: <image>\n\n", "<image>\n")
313
+ history = [
314
+ [item[0].replace("Image-1: <image>\n\n", "<image>\n"), item[1]]
315
+ for item in history
316
+ ]
317
+ pixel_values = _load_image(images[-1], max_num=12).to(torch.bfloat16).cuda()
318
+ num_patches_list = (
319
+ [pixel_values.shape[0]] if pixel_values is not None else []
320
+ )
321
+ elif len(images) > 1:
322
+ pixel_values = [
323
+ _load_image(img, max_num=12).to(torch.bfloat16).cuda() for img in images
324
+ ]
325
+ num_patches_list = [values.size(0) for values in pixel_values]
326
+ pixel_values = torch.cat(pixel_values, dim=0)
327
+ else:
328
+ pixel_values = None
329
+
330
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
331
+
332
+ img_context_token_id = self._tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
333
+ self._model.img_context_token_id = img_context_token_id
334
+
335
+ template = get_conv_template(self._model.template)
336
+ template.system_message = self._model.system_message
337
+ eos_token_id = self._tokenizer.convert_tokens_to_ids(template.sep)
338
+
339
+ history = [] if history is None else history
340
+ for old_question, old_answer in history:
341
+ template.append_message(template.roles[0], old_question)
342
+ template.append_message(template.roles[1], old_answer)
343
+ template.append_message(template.roles[0], content)
344
+ template.append_message(template.roles[1], None)
345
+ query = template.get_prompt()
346
+
347
+ for num_patches in num_patches_list:
348
+ image_tokens = (
349
+ IMG_START_TOKEN
350
+ + IMG_CONTEXT_TOKEN * self._model.num_image_token * num_patches
351
+ + IMG_END_TOKEN
352
+ )
353
+ query = query.replace("<image>", image_tokens, 1)
354
+
355
+ model_inputs = self._tokenizer(query, return_tensors="pt")
356
+ input_ids = model_inputs["input_ids"].cuda()
357
+ attention_mask = model_inputs["attention_mask"].cuda()
358
+ generation_config["eos_token_id"] = eos_token_id
359
+ generate_kwargs = {
360
+ "pixel_values": pixel_values,
361
+ "input_ids": input_ids,
362
+ "attention_mask": attention_mask,
363
+ }
364
+ generate_kwargs.update(generation_config)
365
+
366
+ if stream:
367
+ chunk = self._generate_stream(generate_kwargs, input_ids, include_usage)
368
+ return self._to_chat_completion_chunks(chunk)
369
+ else:
370
+ chunk = self._generate(generate_kwargs, input_ids, template)
371
+ return self._to_chat_completion(chunk)
372
+
373
+ def _generate(self, generate_kwargs, input_ids, template):
374
+ prompt_tokens = len(input_ids[0])
375
+ generation_output = self._model.generate(**generate_kwargs)
376
+ completion_tokens = len(generation_output[0])
377
+ response = self._tokenizer.batch_decode(
378
+ generation_output, skip_special_tokens=True
379
+ )[0]
380
+ response = response.split(template.sep)[0].strip()
381
+ chunk = Completion(
382
+ id=str(uuid.uuid1()),
383
+ object="text_completion",
384
+ created=int(time.time()),
385
+ model=self.model_uid,
386
+ choices=[
387
+ CompletionChoice(
388
+ index=0, text=response, finish_reason="stop", logprobs=None
389
+ )
390
+ ],
391
+ usage=CompletionUsage(
392
+ prompt_tokens=prompt_tokens,
393
+ completion_tokens=completion_tokens,
394
+ total_tokens=prompt_tokens + completion_tokens,
395
+ ),
396
+ )
397
+ return chunk
398
+
399
+ def _generate_stream(self, generate_kwargs, input_ids, include_usage):
400
+ from threading import Thread
401
+
402
+ from transformers import TextIteratorStreamer
403
+
404
+ # Initialize the streamer
405
+ streamer = TextIteratorStreamer(
406
+ self._tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
407
+ )
408
+ # Define the generation configuration
409
+ generate_kwargs["streamer"] = streamer
410
+ # Start the model chat in a separate thread
411
+ thread = Thread(
412
+ target=self._model.generate,
413
+ kwargs=generate_kwargs,
414
+ )
415
+ thread.start()
416
+
417
+ completion_id = str(uuid.uuid1())
418
+ prompt_tokens = len(input_ids[0])
419
+ completion_tokens = 0
420
+ # Loop through the streamer to get the new text as it is generated
421
+ for i, new_text in enumerate(streamer):
422
+ if new_text == self._model.conv_template.sep:
423
+ break
424
+ completion_choice = CompletionChoice(
425
+ text=new_text, index=0, logprobs=None, finish_reason=None
426
+ )
427
+ chunk = CompletionChunk(
428
+ id=completion_id,
429
+ object="text_completion",
430
+ created=int(time.time()),
431
+ model=self.model_uid,
432
+ choices=[completion_choice],
433
+ )
434
+ completion_tokens = max(completion_tokens, len(streamer.token_cache))
435
+ total_tokens = prompt_tokens + completion_tokens
436
+ completion_usage = CompletionUsage(
437
+ prompt_tokens=prompt_tokens,
438
+ completion_tokens=completion_tokens,
439
+ total_tokens=total_tokens,
440
+ )
441
+ chunk["usage"] = completion_usage
442
+ yield chunk
443
+
444
+ if include_usage:
445
+ chunk = CompletionChunk(
446
+ id=completion_id,
447
+ object="text_completion",
448
+ created=int(time.time()),
449
+ model=self.model_uid,
450
+ choices=[],
451
+ )
452
+ chunk["usage"] = CompletionUsage(
453
+ prompt_tokens=prompt_tokens,
454
+ completion_tokens=completion_tokens,
455
+ total_tokens=total_tokens,
456
+ )
457
+ yield chunk
@@ -85,14 +85,10 @@ class Internlm2PytorchChatModel(PytorchChatModel):
85
85
  def match(
86
86
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
87
87
  ) -> bool:
88
- if llm_spec.model_format != "pytorch":
89
- return False
90
88
  model_family = llm_family.model_family or llm_family.model_name
91
- if model_family != "internlm2-chat":
92
- return False
93
- if "chat" not in llm_family.model_ability:
94
- return False
95
- return True
89
+ if model_family in ["internlm2-chat", "internlm2.5-chat"]:
90
+ return True
91
+ return False
96
92
 
97
93
  def prepare_sanitize_generate_config(self, req: InferenceRequest):
98
94
  """
@@ -153,7 +149,7 @@ class Internlm2PytorchChatModel(PytorchChatModel):
153
149
  inputs = inputs.to(self._model.device)
154
150
  prompt_tokens = len(inputs["input_ids"][0])
155
151
  for chunk_text, _ in self._model.stream_chat(
156
- self._tokenizer, prompt, chat_history, **kwargs
152
+ self._tokenizer, prompt, input_history, **kwargs
157
153
  ):
158
154
  completion_tokens = completion_tokens + 1
159
155
  total_tokens = prompt_tokens + completion_tokens