xinference 0.11.0__py3-none-any.whl → 0.11.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (56) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +30 -0
  3. xinference/client/restful/restful_client.py +29 -0
  4. xinference/core/cache_tracker.py +12 -1
  5. xinference/core/chat_interface.py +10 -4
  6. xinference/core/model.py +2 -2
  7. xinference/core/supervisor.py +30 -2
  8. xinference/core/utils.py +12 -0
  9. xinference/core/worker.py +4 -1
  10. xinference/deploy/cmdline.py +126 -0
  11. xinference/deploy/test/test_cmdline.py +24 -0
  12. xinference/fields.py +3 -1
  13. xinference/model/llm/__init__.py +2 -0
  14. xinference/model/llm/ggml/chatglm.py +98 -13
  15. xinference/model/llm/ggml/llamacpp.py +49 -2
  16. xinference/model/llm/llm_family.json +633 -9
  17. xinference/model/llm/llm_family.py +84 -10
  18. xinference/model/llm/llm_family_modelscope.json +337 -10
  19. xinference/model/llm/memory.py +332 -0
  20. xinference/model/llm/pytorch/chatglm.py +48 -0
  21. xinference/model/llm/pytorch/core.py +25 -6
  22. xinference/model/llm/pytorch/deepseek_vl.py +35 -9
  23. xinference/model/llm/pytorch/intern_vl.py +387 -0
  24. xinference/model/llm/pytorch/internlm2.py +32 -1
  25. xinference/model/llm/pytorch/qwen_vl.py +38 -11
  26. xinference/model/llm/pytorch/utils.py +38 -1
  27. xinference/model/llm/pytorch/yi_vl.py +42 -14
  28. xinference/model/llm/sglang/core.py +31 -9
  29. xinference/model/llm/utils.py +38 -5
  30. xinference/model/llm/vllm/core.py +87 -5
  31. xinference/model/rerank/core.py +23 -1
  32. xinference/model/utils.py +17 -7
  33. xinference/thirdparty/deepseek_vl/models/processing_vlm.py +1 -1
  34. xinference/thirdparty/deepseek_vl/models/siglip_vit.py +2 -2
  35. xinference/thirdparty/llava/mm_utils.py +3 -2
  36. xinference/thirdparty/llava/model/llava_arch.py +1 -1
  37. xinference/thirdparty/omnilmm/chat.py +6 -5
  38. xinference/types.py +10 -1
  39. xinference/web/ui/build/asset-manifest.json +3 -3
  40. xinference/web/ui/build/index.html +1 -1
  41. xinference/web/ui/build/static/js/{main.8e44da4b.js → main.551aa479.js} +3 -3
  42. xinference/web/ui/build/static/js/main.551aa479.js.map +1 -0
  43. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +1 -0
  44. xinference/web/ui/node_modules/.cache/babel-loader/23caf6f1e52c43e983ca3bfd4189f41dbd645fa78f2dfdcd7f6b69bc41678665.json +1 -0
  45. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +1 -0
  46. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/METADATA +10 -8
  47. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/RECORD +52 -50
  48. xinference/web/ui/build/static/js/main.8e44da4b.js.map +0 -1
  49. xinference/web/ui/node_modules/.cache/babel-loader/1870cd6f7054d04e049e363c0a85526584fe25519378609d2838e28d7492bbf1.json +0 -1
  50. xinference/web/ui/node_modules/.cache/babel-loader/5393569d846332075b93b55656716a34f50e0a8c970be789502d7e6c49755fd7.json +0 -1
  51. xinference/web/ui/node_modules/.cache/babel-loader/ddaec68b88e5eff792df1e39a4b4b8b737bfc832293c015660c3c69334e3cf5c.json +0 -1
  52. /xinference/web/ui/build/static/js/{main.8e44da4b.js.LICENSE.txt → main.551aa479.js.LICENSE.txt} +0 -0
  53. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/LICENSE +0 -0
  54. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/WHEEL +0 -0
  55. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/entry_points.txt +0 -0
  56. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/top_level.txt +0 -0
@@ -139,6 +139,12 @@ class YiVLChatModel(PytorchChatModel):
139
139
  generate_config = {}
