xinference 0.11.1__py3-none-any.whl → 0.11.2.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (31) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +30 -0
  3. xinference/client/restful/restful_client.py +29 -0
  4. xinference/core/cache_tracker.py +12 -1
  5. xinference/core/supervisor.py +30 -2
  6. xinference/core/utils.py +12 -0
  7. xinference/core/worker.py +4 -1
  8. xinference/deploy/cmdline.py +126 -0
  9. xinference/deploy/test/test_cmdline.py +24 -0
  10. xinference/model/llm/__init__.py +2 -0
  11. xinference/model/llm/llm_family.json +501 -6
  12. xinference/model/llm/llm_family.py +84 -10
  13. xinference/model/llm/llm_family_modelscope.json +198 -7
  14. xinference/model/llm/memory.py +332 -0
  15. xinference/model/llm/pytorch/core.py +2 -0
  16. xinference/model/llm/pytorch/intern_vl.py +347 -0
  17. xinference/model/llm/utils.py +13 -0
  18. xinference/model/llm/vllm/core.py +5 -2
  19. xinference/model/rerank/core.py +23 -1
  20. xinference/model/utils.py +17 -7
  21. xinference/thirdparty/deepseek_vl/models/processing_vlm.py +1 -1
  22. xinference/thirdparty/deepseek_vl/models/siglip_vit.py +2 -2
  23. xinference/thirdparty/llava/mm_utils.py +3 -2
  24. xinference/thirdparty/llava/model/llava_arch.py +1 -1
  25. xinference/thirdparty/omnilmm/chat.py +6 -5
  26. {xinference-0.11.1.dist-info → xinference-0.11.2.post1.dist-info}/METADATA +8 -7
  27. {xinference-0.11.1.dist-info → xinference-0.11.2.post1.dist-info}/RECORD +31 -29
  28. {xinference-0.11.1.dist-info → xinference-0.11.2.post1.dist-info}/LICENSE +0 -0
  29. {xinference-0.11.1.dist-info → xinference-0.11.2.post1.dist-info}/WHEEL +0 -0
  30. {xinference-0.11.1.dist-info → xinference-0.11.2.post1.dist-info}/entry_points.txt +0 -0
  31. {xinference-0.11.1.dist-info → xinference-0.11.2.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,332 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # NOTE:
16
+ #
17
+ # The algorithum is ported from https://github.com/RahulSChand/gpu_poor
18
+ #
19
+ # Improvement:
20
+ #
21
+ # The original js code only calculate kv_cache_dtype by float32, instead of most case we run model with float16.
22
+ #
23
+ # Known Issue:
24
+ #
25
+ # * On vllm, some MHA model use smaller memory than calculation (qwen1.5-7B-chat-gptq-int4,
26
+ # qwen1.5-14B-chat-gptq-int4 with large activation_mem).
27
+ #
28
+ # * On vllm, gemma-it-7B pytorch format model use larger gpu mem than calculation
29
+
30
+ import json
31
+ import math
32
+ from dataclasses import dataclass
33
+ from logging import getLogger
34
+ from math import ceil
35
+ from typing import Any, Optional, Union
36
+
37
+ from .llm_family import convert_model_size_to_float
38
+
39
+ logger = getLogger(__name__)
40
+
41
+
42
+ @dataclass
43
+ class ModelLayersInfo:
44
+ vocab_size: int
45
+ heads: int # num_attention_heads, num_heads or n_head
46
+ hidden_dim: int # hidden_size, d_model, or n_embd
47
+ inter_dim: int # intermediate_size, n_inner or d_ff
48
+ num_layers: int # num_layers, num_hidden_layers or n_layer
49
+
50
+
51
+ @dataclass
52
+ class ModelMemInfo:
53
+ """Memory required by model, unit in MB"""
54
+
55
+ model_mem: int
56
+ kv_cache_mem: int
57
+ activation_mem: int
58
+ overhead: int
59
+ total: int
60
+
61
+
62
+ QUANT_NORMALIZE = {"int4": "4-bit", "int8": "8-bit", "4-bit": "4-bit", "8-bit": "8-bit"}
63
+
64
+ GGML_MULTI_FACTOR_DICT = {
65
+ "q4_0": 18,
66
+ "q4_1": 20,
67
+ "q5_0": 22,
68
+ "q5_1": 24,
69
+ "q8_0": 34,
70
+ "q8_1": 40,
71
+ }
72
+
73
+ GGML_MULTI_FACTOR_DICT_64 = {
74
+ "q6_K": 54.0,
75
+ "q3": 26.0,
76
+ "q4": 38.0,
77
+ "q5": 46.0,
78
+ }
79
+
80
+ GGML_MULTI_FACTOR_DICT_COMBINE = {
81
+ "q3_K_L": [38.0, 26.0],
82
+ "q3_K_M": [46.0, 26.0],
83
+ "q4_K_S": [46.0, 38.0],
84
+ "q4_K_M": [54.0, 38.0],
85
+ "q5_K_M": [54.0, 46.0],
86
+ "q2_K": [26.0, 22.0],
87
+ }
88
+
89
+
90
+ # Return gpu memory in MB
91
+ def estimate_llm_gpu_memory(
92
+ model_size_in_billions: Union[str, int],
93
+ quantization: Optional[str],
94
+ context_length: int, # input+output
95
+ model_format: str,
96
+ model_name: Optional[str] = None,
97
+ kv_cache_dtype: int = 16,
98
+ ) -> Optional[ModelMemInfo]:
99
+ """
100
+ model_size_in_billions: must be str like 1_8 or 46_7, to match llm.
101
+ """
102
+ info = get_model_layers_info(
103
+ model_size_in_billions,
104
+ model_name,
105
+ model_format,
106
+ quantization,
107
+ )
108
+ if info is None:
109
+ return None
110
+ size_in_billions = convert_model_size_to_float(model_size_in_billions)
111
+ return estimate_llm_gpu_memory_details(
112
+ info,
113
+ size_in_billions,
114
+ quantization,
115
+ context_length,
116
+ model_format,
117
+ kv_cache_dtype,
118
+ )
119
+
120
+
121
+ def estimate_llm_gpu_memory_details(
122
+ info: ModelLayersInfo,
123
+ size_in_billions: float,
124
+ quantization: Optional[str],
125
+ context_length: int, # input+output
126
+ model_format: str,
127
+ kv_cache_dtype: int = 16,
128
+ ) -> ModelMemInfo:
129
+ """return model_mem, kv_cache, overhead, activation_mem"""
130
+ if kv_cache_dtype not in [8, 16, 32]:
131
+ raise ValueError(f"Invalid kv_cache_dtype {kv_cache_dtype}")
132
+ if kv_cache_dtype == 8:
133
+ kv_dtype_size = 1
134
+ elif kv_cache_dtype == 16:
135
+ kv_dtype_size = 2
136
+ else:
137
+ kv_dtype_size = 4
138
+ overhead = 650.0
139
+ if model_format == "ggmlv3":
140
+ assert quantization is not None and quantization != "none"
141
+ model_size_in_mb = _compute_model_size_ggml(info, quantization)
142
+ inference_mem = float(
143
+ context_length * kv_dtype_size * info.hidden_dim * info.num_layers
144
+ )
145
+ inference_mem = inference_mem / 1024.0 / 1024.0
146
+ activation_mem = _compute_inference_only_activation_memory(context_length, info)
147
+ overhead = overhead + context_length * 0.1
148
+ else:
149
+ if quantization is not None:
150
+ assert isinstance(quantization, str)
151
+ quantization = QUANT_NORMALIZE[quantization.lower()]
152
+ assert quantization is not None
153
+
154
+ model_size = size_in_billions * 1000000000.0
155
+ model_size_in_mb = _convert_to_mb_model_size(model_size, quantization)
156
+ # KV cache
157
+ inference_mem = float(
158
+ context_length * 2 * kv_dtype_size * info.hidden_dim * info.num_layers
159
+ )
160
+ inference_mem = inference_mem / 1024.0 / 1024.0
161
+ activation_mem = _compute_inference_only_activation_memory(context_length, info)
162
+
163
+ total_mem = ceil(inference_mem + model_size_in_mb + overhead + activation_mem)
164
+ return ModelMemInfo(
165
+ model_mem=ceil(model_size_in_mb),
166
+ kv_cache_mem=ceil(inference_mem),
167
+ activation_mem=ceil(activation_mem),
168
+ overhead=ceil(overhead),
169
+ total=total_mem,
170
+ )
171
+
172
+
173
+ def _load_item_from_json(config_data: Any, *keys: str) -> str:
174
+ assert len(keys) > 0
175
+ for key in keys:
176
+ v = config_data.get(key)
177
+ if v is not None:
178
+ return v
179
+ raise ValueError("load ModelLayersInfo: missing %s" % (keys[0]))
180
+
181
+
182
+ def load_model_config_json(config_path: str) -> ModelLayersInfo:
183
+ with open(config_path, "r") as f:
184
+ config_data = json.load(f)
185
+ return ModelLayersInfo(
186
+ vocab_size=int(_load_item_from_json(config_data, "vocab_size")),
187
+ heads=int(
188
+ _load_item_from_json(
189
+ config_data, "num_key_value_heads", "num_attention_heads"
190
+ )
191
+ ),
192
+ hidden_dim=int(
193
+ _load_item_from_json(config_data, "hidden_size", "d_model", "n_embd")
194
+ ),
195
+ inter_dim=int(_load_item_from_json(config_data, "intermediate_size")),
196
+ num_layers=int(
197
+ _load_item_from_json(
198
+ config_data, "num_hidden_layers", "num_layers", "n_layer"
199
+ )
200
+ ),
201
+ )
202
+
203
+
204
+ def get_model_layers_info(
205
+ model_size_in_billions: Union[str, int],
206
+ model_name: Optional[str],
207
+ model_format: Optional[str],
208
+ quantization: Optional[str],
209
+ ) -> Optional[ModelLayersInfo]:
210
+ from . import match_llm
211
+ from .llm_family import cache_model_config
212
+
213
+ if not model_name:
214
+ logger.debug("get_model_layers_info by default size=%s", model_size_in_billions)
215
+ size_in_billions = convert_model_size_to_float(model_size_in_billions)
216
+ return _get_default_layers_from_size(size_in_billions)
217
+ match_result = match_llm(
218
+ model_name=model_name,
219
+ model_format=model_format,
220
+ model_size_in_billions=model_size_in_billions,
221
+ quantization=quantization,
222
+ )
223
+ if not match_result:
224
+ return None
225
+ llm_family, llm_spec, _quant = match_result
226
+ config_path = cache_model_config(llm_family, llm_spec)
227
+ return load_model_config_json(config_path)
228
+
229
+
230
+ def _get_default_layers_from_size(size_in_billion: float) -> ModelLayersInfo:
231
+ if size_in_billion < 5:
232
+ vocab_size = 32000
233
+ heads = 32
234
+ num_layers = 24
235
+ elif size_in_billion < 10:
236
+ vocab_size = 32000
237
+ heads = 32
238
+ num_layers = 32
239
+ elif size_in_billion < 24:
240
+ vocab_size = 32000
241
+ heads = 40
242
+ num_layers = 40
243
+ elif size_in_billion < 55:
244
+ vocab_size = 32000
245
+ heads = 60
246
+ num_layers = 48
247
+ else:
248
+ vocab_size = 32000
249
+ heads = 64
250
+ num_layers = 80
251
+
252
+ model_size = int(size_in_billion * 1000000000)
253
+ A = num_layers * 4 + 3 * 4 * num_layers
254
+ B = 2 * vocab_size
255
+ C = -1 * model_size
256
+ h = (-B + math.sqrt(B**2 - 4 * A * C)) / (2 * A)
257
+ h = math.ceil(h)
258
+ return ModelLayersInfo(
259
+ vocab_size=vocab_size,
260
+ heads=heads,
261
+ hidden_dim=h,
262
+ inter_dim=4 * h,
263
+ num_layers=num_layers,
264
+ )
265
+
266
+
267
+ def _convert_to_mb_model_size(model_size: float, quantization: Optional[str]) -> float:
268
+ extra = 0.0
269
+ fB = 2.0
270
+ size = (model_size * fB) / (1024.0 * 1024.0)
271
+ # bnb_q4 == 4-bit ?
272
+ if quantization == "8-bit" or quantization == "4-bit":
273
+ extra = 0.06 * size
274
+ if quantization == "8-bit":
275
+ size = size / 2
276
+ if quantization == "4-bit":
277
+ size = size / 4
278
+ return size + extra
279
+
280
+
281
+ def _compute_inference_only_activation_memory(
282
+ context_length: int, info: ModelLayersInfo
283
+ ) -> float:
284
+ hidden_dim = info.hidden_dim
285
+ heads = info.heads
286
+ ret = (
287
+ (context_length * hidden_dim * 5 * 2 + (context_length**2) * heads * 2)
288
+ / 1024
289
+ / 1024
290
+ )
291
+ return ret
292
+
293
+
294
+ def _compute_model_size_ggml(info: ModelLayersInfo, quantization: str) -> float:
295
+ assert quantization is not None
296
+ vocab_size = info.vocab_size
297
+ num_layers = info.num_layers
298
+ hidden_dim = info.hidden_dim
299
+ inter_dim = info.inter_dim
300
+ total_params = int(
301
+ vocab_size * hidden_dim * 2
302
+ + num_layers * 4 * (hidden_dim**2)
303
+ + num_layers * 3 * inter_dim * hidden_dim
304
+ )
305
+ other_v_down_params = (
306
+ num_layers * (hidden_dim**2) + num_layers * hidden_dim * inter_dim
307
+ )
308
+ other_param_q2k = (
309
+ total_params - (hidden_dim**2) * num_layers * 2 + 2 * vocab_size * hidden_dim
310
+ )
311
+
312
+ total = 0.0
313
+ v1 = GGML_MULTI_FACTOR_DICT.get(quantization)
314
+ if v1 is not None:
315
+ total = (v1 * total_params) / (32 * 1024 * 1024)
316
+ v2 = GGML_MULTI_FACTOR_DICT_64.get(quantization)
317
+ if v2 is not None:
318
+ total = (v2 * total_params) / (64 * 1024 * 1024)
319
+ v3 = GGML_MULTI_FACTOR_DICT_COMBINE.get(quantization)
320
+ if v3 is not None:
321
+ factors = v3
322
+ if quantization == "q2_K":
323
+ total = (
324
+ (total_params - other_param_q2k) * factors[1]
325
+ + other_param_q2k * factors[0]
326
+ ) / (64 * 1024 * 1024)
327
+ else:
328
+ total = (
329
+ (total_params - other_v_down_params) * factors[1]
330
+ + other_v_down_params * factors[0]
331
+ ) / (64 * 1024 * 1024)
332
+ return total
@@ -60,6 +60,8 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
60
60
  "OmniLMM",
61
61
  "yi-vl-chat",
62
62
  "deepseek-vl-chat",
63
+ "internvl-chat",
64
+ "mini-internvl-chat",
63
65
  ]
