xinference 0.11.1__py3-none-any.whl → 0.11.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (31) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +30 -0
  3. xinference/client/restful/restful_client.py +29 -0
  4. xinference/core/cache_tracker.py +12 -1
  5. xinference/core/supervisor.py +30 -2
  6. xinference/core/utils.py +12 -0
  7. xinference/core/worker.py +4 -1
  8. xinference/deploy/cmdline.py +126 -0
  9. xinference/deploy/test/test_cmdline.py +24 -0
  10. xinference/model/llm/__init__.py +2 -0
  11. xinference/model/llm/llm_family.json +501 -6
  12. xinference/model/llm/llm_family.py +84 -10
  13. xinference/model/llm/llm_family_modelscope.json +198 -7
  14. xinference/model/llm/memory.py +332 -0
  15. xinference/model/llm/pytorch/core.py +2 -0
  16. xinference/model/llm/pytorch/intern_vl.py +387 -0
  17. xinference/model/llm/utils.py +13 -0
  18. xinference/model/llm/vllm/core.py +5 -2
  19. xinference/model/rerank/core.py +23 -1
  20. xinference/model/utils.py +17 -7
  21. xinference/thirdparty/deepseek_vl/models/processing_vlm.py +1 -1
  22. xinference/thirdparty/deepseek_vl/models/siglip_vit.py +2 -2
  23. xinference/thirdparty/llava/mm_utils.py +3 -2
  24. xinference/thirdparty/llava/model/llava_arch.py +1 -1
  25. xinference/thirdparty/omnilmm/chat.py +6 -5
  26. {xinference-0.11.1.dist-info → xinference-0.11.2.dist-info}/METADATA +8 -7
  27. {xinference-0.11.1.dist-info → xinference-0.11.2.dist-info}/RECORD +31 -29
  28. {xinference-0.11.1.dist-info → xinference-0.11.2.dist-info}/LICENSE +0 -0
  29. {xinference-0.11.1.dist-info → xinference-0.11.2.dist-info}/WHEEL +0 -0
  30. {xinference-0.11.1.dist-info → xinference-0.11.2.dist-info}/entry_points.txt +0 -0
  31. {xinference-0.11.1.dist-info → xinference-0.11.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,387 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import base64
15
+ import logging
16
+ import time
17
+ import uuid
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ from io import BytesIO
20
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
21
+
22
+ import requests
23
+ import torch
24
+ import torchvision.transforms as T
25
+ from PIL import Image
26
+ from torchvision.transforms.functional import InterpolationMode
27
+
28
+ from ....model.utils import select_device
29
+ from ....types import (
30
+ ChatCompletion,
31
+ ChatCompletionChunk,
32
+ ChatCompletionMessage,
33
+ Completion,
34
+ CompletionChoice,
35
+ CompletionUsage,
36
+ )
37
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
38
+ from .core import PytorchChatModel, PytorchGenerateConfig
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
43
+ IMAGENET_STD = (0.229, 0.224, 0.225)
44
+
45
+
46
+ class InternVLChatModel(PytorchChatModel):
47
+ def __init__(self, *args, **kwargs):
48
+ super().__init__(*args, **kwargs)
49
+ self._tokenizer = None
50
+ self._model = None
51
+
52
+ @classmethod
53
+ def match(
54
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
55
+ ) -> bool:
56
+ family = model_family.model_family or model_family.model_name
57
+ if "internvl" in family.lower():
58
+ return True
59
+ return False
60
+
61
+ def load(self, **kwargs):
62
+ from transformers import AutoModel, AutoTokenizer
63
+ from transformers.generation import GenerationConfig
64
+
65
+ device = self._pytorch_model_config.get("device", "auto")
66
+ device = select_device(device)
67
+ # for multiple GPU, set back to auto to make multiple devices work
68
+ device = "auto" if device == "cuda" else device
69
+
70
+ self._tokenizer = AutoTokenizer.from_pretrained(
71
+ self.model_path,
72
+ trust_remote_code=True,
73
+ )
74
+
75
+ kwargs = {
76
+ "torch_dtype": torch.bfloat16,
77
+ "low_cpu_mem_usage": True,
78
+ "trust_remote_code": True,
79
+ "device_map": device,
80
+ }
81
+
82
+ if "Int8" in self.model_spec.quantizations:
83
+ kwargs.update(
84
+ {
85
+ "load_in_8bit": True,
86
+ "device_map": device,
87
+ }
88
+ )
89
+ elif "mini" in self.model_family.model_name:
90
+ kwargs.pop("device_map")
91
+
92
+ self._model = AutoModel.from_pretrained(self.model_path, **kwargs).eval()
93
+
94
+ if "Int8" not in self.model_spec.quantizations:
95
+ self._model.cuda()
96
+
97
+ # Specify hyperparameters for generation
98
+ self._model.generation_config = GenerationConfig.from_pretrained(
99
+ self.model_path,
100
+ trust_remote_code=True,
101
+ )
102
+
103
+ def _message_content_to_intern(self, content):
104
+ def _load_image(_url):
105
+ if _url.startswith("data:"):
106
+ logging.info("Parse url by base64 decoder.")
107
+ # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
108
+ # e.g. f"data:image/jpeg;base64,{base64_image}"
109
+ _type, data = _url.split(";")
110
+ _, ext = _type.split("/")
111
+ data = data[len("base64,") :]
112
+ data = base64.b64decode(data.encode("utf-8"))
113
+ return Image.open(BytesIO(data)).convert("RGB")
114
+ else:
115
+ try:
116
+ response = requests.get(_url)
117
+ except requests.exceptions.MissingSchema:
118
+ return Image.open(_url).convert("RGB")
119
+ else:
120
+ return Image.open(BytesIO(response.content)).convert("RGB")
121
+
122
+ if not isinstance(content, str):
123
+ texts = []
124
+ image_urls = []
125
+ for c in content:
126
+ c_type = c.get("type")
127
+ if c_type == "text":
128
+ texts.append(c["text"])
129
+ elif c_type == "image_url":
130
+ image_urls.append(c["image_url"]["url"])
131
+ image_futures = []
132
+ with ThreadPoolExecutor() as executor:
133
+ for image_url in image_urls:
134
+ fut = executor.submit(_load_image, image_url)
135
+ image_futures.append(fut)
136
+ images = [fut.result() for fut in image_futures]
137
+ text = " ".join(texts)
138
+ if len(images) == 0:
139
+ return text, None
140
+ else:
141
+ return text, images
142
+ return content, None
143
+
144
+ def _history_content_to_intern(
145
+ self,
146
+ chat_history: List[ChatCompletionMessage],
147
+ IMG_START_TOKEN="<img>",
148
+ IMG_END_TOKEN="</img>",
149
+ IMG_CONTEXT_TOKEN="<IMG_CONTEXT>",
150
+ ):
151
+ def _image_to_piexl_values(images):
152
+ load_images = []
153
+ for image in images:
154
+ if image.startswith("data:"):
155
+ logging.info("Parse url by base64 decoder.")
156
+ # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
157
+ # e.g. f"data:image/jpeg;base64,{base64_image}"
158
+ _type, data = image.split(";")
159
+ _, ext = _type.split("/")
160
+ data = data[len("base64,") :]
161
+ data = base64.b64decode(data.encode("utf-8"))
162
+ img = Image.open(BytesIO(data)).convert("RGB")
163
+ pixel_value = (
164
+ self._load_image(img, max_num=6).to(torch.bfloat16).cuda()
165
+ )
166
+ load_images.append(pixel_value)
167
+ else:
168
+ try:
169
+ response = requests.get(image)
170
+ except requests.exceptions.MissingSchema:
171
+ img = Image.open(image).convert("RGB")
172
+ else:
173
+ img = Image.open(BytesIO(response.content)).convert("RGB")
174
+ pixel_value = (
175
+ self._load_image(img, max_num=6).to(torch.bfloat16).cuda()
176
+ )
177
+ load_images.append(pixel_value)
178
+ return torch.cat(tuple(load_images), dim=0)
179
+
180
+ history: List[Tuple] = []
181
+ pixel_values = None
182
+ for i in range(0, len(chat_history), 2):
183
+ tmp = []
184
+ images: List[str] = []
185
+ user = chat_history[i]["content"]
186
+ if isinstance(user, List):
187
+ for content in user:
188
+ c_type = content.get("type")
189
+ if c_type == "text":
190
+ tmp.append(content["text"])
191
+ elif c_type == "image_url" and not history:
192
+ images.append(content["image_url"]["url"])
193
+ if not history:
194
+ pixel_values = _image_to_piexl_values(images)
195
+ image_bs = pixel_values.shape[0]
196
+ image_tokens = (
197
+ IMG_START_TOKEN
198
+ + IMG_CONTEXT_TOKEN * self._model.num_image_token * image_bs
199
+ + IMG_END_TOKEN
200
+ )
201
+ tmp[0] = image_tokens + "\n" + tmp[0]
202
+ else:
203
+ tmp.append(user)
204
+ tmp.append(chat_history[i + 1]["content"])
205
+ history.append(tuple(tmp))
206
+ return history, pixel_values
207
+
208
+ def _load_image(_url):
209
+ if _url.startswith("data:"):
210
+ logging.info("Parse url by base64 decoder.")
211
+ # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
212
+ # e.g. f"data:image/jpeg;base64,{base64_image}"
213
+ _type, data = _url.split(";")
214
+ _, ext = _type.split("/")
215
+ data = data[len("base64,") :]
216
+ data = base64.b64decode(data.encode("utf-8"))
217
+
218
+ return Image.open(BytesIO(data)).convert("RGB")
219
+ else:
220
+ try:
221
+ response = requests.get(_url)
222
+ except requests.exceptions.MissingSchema:
223
+ return Image.open(_url).convert("RGB")
224
+ else:
225
+ return Image.open(BytesIO(response.content)).convert("RGB")
226
+
227
+ if not isinstance(content, str):
228
+ texts = []
229
+ image_urls = []
230
+ for c in content:
231
+ c_type = c.get("type")
232
+ if c_type == "text":
233
+ texts.append(c["text"])
234
+ elif c_type == "image_url":
235
+ image_urls.append(c["image_url"]["url"])
236
+ image_futures = []
237
+ with ThreadPoolExecutor() as executor:
238
+ for image_url in image_urls:
239
+ fut = executor.submit(_load_image, image_url)
240
+ image_futures.append(fut)
241
+ images = [fut.result() for fut in image_futures]
242
+ text = " ".join(texts)
243
+ if len(images) == 0:
244
+ return text
245
+ else:
246
+ return text, images
247
+ return content
248
+
249
+ def _find_closest_aspect_ratio(
250
+ self, aspect_ratio, target_ratios, width, height, image_size
251
+ ):
252
+ best_ratio_diff = float("inf")
253
+ best_ratio = (1, 1)
254
+ area = width * height
255
+ for ratio in target_ratios:
256
+ target_aspect_ratio = ratio[0] / ratio[1]
257
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
258
+ if ratio_diff < best_ratio_diff:
259
+ best_ratio_diff = ratio_diff
260
+ best_ratio = ratio
261
+ elif ratio_diff == best_ratio_diff:
262
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
263
+ best_ratio = ratio
264
+ return best_ratio
265
+
266
+ def _dynamic_preprocess(
267
+ self, image, min_num=1, max_num=6, image_size=448, use_thumbnail=False
268
+ ):
269
+ orig_width, orig_height = image.size
270
+ aspect_ratio = orig_width / orig_height
271
+
272
+ # calculate the existing image aspect ratio
273
+ target_ratios = set(
274
+ (i, j)
275
+ for n in range(min_num, max_num + 1)
276
+ for i in range(1, n + 1)
277
+ for j in range(1, n + 1)
278
+ if i * j <= max_num and i * j >= min_num
279
+ )
280
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
281
+
282
+ # find the closest aspect ratio to the target
283
+ target_aspect_ratio = self._find_closest_aspect_ratio(
284
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
285
+ )
286
+
287
+ # calculate the target width and height
288
+ target_width = image_size * target_aspect_ratio[0]
289
+ target_height = image_size * target_aspect_ratio[1]
290
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
291
+
292
+ # resize the image
293
+ resized_img = image.resize((target_width, target_height))
294
+ processed_images = []
295
+ for i in range(blocks):
296
+ box = (
297
+ (i % (target_width // image_size)) * image_size,
298
+ (i // (target_width // image_size)) * image_size,
299
+ ((i % (target_width // image_size)) + 1) * image_size,
300
+ ((i // (target_width // image_size)) + 1) * image_size,
301
+ )
302
+ # split the image
303
+ split_img = resized_img.crop(box)
304
+ processed_images.append(split_img)
305
+ assert len(processed_images) == blocks
306
+ if use_thumbnail and len(processed_images) != 1:
307
+ thumbnail_img = image.resize((image_size, image_size))
308
+ processed_images.append(thumbnail_img)
309
+ return processed_images
310
+
311
+ def _build_transform(self, input_size):
312
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
313
+ transform = T.Compose(
314
+ [
315
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
316
+ T.Resize(
317
+ (input_size, input_size), interpolation=InterpolationMode.BICUBIC
318
+ ),
319
+ T.ToTensor(),
320
+ T.Normalize(mean=MEAN, std=STD),
321
+ ]
322
+ )
323
+ return transform
324
+
325
+ def _load_image(self, image_file, input_size=448, max_num=6):
326
+ transform = self._build_transform(input_size=input_size)
327
+ images = self._dynamic_preprocess(
328
+ image_file, image_size=input_size, use_thumbnail=True, max_num=max_num
329
+ )
330
+ pixel_values = [transform(image) for image in images]
331
+ pixel_values = torch.stack(pixel_values)
332
+ return pixel_values
333
+
334
+ def chat(
335
+ self,
336
+ prompt: Union[str, List[Dict]],
337
+ system_prompt: Optional[str] = None,
338
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
339
+ generate_config: Optional[PytorchGenerateConfig] = None,
340
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
341
+ if generate_config and generate_config.pop("stream"):
342
+ raise Exception(
343
+ f"Chat with model {self.model_family.model_name} does not support stream."
344
+ )
345
+ sanitized_config = {
346
+ "num_beams": 1,
347
+ "max_new_tokens": generate_config.get("max_tokens", 512)
348
+ if generate_config
349
+ else 512,
350
+ "do_sample": False,
351
+ }
352
+
353
+ content, image = self._message_content_to_intern(prompt)
354
+
355
+ history = None
356
+ if chat_history:
357
+ history, pixel_values = self._history_content_to_intern(chat_history)
358
+ else:
359
+ load_images = []
360
+ for img in image:
361
+ pixel_value = self._load_image(img, max_num=6).to(torch.bfloat16).cuda()
362
+ load_images.append(pixel_value)
363
+ pixel_values = torch.cat(tuple(load_images), dim=0)
364
+
365
+ response, history = self._model.chat(
366
+ self._tokenizer,
367
+ pixel_values,
368
+ content,
369
+ sanitized_config,
370
+ history=history,
371
+ return_history=True,
372
+ )
373
+ chunk = Completion(
374
+ id=str(uuid.uuid1()),
375
+ object="text_completion",
376
+ created=int(time.time()),
377
+ model=self.model_uid,
378
+ choices=[
379
+ CompletionChoice(
380
+ index=0, text=response, finish_reason="stop", logprobs=None
381
+ )
382
+ ],
383
+ usage=CompletionUsage(
384
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
385
+ ),
386
+ )
387
+ return self._to_chat_completion(chunk)
@@ -456,6 +456,19 @@ Begin!"""
456
456
  ret += f"<|{role}|>{prompt_style.intra_message_sep}"
457
457
  ret += "<|assistant|>\n"
458
458
  return ret
459
+ elif prompt_style.style_name == "c4ai-command-r":
460
+ ret = (
461
+ f"<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>"
462
+ f"{prompt_style.system_prompt}{prompt_style.inter_message_sep}"
463
+ )
464
+ for i, message in enumerate(chat_history):
465
+ role = get_role(message["role"])
466
+ content = message["content"]
467
+ if content:
468
+ ret += f"{role}{content}{prompt_style.inter_message_sep}"
469
+ else:
470
+ ret += role
471
+ return ret
459
472
  else:
460
473
  raise ValueError(f"Invalid prompt style: {prompt_style.style_name}")
461
474
 
@@ -97,6 +97,8 @@ VLLM_SUPPORTED_MODELS = [
97
97
  "Yi-1.5",
98
98
  "code-llama",
99
99
  "code-llama-python",
100
+ "deepseek",
101
+ "deepseek-coder",
100
102
  ]
101
103
  VLLM_SUPPORTED_CHAT_MODELS = [
102
104
  "llama-2-chat",
@@ -125,6 +127,7 @@ VLLM_SUPPORTED_CHAT_MODELS = [
125
127
  ]
126
128
  if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
127
129
  VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-chat")
130
+ VLLM_SUPPORTED_MODELS.append("codeqwen1.5")
128
131
  VLLM_SUPPORTED_CHAT_MODELS.append("codeqwen1.5-chat")
129
132
 
130
133
  if VLLM_INSTALLED and vllm.__version__ >= "0.3.2":
@@ -136,8 +139,8 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.3":
136
139
 
137
140
  if VLLM_INSTALLED and vllm.__version__ >= "0.4.0":
138
141
  VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-moe-chat")
139
- VLLM_SUPPORTED_MODELS.append("c4ai-command-r-v01")
140
- VLLM_SUPPORTED_MODELS.append("c4ai-command-r-v01-4bit")
142
+ VLLM_SUPPORTED_CHAT_MODELS.append("c4ai-command-r-v01")
143
+ VLLM_SUPPORTED_CHAT_MODELS.append("c4ai-command-r-v01-4bit")
141
144
 
142
145
 
143
146
  class VLLMModel(LLM):
@@ -46,7 +46,7 @@ def get_rerank_model_descriptions():
46
46
  class RerankModelSpec(CacheableModelSpec):
47
47
  model_name: str
48
48
  language: List[str]
49
- type: Optional[str] = "normal"
49
+ type: Optional[str] = "unknown"
50
50
  model_id: str
51
51
  model_revision: Optional[str]
52
52
  model_hub: str = "huggingface"
@@ -118,6 +118,28 @@ class RerankModel:
118
118
  self._use_fp16 = use_fp16
119
119
  self._model = None
120
120
  self._counter = 0
121
+ if model_spec.type == "unknown":
122
+ model_spec.type = self._auto_detect_type(model_path)
123
+
124
+ @staticmethod
125
+ def _auto_detect_type(model_path):
126
+ """This method may not be stable due to the fact that the tokenizer name may be changed.
127
+ Therefore, we only use this method for unknown model types."""
128
+ from transformers import AutoTokenizer
129
+
130
+ type_mapper = {
131
+ "LlamaTokenizerFast": "LLM-based layerwise",
132
+ "GemmaTokenizerFast": "LLM-based",
133
+ "XLMRobertaTokenizerFast": "normal",
134
+ }
135
+
136
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
137
+ rerank_type = type_mapper.get(type(tokenizer).__name__)
138
+ if rerank_type is None:
139
+ raise Exception(
140
+ f"Can't determine the rerank type based on the tokenizer {tokenizer}"
141
+ )
142
+ return rerank_type
121
143
 
122
144
  def load(self):
123
145
  if self._model_spec.type == "normal":
xinference/model/utils.py CHANGED
@@ -19,6 +19,7 @@ from json import JSONDecodeError
19
19
  from pathlib import Path
20
20
  from typing import Any, Callable, Dict, Optional, Tuple, Union
21
21
 
22
+ import huggingface_hub
22
23
  from fsspec import AbstractFileSystem
23
24
 
24
25
  from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_ENV_MODEL_SRC
@@ -27,6 +28,7 @@ from .core import CacheableModelSpec
27
28
 
28
29
  logger = logging.getLogger(__name__)
29
30
  MAX_ATTEMPTS = 3
31
+ IS_NEW_HUGGINGFACE_HUB: bool = huggingface_hub.__version__ >= "0.23.0"
30
32
 
31
33
 
32
34
  def is_locale_chinese_simplified() -> bool:
@@ -76,6 +78,13 @@ def symlink_local_file(path: str, local_dir: str, relpath: str) -> str:
76
78
  return local_dir_filepath
77
79
 
78
80
 
81
+ def create_symlink(download_dir: str, cache_dir: str):
82
+ for subdir, dirs, files in os.walk(download_dir):
83
+ for file in files:
84
+ relpath = os.path.relpath(os.path.join(subdir, file), download_dir)
85
+ symlink_local_file(os.path.join(subdir, file), cache_dir, relpath)
86
+
87
+
79
88
  def retry_download(
80
89
  download_func: Callable,
81
90
  model_name: str,
@@ -306,22 +315,23 @@ def cache(model_spec: CacheableModelSpec, model_description_type: type):
306
315
  model_spec.model_id,
307
316
  revision=model_spec.model_revision,
308
317
  )
309
- for subdir, dirs, files in os.walk(download_dir):
310
- for file in files:
311
- relpath = os.path.relpath(os.path.join(subdir, file), download_dir)
312
- symlink_local_file(os.path.join(subdir, file), cache_dir, relpath)
318
+ create_symlink(download_dir, cache_dir)
313
319
  else:
314
320
  from huggingface_hub import snapshot_download as hf_download
315
321
 
316
- retry_download(
322
+ use_symlinks = {}
323
+ if not IS_NEW_HUGGINGFACE_HUB:
324
+ use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
325
+ download_dir = retry_download(
317
326
  hf_download,
318
327
  model_spec.model_name,
319
328
  None,
320
329
  model_spec.model_id,
321
330
  revision=model_spec.model_revision,
322
- local_dir=cache_dir,
323
- local_dir_use_symlinks=True,
331
+ **use_symlinks,
324
332
  )
333
+ if IS_NEW_HUGGINGFACE_HUB:
334
+ create_symlink(download_dir, cache_dir)
325
335
  with open(meta_path, "w") as f:
326
336
  import json
327
337
 
@@ -25,8 +25,8 @@ from PIL.Image import Image
25
25
  from transformers import LlamaTokenizerFast
26
26
  from transformers.processing_utils import ProcessorMixin
27
27
 
28
- from .image_processing_vlm import VLMImageProcessor
29
28
  from ..utils.conversation import get_conv_template
29
+ from .image_processing_vlm import VLMImageProcessor
30
30
 
31
31
 
32
32
  class DictOutput(object):
@@ -92,7 +92,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
92
92
  def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
93
93
  # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
94
94
  r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
95
- convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype.
95
+ convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype.
96
96
  Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
97
97
  from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
98
98
  with values outside :math:`[a, b]` redrawn until they are within
@@ -305,7 +305,7 @@ class VisionTransformer(nn.Module):
305
305
  img_size: Input image size.
306
306
  patch_size: Patch size.
307
307
  in_chans: Number of image input channels.
308
- num_classes: Mumber of classes for classification head.
308
+ num_classes: Number of classes for classification head.
309
309
  global_pool: Type of global pooling for final sequence (default: 'token').
310
310
  embed_dim: Transformer embedding dimension.
311
311
  depth: Depth of transformer.
@@ -2,11 +2,12 @@ import base64
2
2
  from io import BytesIO
3
3
 
4
4
  import torch
5
- from .model import LlavaLlamaForCausalLM
6
- from .model.constants import IMAGE_TOKEN_INDEX
7
5
  from PIL import Image
8
6
  from transformers import AutoTokenizer, StoppingCriteria
9
7
 
8
+ from .model import LlavaLlamaForCausalLM
9
+ from .model.constants import IMAGE_TOKEN_INDEX
10
+
10
11
 
11
12
  def load_image_from_base64(image):
12
13
  return Image.open(BytesIO(base64.b64decode(image)))
@@ -17,9 +17,9 @@ import os
17
17
  from abc import ABC, abstractmethod
18
18
 
19
19
  import torch
20
- from .constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, key_info
21
20
 
22
21
  from .clip_encoder.builder import build_vision_tower
22
+ from .constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, key_info
23
23
  from .multimodal_projector.builder import build_vision_projector
24
24
 
25
25
 
@@ -7,11 +7,6 @@ import torch
7
7
  from PIL import Image
8
8
  from transformers import AutoModel, AutoTokenizer
9
9
 
10
- from .model.omnilmm import OmniLMMForCausalLM
11
- from .model.utils import build_transform
12
- from .train.train_utils import omni_preprocess
13
- from .utils import disable_torch_init
14
-
15
10
  DEFAULT_IMAGE_TOKEN = "<image>"
16
11
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
17
12
  DEFAULT_IM_START_TOKEN = "<im_start>"
@@ -21,6 +16,10 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
21
16
  def init_omni_lmm(model_path, device_map):
22
17
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
23
18
 
19
+ from .model.omnilmm import OmniLMMForCausalLM
20
+ from .model.utils import build_transform
21
+ from .utils import disable_torch_init
22
+
24
23
  torch.backends.cuda.matmul.allow_tf32 = True
25
24
  disable_torch_init()
26
25
  model_name = os.path.expanduser(model_path)
@@ -98,6 +97,8 @@ def expand_question_into_multimodal(
98
97
 
99
98
 
100
99
  def wrap_question_for_omni_lmm(question, image_token_len, tokenizer):
100
+ from .train.train_utils import omni_preprocess
101
+
101
102
  question = expand_question_into_multimodal(
102
103
  question,
103
104
  image_token_len,