140
140
 
141
141
  stream = generate_config.get("stream", False)
142
+ stream_options = generate_config.pop("stream_options", None)
143
+ include_usage = (
144
+ stream_options["include_usage"]
145
+ if isinstance(stream_options, dict)
146
+ else False
147
+ )
142
148
 
143
149
  from ....thirdparty.llava.conversation import conv_templates
144
150
  from ....thirdparty.llava.mm_utils import (
@@ -166,11 +172,11 @@ class YiVLChatModel(PytorchChatModel):
166
172
  )
167
173
 
168
174
  images = state.get_images(return_pil=True)
169
- image = images[0]
170
-
171
- image_tensor = self._image_processor.preprocess(image, return_tensors="pt")[
172
- "pixel_values"
173
- ][0]
175
+ if images:
176
+ image = images[0]
177
+ image_tensor = self._image_processor.preprocess(image, return_tensors="pt")[
178
+ "pixel_values"
179
+ ][0]
174
180
 
175
181
  stop_str = state.sep
176
182
  keywords = [stop_str]
@@ -187,7 +193,9 @@ class YiVLChatModel(PytorchChatModel):
187
193
  "input_ids": input_ids,
188
194
  "images": image_tensor.unsqueeze(0)
189
195
  .to(dtype=torch.bfloat16)
190
- .to(self._model.device),
196
+ .to(self._model.device)
197
+ if images
198
+ else None,
191
199
  "streamer": streamer,
192
200
  "do_sample": True,
193
201
  "top_p": float(top_p),
@@ -200,7 +208,7 @@ class YiVLChatModel(PytorchChatModel):
200
208
  t.start()
201
209
 
202
210
  if stream:
203
- it = self._generate_stream(streamer, stop_str)
211
+ it = self._generate_stream(streamer, stop_str, input_ids, include_usage)
204
212
  return self._to_chat_completion_chunks(it)
205
213
  else:
206
214
  c = self._generate(streamer, stop_str)
@@ -229,8 +237,12 @@ class YiVLChatModel(PytorchChatModel):
229
237
  )
230
238
  return c
231
239
 
232
- def _generate_stream(self, streamer, stop_str) -> Iterator[CompletionChunk]:
240
+ def _generate_stream(
241
+ self, streamer, stop_str, input_ids, include_usage
242
+ ) -> Iterator[CompletionChunk]:
233
243
  completion_id = str(uuid.uuid1())
244
+ prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
245
+ prompt_tokens = len(input_ids[0])
234
246
  for i, new_text in enumerate(streamer):
235
247
  if not new_text.endswith(stop_str):
236
248
  completion_choice = CompletionChoice(
@@ -243,10 +255,12 @@ class YiVLChatModel(PytorchChatModel):
243
255
  model=self.model_uid,
244
256
  choices=[completion_choice],
245
257
  )
258
+ completion_tokens = i
259
+ total_tokens = prompt_tokens + completion_tokens
246
260
  completion_usage = CompletionUsage(
247
- prompt_tokens=-1,
248
- completion_tokens=-1,
249
- total_tokens=-1,
261
+ prompt_tokens=prompt_tokens,
262
+ completion_tokens=completion_tokens,
263
+ total_tokens=total_tokens,
250
264
  )
251
265
  chunk["usage"] = completion_usage
252
266
  yield chunk
@@ -262,9 +276,23 @@ class YiVLChatModel(PytorchChatModel):
262
276
  choices=[completion_choice],
263
277
  )
264
278
  completion_usage = CompletionUsage(
265
- prompt_tokens=-1,
266
- completion_tokens=-1,
267
- total_tokens=-1,
279
+ prompt_tokens=prompt_tokens,
280
+ completion_tokens=completion_tokens,
281
+ total_tokens=total_tokens,
268
282
  )
269
283
  chunk["usage"] = completion_usage
270
284
  yield chunk