64
66
 
65
67
 
@@ -0,0 +1,347 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import base64
15
+ import logging
16
+ import time
17
+ import uuid
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ from io import BytesIO
20
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
21
+
22
+ import requests
23
+ import torch
24
+ from PIL import Image
25
+
26
+ from ....model.utils import select_device
27
+ from ....types import (
28
+ ChatCompletion,
29
+ ChatCompletionChunk,
30
+ ChatCompletionMessage,
31
+ Completion,
32
+ CompletionChoice,
33
+ CompletionUsage,
34
+ )
35
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
36
+ from .core import PytorchChatModel, PytorchGenerateConfig
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
41
+ IMAGENET_STD = (0.229, 0.224, 0.225)
42
+
43
+
44
+ class InternVLChatModel(PytorchChatModel):
45
+ def __init__(self, *args, **kwargs):
46
+ super().__init__(*args, **kwargs)
47
+ self._tokenizer = None
48
+ self._model = None
49
+
50
+ @classmethod
51
+ def match(
52
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
53
+ ) -> bool:
54
+ family = model_family.model_family or model_family.model_name
55
+ if "internvl" in family.lower():
56
+ return True
57
+ return False
58
+
59
+ def load(self, **kwargs):
60
+ from transformers import AutoModel, AutoTokenizer
61
+ from transformers.generation import GenerationConfig
62
+
63
+ device = self._pytorch_model_config.get("device", "auto")
64
+ device = select_device(device)
65
+ # for multiple GPU, set back to auto to make multiple devices work
66
+ device = "auto" if device == "cuda" else device
67
+
68
+ self._tokenizer = AutoTokenizer.from_pretrained(
69
+ self.model_path,
70
+ trust_remote_code=True,
71
+ )
72
+
73
+ kwargs = {
74
+ "torch_dtype": torch.bfloat16,
75
+ "low_cpu_mem_usage": True,
76
+ "trust_remote_code": True,
77
+ "device_map": device,
78
+ }
79
+
80
+ if "Int8" in self.model_spec.quantizations:
81
+ kwargs.update(
82
+ {
83
+ "load_in_8bit": True,
84
+ "device_map": device,
85
+ }
86
+ )
87
+ elif "mini" in self.model_family.model_name:
88
+ kwargs.pop("device_map")
89
+
90
+ self._model = AutoModel.from_pretrained(self.model_path, **kwargs).eval()
91
+
92
+ if "Int8" not in self.model_spec.quantizations:
93
+ self._model.cuda()
94
+
95
+ # Specify hyperparameters for generation
96
+ self._model.generation_config = GenerationConfig.from_pretrained(
97
+ self.model_path,
98
+ trust_remote_code=True,
99
+ )
100
+
101
+ def _message_content_to_intern(self, content):
102
+ def _load_image(_url):
103
+ if _url.startswith("data:"):
104
+ logging.info("Parse url by base64 decoder.")
105
+ # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
106
+ # e.g. f"data:image/jpeg;base64,{base64_image}"
107
+ _type, data = _url.split(";")
108
+ _, ext = _type.split("/")
109
+ data = data[len("base64,") :]
110
+ data = base64.b64decode(data.encode("utf-8"))
111
+ return Image.open(BytesIO(data)).convert("RGB")
112
+ else:
113
+ try:
114
+ response = requests.get(_url)
115
+ except requests.exceptions.MissingSchema:
116
+ return Image.open(_url).convert("RGB")
117
+ else:
118
+ return Image.open(BytesIO(response.content)).convert("RGB")
119
+
120
+ if not isinstance(content, str):
121
+ texts = []
122
+ image_urls = []
123
+ for c in content:
124
+ c_type = c.get("type")
125
+ if c_type == "text":
126
+ texts.append(c["text"])
127
+ elif c_type == "image_url":
128
+ image_urls.append(c["image_url"]["url"])
129
+ image_futures = []
130
+ with ThreadPoolExecutor() as executor:
131
+ for image_url in image_urls:
132
+ fut = executor.submit(_load_image, image_url)
133
+ image_futures.append(fut)
134
+ images = [fut.result() for fut in image_futures]
135
+ text = " ".join(texts)
136
+ if len(images) == 0:
137
+ return text, None
138
+ else:
139
+ return text, images
140
+ return content, None
141
+
142
+ def _history_content_to_intern(
143
+ self,
144
+ chat_history: List[ChatCompletionMessage],
145
+ IMG_START_TOKEN="<img>",
146
+ IMG_END_TOKEN="</img>",
147
+ IMG_CONTEXT_TOKEN="<IMG_CONTEXT>",
148
+ ):
149
+ def _image_to_piexl_values(images):
150
+ load_images = []
151
+ for image in images:
152
+ if image.startswith("data:"):
153
+ logging.info("Parse url by base64 decoder.")
154
+ # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
155
+ # e.g. f"data:image/jpeg;base64,{base64_image}"
156
+ _type, data = image.split(";")
157
+ _, ext = _type.split("/")
158
+ data = data[len("base64,") :]
159
+ data = base64.b64decode(data.encode("utf-8"))
160
+ img = Image.open(BytesIO(data)).convert("RGB")
161
+ pixel_value = (
162
+ self._load_image(img, max_num=6).to(torch.bfloat16).cuda()
163
+ )
164
+ load_images.append(pixel_value)
165
+ else:
166
+ try:
167
+ response = requests.get(image)
168
+ except requests.exceptions.MissingSchema:
169
+ img = Image.open(image).convert("RGB")
170
+ else:
171
+ img = Image.open(BytesIO(response.content)).convert("RGB")
172
+ pixel_value = (
173
+ self._load_image(img, max_num=6).to(torch.bfloat16).cuda()
174
+ )
175
+ load_images.append(pixel_value)
176
+ return torch.cat(tuple(load_images), dim=0)
177
+
178
+ history: List[Tuple] = []
179
+ pixel_values = None
180
+ for i in range(0, len(chat_history), 2):
181
+ tmp = []
182
+ images: List[str] = []
183
+ user = chat_history[i]["content"]
184
+ if isinstance(user, List):
185
+ for content in user:
186
+ c_type = content.get("type")
187
+ if c_type == "text":
188
+ tmp.append(content["text"])
189
+ elif c_type == "image_url" and not history:
190
+ images.append(content["image_url"]["url"])
191
+ if not history:
192
+ pixel_values = _image_to_piexl_values(images)
193
+ image_bs = pixel_values.shape[0]
194
+ image_tokens = (
195
+ IMG_START_TOKEN
196
+ + IMG_CONTEXT_TOKEN * self._model.num_image_token * image_bs
197
+ + IMG_END_TOKEN
198
+ )
199
+ tmp[0] = image_tokens + "\n" + tmp[0]
200
+ else:
201
+ tmp.append(user)
202
+ tmp.append(chat_history[i + 1]["content"])
203
+ history.append(tuple(tmp))
204
+ return history, pixel_values
205
+
206
+ def _find_closest_aspect_ratio(
207
+ self, aspect_ratio, target_ratios, width, height, image_size
208
+ ):
209
+ best_ratio_diff = float("inf")
210
+ best_ratio = (1, 1)
211
+ area = width * height
212
+ for ratio in target_ratios:
213
+ target_aspect_ratio = ratio[0] / ratio[1]
214
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
215
+ if ratio_diff < best_ratio_diff:
216
+ best_ratio_diff = ratio_diff
217
+ best_ratio = ratio
218
+ elif ratio_diff == best_ratio_diff:
219
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
220
+ best_ratio = ratio
221
+ return best_ratio
222
+
223
+ def _dynamic_preprocess(
224
+ self, image, min_num=1, max_num=6, image_size=448, use_thumbnail=False
225
+ ):
226
+ orig_width, orig_height = image.size
227
+ aspect_ratio = orig_width / orig_height
228
+
229
+ # calculate the existing image aspect ratio
230
+ target_ratios = set(
231
+ (i, j)
232
+ for n in range(min_num, max_num + 1)
233
+ for i in range(1, n + 1)
234
+ for j in range(1, n + 1)
235
+ if i * j <= max_num and i * j >= min_num
236
+ )
237
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
238
+
239
+ # find the closest aspect ratio to the target
240
+ target_aspect_ratio = self._find_closest_aspect_ratio(
241
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
242
+ )
243
+
244
+ # calculate the target width and height
245
+ target_width = image_size * target_aspect_ratio[0]
246
+ target_height = image_size * target_aspect_ratio[1]
247
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
248
+
249
+ # resize the image
250
+ resized_img = image.resize((target_width, target_height))
251
+ processed_images = []
252
+ for i in range(blocks):
253
+ box = (
254
+ (i % (target_width // image_size)) * image_size,
255
+ (i // (target_width // image_size)) * image_size,
256
+ ((i % (target_width // image_size)) + 1) * image_size,
257
+ ((i // (target_width // image_size)) + 1) * image_size,
258
+ )
259
+ # split the image
260
+ split_img = resized_img.crop(box)
261
+ processed_images.append(split_img)
262
+ assert len(processed_images) == blocks
263
+ if use_thumbnail and len(processed_images) != 1:
264
+ thumbnail_img = image.resize((image_size, image_size))
265
+ processed_images.append(thumbnail_img)
266
+ return processed_images
267
+
268
+ def _build_transform(self, input_size):
269
+ import torchvision.transforms as T
270
+ from torchvision.transforms.functional import InterpolationMode
271
+
272
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
273
+ transform = T.Compose(
274
+ [
275
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
276
+ T.Resize(
277
+ (input_size, input_size), interpolation=InterpolationMode.BICUBIC
278
+ ),
279
+ T.ToTensor(),
280
+ T.Normalize(mean=MEAN, std=STD),
281
+ ]
282
+ )
283
+ return transform
284
+
285
+ def _load_image(self, image_file, input_size=448, max_num=6):
286
+ transform = self._build_transform(input_size=input_size)
287
+ images = self._dynamic_preprocess(
288
+ image_file, image_size=input_size, use_thumbnail=True, max_num=max_num
289
+ )
290
+ pixel_values = [transform(image) for image in images]
291
+ pixel_values = torch.stack(pixel_values)
292
+ return pixel_values
293
+
294
+ def chat(
295
+ self,
296
+ prompt: Union[str, List[Dict]],
297
+ system_prompt: Optional[str] = None,
298
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
299
+ generate_config: Optional[PytorchGenerateConfig] = None,
300
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
301
+ if generate_config and generate_config.pop("stream"):
302
+ raise Exception(
303
+ f"Chat with model {self.model_family.model_name} does not support stream."
304
+ )
305
+ sanitized_config = {
306
+ "num_beams": 1,
307
+ "max_new_tokens": generate_config.get("max_tokens", 512)
308
+ if generate_config
309
+ else 512,
310
+ "do_sample": False,
311
+ }
312
+
313
+ content, image = self._message_content_to_intern(prompt)
314
+
315
+ history = None
316
+ if chat_history:
317
+ history, pixel_values = self._history_content_to_intern(chat_history)
318
+ else:
319
+ load_images = []
320
+ for img in image:
321
+ pixel_value = self._load_image(img, max_num=6).to(torch.bfloat16).cuda()
322
+ load_images.append(pixel_value)
323
+ pixel_values = torch.cat(tuple(load_images), dim=0)
324
+
325
+ response, history = self._model.chat(
326
+ self._tokenizer,
327
+ pixel_values,
328
+ content,
329
+ sanitized_config,
330
+ history=history,
331
+ return_history=True,
332
+ )
333
+ chunk = Completion(
334
+ id=str(uuid.uuid1()),
335
+ object="text_completion",
336
+ created=int(time.time()),
337
+ model=self.model_uid,
338
+ choices=[
339
+ CompletionChoice(
340
+ index=0, text=response, finish_reason="stop", logprobs=None
341
+ )
342
+ ],
343
+ usage=CompletionUsage(
344
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
345
+ ),
346
+ )
347
+ return self._to_chat_completion(chunk)
@@ -456,6 +456,19 @@ Begin!"""
456
456
  ret += f"<|{role}|>{prompt_style.intra_message_sep}"
457
457
  ret += "<|assistant|>\n"
458
458
  return ret
459
+ elif prompt_style.style_name == "c4ai-command-r":
460
+ ret = (
461
+ f"<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>"
462
+ f"{prompt_style.system_prompt}{prompt_style.inter_message_sep}"
463
+ )
464
+ for i, message in enumerate(chat_history):
465
+ role = get_role(message["role"])
466
+ content = message["content"]
467
+ if content:
468
+ ret += f"{role}{content}{prompt_style.inter_message_sep}"
469
+ else:
470
+ ret += role
471
+ return ret
459
472
  else:
460
473
  raise ValueError(f"Invalid prompt style: {prompt_style.style_name}")
461
474