285
+ if include_usage:
286
+ chunk = CompletionChunk(
287
+ id=completion_id,
288
+ object="text_completion",
289
+ created=int(time.time()),
290
+ model=self.model_uid,
291
+ choices=[],
292
+ )
293
+ chunk["usage"] = CompletionUsage(
294
+ prompt_tokens=prompt_tokens,
295
+ completion_tokens=completion_tokens,
296
+ total_tokens=total_tokens,
297
+ )
298
+ yield chunk
@@ -53,6 +53,7 @@ class SGLANGGenerateConfig(TypedDict, total=False):
53
53
  stop: Optional[Union[str, List[str]]]
54
54
  ignore_eos: bool
55
55
  stream: bool
56
+ stream_options: Optional[Union[dict, None]]
56
57
 
57
58
 
58
59
  try:
@@ -157,6 +158,8 @@ class SGLANGModel(LLM):
157
158
  )
158
159
  generate_config.setdefault("stop", [])
159
160
  generate_config.setdefault("stream", False)
161
+ stream_options = generate_config.get("stream_options")
162
+ generate_config.setdefault("stream_options", stream_options)
160
163
  generate_config.setdefault("ignore_eos", False)
161
164
 
162
165
  return generate_config
@@ -192,7 +195,7 @@ class SGLANGModel(LLM):
192
195
 
193
196
  @staticmethod
194
197
  def _convert_state_to_completion_chunk(
195
- request_id: str, model: str, output_text: str, meta_info: Dict
198
+ request_id: str, model: str, output_text: str
196
199
  ) -> CompletionChunk:
197
200
  choices: List[CompletionChoice] = [
198
201
  CompletionChoice(
@@ -209,13 +212,6 @@ class SGLANGModel(LLM):
209
212
  model=model,
210
213
  choices=choices,
211
214
  )
212
- prompt_tokens = meta_info["prompt_tokens"]
213
- completion_tokens = meta_info["completion_tokens"]
214
- chunk["usage"] = CompletionUsage(
215
- prompt_tokens=prompt_tokens,
216
- completion_tokens=completion_tokens,
217
- total_tokens=prompt_tokens + completion_tokens,
218
- )
219
215
  return chunk
220
216
 
221
217
  @staticmethod
@@ -272,6 +268,9 @@ class SGLANGModel(LLM):
272
268
  "Enter generate, prompt: %s, generate config: %s", prompt, generate_config
273
269
  )
274
270
  stream = sanitized_generate_config.pop("stream")
271
+ stream_options = sanitized_generate_config.pop("stream_options")
272
+ if isinstance(stream_options, dict):
273
+ include_usage = stream_options.pop("include_usage", False)
275
274
  request_id = str(uuid.uuid1())
276
275
  state = pipeline.run(
277
276
  question=prompt,
@@ -289,11 +288,34 @@ class SGLANGModel(LLM):
289
288
  else:
290
289
 
291
290
  async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
291
+ prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
292
292
  async for out, meta_info in state.text_async_iter(
293
293
  var_name="answer", return_meta_data=True
294
294
  ):
295
295
  chunk = self._convert_state_to_completion_chunk(
296
- request_id, self.model_uid, output_text=out, meta_info=meta_info
296
+ request_id, self.model_uid, output_text=out
297
+ )
298
+ prompt_tokens = meta_info["prompt_tokens"]
299
+ completion_tokens = meta_info["completion_tokens"]
300
+ total_tokens = prompt_tokens + completion_tokens
301
+ chunk["usage"] = CompletionUsage(
302
+ prompt_tokens=prompt_tokens,
303
+ completion_tokens=completion_tokens,
304
+ total_tokens=total_tokens,
305
+ )
306
+ yield chunk
307
+ if include_usage:
308
+ chunk = CompletionChunk(
309
+ id=request_id,
310
+ object="text_completion",
311
+ created=int(time.time()),
312
+ model=self.model_uid,
313
+ choices=[],
314
+ )
315
+ chunk["usage"] = CompletionUsage(
316
+ prompt_tokens=prompt_tokens,
317
+ completion_tokens=completion_tokens,
318
+ total_tokens=total_tokens,
297
319
  )
298
320
  yield chunk
299
321
 
@@ -456,6 +456,19 @@ Begin!"""
456
456
  ret += f"<|{role}|>{prompt_style.intra_message_sep}"
457
457
  ret += "<|assistant|>\n"
458
458
  return ret
459
+ elif prompt_style.style_name == "c4ai-command-r":
460
+ ret = (
461
+ f"<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>"
462
+ f"{prompt_style.system_prompt}{prompt_style.inter_message_sep}"
463
+ )
464
+ for i, message in enumerate(chat_history):
465
+ role = get_role(message["role"])
466
+ content = message["content"]
467
+ if content:
468
+ ret += f"{role}{content}{prompt_style.inter_message_sep}"
469
+ else:
470
+ ret += role
471
+ return ret
459
472
  else:
460
473
  raise ValueError(f"Invalid prompt style: {prompt_style.style_name}")
461
474
 
@@ -482,9 +495,6 @@ Begin!"""
482
495
  for i, choice in enumerate(chunk["choices"])
483
496
  ],
484
497
  }
485
- usage = chunk.get("usage")
486
- if usage is not None:
487
- chat_chunk["usage"] = usage
488
498
  return cast(ChatCompletionChunk, chat_chunk)
489
499
 
490
500
  @classmethod
@@ -508,6 +518,19 @@ Begin!"""
508
518
  for i, choice in enumerate(chunk["choices"])
509
519
  ],
510
520
  }
521
+ return cast(ChatCompletionChunk, chat_chunk)
522
+
523
+ @classmethod
524
+ def _get_final_chat_completion_chunk(
525
+ cls, chunk: CompletionChunk
526
+ ) -> ChatCompletionChunk:
527
+ chat_chunk = {
528
+ "id": "chat" + chunk["id"],
529
+ "model": chunk["model"],
530
+ "created": chunk["created"],
531
+ "object": "chat.completion.chunk",
532
+ "choices": [],
533
+ }
511
534
  usage = chunk.get("usage")
512
535
  if usage is not None:
513
536
  chat_chunk["usage"] = usage
@@ -521,7 +544,12 @@ Begin!"""
521
544
  for i, chunk in enumerate(chunks):
522
545
  if i == 0:
523
546
  yield cls._get_first_chat_completion_chunk(chunk)
524
- yield cls._to_chat_completion_chunk(chunk)
547
+ # usage
548
+ choices = chunk.get("choices")
549
+ if not choices:
550
+ yield cls._get_final_chat_completion_chunk(chunk)
551
+ else:
552
+ yield cls._to_chat_completion_chunk(chunk)
525
553
 
526
554
  @classmethod
527
555
  async def _async_to_chat_completion_chunks(
@@ -532,7 +560,12 @@ Begin!"""
532
560
  async for chunk in chunks:
533
561
  if i == 0:
534
562
  yield cls._get_first_chat_completion_chunk(chunk)
535
- yield cls._to_chat_completion_chunk(chunk)
563
+ # usage
564
+ choices = chunk.get("choices")
565
+ if not choices:
566
+ yield cls._get_final_chat_completion_chunk(chunk)
567
+ else:
568
+ yield cls._to_chat_completion_chunk(chunk)
536
569
  i += 1
537
570
 
538
571
  @staticmethod
@@ -37,6 +37,7 @@ from ....types import (
37
37
  CompletionChoice,
38
38
  CompletionChunk,
39
39
  CompletionUsage,
40
+ LoRA,
40
41
  ToolCallFunction,
41
42
  ToolCalls,
42
43
  )
@@ -64,16 +65,19 @@ class VLLMModelConfig(TypedDict, total=False):
64
65
 
65
66
 
66
67
  class VLLMGenerateConfig(TypedDict, total=False):
68
+ lora_name: Optional[str]
67
69
  n: int
68
70
  best_of: Optional[int]
69
71
  presence_penalty: float
70
72
  frequency_penalty: float
71
73
  temperature: float
72
74
  top_p: float
75
+ top_k: int
73
76
  max_tokens: int
74
77
  stop_token_ids: Optional[List[int]]
75
78
  stop: Optional[Union[str, List[str]]]
76
79
  stream: bool # non-sampling param, should not be passed to the engine.
80
+ stream_options: Optional[Union[dict, None]]
77
81
 
78
82
 
79
83
  try:
@@ -90,8 +94,11 @@ VLLM_SUPPORTED_MODELS = [
90
94
  "internlm-16k",
91
95
  "mistral-v0.1",
92
96
  "Yi",
97
+ "Yi-1.5",
93
98
  "code-llama",
94
99
  "code-llama-python",
100
+ "deepseek",
101
+ "deepseek-coder",
95
102
  ]
96
103
  VLLM_SUPPORTED_CHAT_MODELS = [
97
104
  "llama-2-chat",
@@ -106,6 +113,7 @@ VLLM_SUPPORTED_CHAT_MODELS = [
106
113
  "internlm2-chat",
107
114
  "qwen-chat",
108
115
  "Yi-chat",
116
+ "Yi-1.5-chat",
109
117
  "code-llama-instruct",
110
118
  "mistral-instruct-v0.1",
111
119
  "mistral-instruct-v0.2",
@@ -119,6 +127,7 @@ VLLM_SUPPORTED_CHAT_MODELS = [
119
127
  ]
120
128
  if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
121
129
  VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-chat")
130
+ VLLM_SUPPORTED_MODELS.append("codeqwen1.5")
122
131
  VLLM_SUPPORTED_CHAT_MODELS.append("codeqwen1.5-chat")
123
132
 
124
133
  if VLLM_INSTALLED and vllm.__version__ >= "0.3.2":
@@ -130,8 +139,8 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.3":
130
139
 
131
140
  if VLLM_INSTALLED and vllm.__version__ >= "0.4.0":
132
141
  VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-moe-chat")
133
- VLLM_SUPPORTED_MODELS.append("c4ai-command-r-v01")
134
- VLLM_SUPPORTED_MODELS.append("c4ai-command-r-v01-4bit")
142
+ VLLM_SUPPORTED_CHAT_MODELS.append("c4ai-command-r-v01")
143
+ VLLM_SUPPORTED_CHAT_MODELS.append("c4ai-command-r-v01-4bit")
135
144
 
136
145
 
137
146
  class VLLMModel(LLM):
@@ -143,16 +152,30 @@ class VLLMModel(LLM):
143
152
  quantization: str,
144
153
  model_path: str,
145
154
  model_config: Optional[VLLMModelConfig],
155
+ peft_model: Optional[List[LoRA]] = None,
146
156
  ):
157
+ try:
158
+ from vllm.lora.request import LoRARequest
159
+ except ImportError:
160
+ error_message = "Failed to import module 'vllm'"
161
+ installation_guide = [
162
+ "Please make sure 'vllm' is installed. ",
163
+ "You can install it by `pip install vllm`\n",
164
+ ]
165
+
166
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
147
167
  super().__init__(model_uid, model_family, model_spec, quantization, model_path)
148
168
  self._model_config = model_config
149
169
  self._engine = None
170
+ self.lora_modules = peft_model
171
+ self.lora_requests: List[LoRARequest] = []
150
172
 
151
173
  def load(self):
152
174
  try:
153
175
  import vllm
154
176
  from vllm.engine.arg_utils import AsyncEngineArgs
155
177
  from vllm.engine.async_llm_engine import AsyncLLMEngine
178
+ from vllm.lora.request import LoRARequest
156
179
  except ImportError:
157
180
  error_message = "Failed to import module 'vllm'"
158
181
  installation_guide = [
@@ -171,11 +194,33 @@ class VLLMModel(LLM):
171
194
  multiprocessing.set_start_method("fork", force=True)
172
195
 
173
196
  self._model_config = self._sanitize_model_config(self._model_config)
197
+
198
+ if self.lora_modules is None:
199
+ self.lora_requests = []
200
+ else:
201
+ self.lora_requests = [
202
+ LoRARequest(
203
+ lora_name=lora.lora_name,
204
+ lora_int_id=i,
205
+ lora_local_path=lora.local_path,
206
+ )
207
+ for i, lora in enumerate(self.lora_modules, start=1)
208
+ ]
209
+
210
+ enable_lora = len(self.lora_requests) > 0
211
+ max_loras = len(self.lora_requests)
212
+
174
213
  logger.info(
175
214
  f"Loading {self.model_uid} with following model config: {self._model_config}"
215
+ f"Enable lora: {enable_lora}. Lora count: {max_loras}."
176
216
  )
177
217
 
178
- engine_args = AsyncEngineArgs(model=self.model_path, **self._model_config)
218
+ engine_args = AsyncEngineArgs(
219
+ model=self.model_path,
220
+ enable_lora=enable_lora,
221
+ max_loras=max_loras,
222
+ **self._model_config,
223
+ )
179
224
  self._engine = AsyncLLMEngine.from_engine_args(engine_args)
180
225
 
181
226
  def _sanitize_model_config(
@@ -206,6 +251,7 @@ class VLLMModel(LLM):
206
251
  generate_config = {}
207
252
 
208
253
  sanitized = VLLMGenerateConfig()
254
+ sanitized.setdefault("lora_name", generate_config.get("lora_name", None))
209
255
  sanitized.setdefault("n", generate_config.get("n", 1))
210
256
  sanitized.setdefault("best_of", generate_config.get("best_of", None))
211
257
  sanitized.setdefault(
@@ -216,12 +262,16 @@ class VLLMModel(LLM):
216
262
  )
217
263
  sanitized.setdefault("temperature", generate_config.get("temperature", 1.0))
218
264
  sanitized.setdefault("top_p", generate_config.get("top_p", 1.0))
265
+ sanitized.setdefault("top_k", generate_config.get("top_k", -1))
219
266
  sanitized.setdefault("max_tokens", generate_config.get("max_tokens", 1024))
220
267
  sanitized.setdefault("stop", generate_config.get("stop", None))
221
268
  sanitized.setdefault(
222
269
  "stop_token_ids", generate_config.get("stop_token_ids", None)
223
270
  )
224
- sanitized.setdefault("stream", generate_config.get("stream", None))
271
+ sanitized.setdefault("stream", generate_config.get("stream", False))
272
+ sanitized.setdefault(
273
+ "stream_options", generate_config.get("stream_options", None)
274
+ )
225
275
 
226
276
  return sanitized
227
277
 
@@ -338,16 +388,34 @@ class VLLMModel(LLM):
338
388
  "Enter generate, prompt: %s, generate config: %s", prompt, generate_config
339
389
  )
340
390
 
391
+ lora_model = sanitized_generate_config.pop("lora_name")
392
+
393
+ lora_request = None
394
+ if lora_model is not None:
395
+ for lora in self.lora_requests:
396
+ if lora_model == lora.lora_name:
397
+ lora_request = lora
398
+ break
399
+
341
400
  stream = sanitized_generate_config.pop("stream")
401
+ stream_options = sanitized_generate_config.pop("stream_options", None)
402
+ include_usage = (
403
+ stream_options["include_usage"]
404
+ if isinstance(stream_options, dict)
405
+ else False
406
+ )
342
407
  sampling_params = SamplingParams(**sanitized_generate_config)
343
408
  request_id = str(uuid.uuid1())
344
409
 
345
410
  assert self._engine is not None
346
- results_generator = self._engine.generate(prompt, sampling_params, request_id)
411
+ results_generator = self._engine.generate(
412
+ prompt, sampling_params, request_id, lora_request=lora_request
413
+ )
347
414
 
348
415
  async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
349
416
  previous_texts = [""] * sanitized_generate_config["n"]
350
417
  tools_token_filter = ChatModelMixin._tools_token_filter(self.model_family)
418
+ prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
351
419
  async for _request_output in results_generator:
352
420
  chunk = self._convert_request_output_to_completion_chunk(
353
421
  request_id=request_id,
@@ -398,6 +466,20 @@ class VLLMModel(LLM):
398
466
  total_tokens=total_tokens,
399
467
  )
400
468
  yield chunk
469
+ if include_usage:
470
+ chunk = CompletionChunk(
471
+ id=request_id,
472
+ object="text_completion",
473
+ created=int(time.time()),
474
+ model=self.model_uid,
475
+ choices=[],
476
+ )
477
+ chunk["usage"] = CompletionUsage(
478
+ prompt_tokens=prompt_tokens,
479
+ completion_tokens=completion_tokens,
480
+ total_tokens=total_tokens,
481
+ )
482
+ yield chunk
401
483
 
402
484
  if stream:
403
485
  return stream_results()
@@ -46,7 +46,7 @@ def get_rerank_model_descriptions():
46
46
  class RerankModelSpec(CacheableModelSpec):
47
47
  model_name: str
48
48
  language: List[str]
49
- type: Optional[str] = "normal"
49
+ type: Optional[str] = "unknown"
50
50
  model_id: str
51
51
  model_revision: Optional[str]
52
52
  model_hub: str = "huggingface"
@@ -118,6 +118,28 @@ class RerankModel:
118
118
  self._use_fp16 = use_fp16
119
119
  self._model = None
120
120
  self._counter = 0
121
+ if model_spec.type == "unknown":
122
+ model_spec.type = self._auto_detect_type(model_path)
123
+
124
+ @staticmethod
125
+ def _auto_detect_type(model_path):
126
+ """This method may not be stable due to the fact that the tokenizer name may be changed.
127
+ Therefore, we only use this method for unknown model types."""
128
+ from transformers import AutoTokenizer
129
+
130
+ type_mapper = {
131
+ "LlamaTokenizerFast": "LLM-based layerwise",
132
+ "GemmaTokenizerFast": "LLM-based",
133
+ "XLMRobertaTokenizerFast": "normal",
134
+ }
135
+
136
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
137
+ rerank_type = type_mapper.get(type(tokenizer).__name__)
138
+ if rerank_type is None:
139
+ raise Exception(
140
+ f"Can't determine the rerank type based on the tokenizer {tokenizer}"
141
+ )
142
+ return rerank_type
121
143
 
122
144
  def load(self):
123
145
  if self._model_spec.type == "normal":
xinference/model/utils.py CHANGED
@@ -19,6 +19,7 @@ from json import JSONDecodeError
19
19
  from pathlib import Path
20
20
  from typing import Any, Callable, Dict, Optional, Tuple, Union
21
21
 
22
+ import huggingface_hub
22
23
  from fsspec import AbstractFileSystem
23
24
 
24
25
  from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_ENV_MODEL_SRC
@@ -27,6 +28,7 @@ from .core import CacheableModelSpec
27
28
 
28
29
  logger = logging.getLogger(__name__)
29
30
  MAX_ATTEMPTS = 3
31
+ IS_NEW_HUGGINGFACE_HUB: bool = huggingface_hub.__version__ >= "0.23.0"
30
32
 
31
33
 
32
34
  def is_locale_chinese_simplified() -> bool:
@@ -76,6 +78,13 @@ def symlink_local_file(path: str, local_dir: str, relpath: str) -> str:
76
78
  return local_dir_filepath
77
79
 
78
80
 
81
+ def create_symlink(download_dir: str, cache_dir: str):
82
+ for subdir, dirs, files in os.walk(download_dir):
83
+ for file in files:
84
+ relpath = os.path.relpath(os.path.join(subdir, file), download_dir)
85
+ symlink_local_file(os.path.join(subdir, file), cache_dir, relpath)
86
+
87
+
79
88
  def retry_download(
80
89
  download_func: Callable,
81
90
  model_name: str,
@@ -306,22 +315,23 @@ def cache(model_spec: CacheableModelSpec, model_description_type: type):
306
315
  model_spec.model_id,
307
316
  revision=model_spec.model_revision,
308
317
  )
309
- for subdir, dirs, files in os.walk(download_dir):
310
- for file in files:
311
- relpath = os.path.relpath(os.path.join(subdir, file), download_dir)
312
- symlink_local_file(os.path.join(subdir, file), cache_dir, relpath)
318
+ create_symlink(download_dir, cache_dir)
313
319
  else:
314
320
  from huggingface_hub import snapshot_download as hf_download
315
321
 
316
- retry_download(
322
+ use_symlinks = {}
323
+ if not IS_NEW_HUGGINGFACE_HUB:
324
+ use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
325
+ download_dir = retry_download(
317
326
  hf_download,
318
327
  model_spec.model_name,
319
328
  None,
320
329
  model_spec.model_id,
321
330
  revision=model_spec.model_revision,
322
- local_dir=cache_dir,
323
- local_dir_use_symlinks=True,
331
+ **use_symlinks,
324
332
  )
333
+ if IS_NEW_HUGGINGFACE_HUB:
334
+ create_symlink(download_dir, cache_dir)
325
335
  with open(meta_path, "w") as f:
326
336
  import json
327
337
 
@@ -25,8 +25,8 @@ from PIL.Image import Image
25
25
  from transformers import LlamaTokenizerFast
26
26
  from transformers.processing_utils import ProcessorMixin
27
27
 
28
- from .image_processing_vlm import VLMImageProcessor
29
28
  from ..utils.conversation import get_conv_template
29
+ from .image_processing_vlm import VLMImageProcessor
30
30
 
31
31
 
32
32
  class DictOutput(object):
@@ -92,7 +92,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
92
92
  def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
93
93
  # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
94
94
  r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
95
- convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype.
95
+ convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype.
96
96
  Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
97
97
  from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
98
98
  with values outside :math:`[a, b]` redrawn until they are within
@@ -305,7 +305,7 @@ class VisionTransformer(nn.Module):
305
305
  img_size: Input image size.
306
306
  patch_size: Patch size.
307
307
  in_chans: Number of image input channels.
308
- num_classes: Mumber of classes for classification head.
308
+ num_classes: Number of classes for classification head.
309
309
  global_pool: Type of global pooling for final sequence (default: 'token').
310
310
  embed_dim: Transformer embedding dimension.
311
311
  depth: Depth of transformer.
@@ -2,11 +2,12 @@ import base64
2
2
  from io import BytesIO
3
3
 
4
4
  import torch
5
- from .model import LlavaLlamaForCausalLM
6
- from .model.constants import IMAGE_TOKEN_INDEX
7
5
  from PIL import Image
8
6
  from transformers import AutoTokenizer, StoppingCriteria
9
7
 
8
+ from .model import LlavaLlamaForCausalLM
9
+ from .model.constants import IMAGE_TOKEN_INDEX
10
+
10
11
 
11
12
  def load_image_from_base64(image):
12
13
  return Image.open(BytesIO(base64.b64decode(image)))
@@ -17,9 +17,9 @@ import os
17
17
  from abc import ABC, abstractmethod
18
18
 
19
19
  import torch
20
- from .constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, key_info
21
20
 
22
21
  from .clip_encoder.builder import build_vision_tower
22
+ from .constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, key_info
23
23
  from .multimodal_projector.builder import build_vision_projector
24
24
 
25
25
 
@@ -7,11 +7,6 @@ import torch
7
7
  from PIL import Image
8
8
  from transformers import AutoModel, AutoTokenizer
9
9
 
10
- from .model.omnilmm import OmniLMMForCausalLM
11
- from .model.utils import build_transform
12
- from .train.train_utils import omni_preprocess
13
- from .utils import disable_torch_init
14
-
15
10
  DEFAULT_IMAGE_TOKEN = "<image>"
16
11
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
17
12
  DEFAULT_IM_START_TOKEN = "<im_start>"
@@ -21,6 +16,10 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
21
16
  def init_omni_lmm(model_path, device_map):
22
17
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
23
18
 
19
+ from .model.omnilmm import OmniLMMForCausalLM
20
+ from .model.utils import build_transform
21
+ from .utils import disable_torch_init
22
+
24
23
  torch.backends.cuda.matmul.allow_tf32 = True
25
24
  disable_torch_init()
26
25
  model_name = os.path.expanduser(model_path)
@@ -98,6 +97,8 @@ def expand_question_into_multimodal(
98
97
 
99
98
 
100
99
  def wrap_question_for_omni_lmm(question, image_token_len, tokenizer):
100
+ from .train.train_utils import omni_preprocess
101
+
101
102
  question = expand_question_into_multimodal(
102
103
  question,
103
104
  image_token_len,