xinference 0.12.3__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (71) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +6 -6
  3. xinference/client/restful/restful_client.py +0 -2
  4. xinference/core/model.py +21 -4
  5. xinference/core/scheduler.py +2 -0
  6. xinference/core/worker.py +74 -45
  7. xinference/deploy/utils.py +33 -2
  8. xinference/model/llm/__init__.py +5 -0
  9. xinference/model/llm/llm_family.json +240 -1
  10. xinference/model/llm/llm_family.py +32 -8
  11. xinference/model/llm/llm_family_modelscope.json +192 -0
  12. xinference/model/llm/mlx/__init__.py +13 -0
  13. xinference/model/llm/mlx/core.py +408 -0
  14. xinference/model/llm/pytorch/chatglm.py +2 -9
  15. xinference/model/llm/pytorch/cogvlm2.py +206 -21
  16. xinference/model/llm/pytorch/core.py +213 -40
  17. xinference/model/llm/pytorch/glm4v.py +171 -15
  18. xinference/model/llm/pytorch/qwen_vl.py +168 -7
  19. xinference/model/llm/pytorch/utils.py +53 -62
  20. xinference/model/llm/utils.py +24 -5
  21. xinference/model/rerank/core.py +5 -0
  22. xinference/thirdparty/deepseek_vl/serve/__init__.py +13 -0
  23. xinference/thirdparty/deepseek_vl/serve/app_deepseek.py +510 -0
  24. xinference/thirdparty/deepseek_vl/serve/app_modules/__init__.py +13 -0
  25. xinference/thirdparty/deepseek_vl/serve/app_modules/gradio_utils.py +94 -0
  26. xinference/thirdparty/deepseek_vl/serve/app_modules/overwrites.py +81 -0
  27. xinference/thirdparty/deepseek_vl/serve/app_modules/presets.py +96 -0
  28. xinference/thirdparty/deepseek_vl/serve/app_modules/utils.py +229 -0
  29. xinference/thirdparty/deepseek_vl/serve/inference.py +170 -0
  30. xinference/web/ui/build/asset-manifest.json +3 -3
  31. xinference/web/ui/build/index.html +1 -1
  32. xinference/web/ui/build/static/js/main.0fb6f3ab.js +3 -0
  33. xinference/web/ui/build/static/js/main.0fb6f3ab.js.map +1 -0
  34. xinference/web/ui/node_modules/.cache/babel-loader/0f6b391abec76271137faad13a3793fe7acc1024e8cd2269c147b653ecd3a73b.json +1 -0
  35. xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +1 -0
  36. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +1 -0
  37. xinference/web/ui/node_modules/.cache/babel-loader/2c63090c842376cdd368c3ded88a333ef40d94785747651343040a6f7872a223.json +1 -0
  38. xinference/web/ui/node_modules/.cache/babel-loader/30a0c79d8025d6441eb75b2df5bc2750a14f30119c869ef02570d294dff65c2f.json +1 -0
  39. xinference/web/ui/node_modules/.cache/babel-loader/40486e655c3c5801f087e2cf206c0b5511aaa0dfdba78046b7181bf9c17e54c5.json +1 -0
  40. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +1 -0
  41. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +1 -0
  42. xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +1 -0
  43. xinference/web/ui/node_modules/.cache/babel-loader/b5507cd57f16a3a230aa0128e39fe103e928de139ea29e2679e4c64dcbba3b3a.json +1 -0
  44. xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +1 -0
  45. xinference/web/ui/node_modules/.cache/babel-loader/d779b915f83f9c7b5a72515b6932fdd114f1822cef90ae01cc0d12bca59abc2d.json +1 -0
  46. xinference/web/ui/node_modules/.cache/babel-loader/d87824cb266194447a9c0c69ebab2d507bfc3e3148976173760d18c035e9dd26.json +1 -0
  47. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +1 -0
  48. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +1 -0
  49. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/METADATA +4 -1
  50. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/RECORD +55 -44
  51. xinference/web/ui/build/static/js/main.77dd47c3.js +0 -3
  52. xinference/web/ui/build/static/js/main.77dd47c3.js.map +0 -1
  53. xinference/web/ui/node_modules/.cache/babel-loader/0cd591866aa345566e0b63fb51ff2043e163a770af6fdc2f3bad395d046353e2.json +0 -1
  54. xinference/web/ui/node_modules/.cache/babel-loader/37c1476717199863bbba1530e3513a9368f8f73001b75b4a85c2075956308027.json +0 -1
  55. xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +0 -1
  56. xinference/web/ui/node_modules/.cache/babel-loader/3fa1f69162f9c6dc0f6a6e21b64d49d6b8e6fa8dfa59a82cf829931c5f97d99f.json +0 -1
  57. xinference/web/ui/node_modules/.cache/babel-loader/46edc1fe657dfedb2e673148332bb442c6eb98f09f2592c389209e376510afa5.json +0 -1
  58. xinference/web/ui/node_modules/.cache/babel-loader/62e257ed9016471035fa1a7da57c9e2a4250974ed566b4d1295873d747c68eb2.json +0 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/82db357f3fd5b32215d747ee593f69ff06c95ad6cde37f71a96c8290aaab64c0.json +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/bc6da27195ec4607bb472bf61f97c928ad4966fa64e4c2247661bedb7400abba.json +0 -1
  63. xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +0 -1
  64. xinference/web/ui/node_modules/.cache/babel-loader/e606671420d2937102c3c34b4b04056c11736408c1d3347b8cf42dfe61fb394b.json +0 -1
  65. xinference/web/ui/node_modules/.cache/babel-loader/f118f99c22b713c678c1209c4e1dd43fe86e3f6e801a4c0c35d3bbf41fd05fe6.json +0 -1
  66. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +0 -1
  67. /xinference/web/ui/build/static/js/{main.77dd47c3.js.LICENSE.txt → main.0fb6f3ab.js.LICENSE.txt} +0 -0
  68. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/LICENSE +0 -0
  69. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/WHEEL +0 -0
  70. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/entry_points.txt +0 -0
  71. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/top_level.txt +0 -0
@@ -16,9 +16,14 @@ import logging
16
16
  import operator
17
17
  import tempfile
18
18
  import time
19
+ import typing
19
20
  import uuid
20
- from typing import Dict, Iterator, List, Optional, Union
21
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
21
22
 
23
+ import torch
24
+ from transformers import PreTrainedTokenizer
25
+
26
+ from ....core.scheduler import InferenceRequest
22
27
  from ....model.utils import select_device
23
28
  from ....types import (
24
29
  ChatCompletion,
@@ -31,6 +36,7 @@ from ....types import (
31
36
  )
32
37
  from ..llm_family import LLMFamilyV1, LLMSpecV1
33
38
  from .core import PytorchChatModel, PytorchGenerateConfig
39
+ from .utils import pad_prefill_tokens
34
40
 
35
41
  logger = logging.getLogger(__name__)
36
42
 
@@ -40,6 +46,7 @@ class QwenVLChatModel(PytorchChatModel):
40
46
  super().__init__(*args, **kwargs)
41
47
  self._tokenizer = None
42
48
  self._model = None
49
+ self._device = None
43
50
 
44
51
  @classmethod
45
52
  def match(
@@ -62,6 +69,7 @@ class QwenVLChatModel(PytorchChatModel):
62
69
 
63
70
  device = self._pytorch_model_config.get("device", "auto")
64
71
  device = select_device(device)
72
+ self._device = device
65
73
  # for multiple GPU, set back to auto to make multiple devices work
66
74
  device = "auto" if device == "cuda" else device
67
75
 
@@ -120,13 +128,11 @@ class QwenVLChatModel(PytorchChatModel):
120
128
  return self._tokenizer.from_list_format(content)
121
129
  return content
122
130
 
123
- def chat(
131
+ def _get_prompt_and_chat_history(
124
132
  self,
125
133
  prompt: Union[str, List[Dict]],
126
- system_prompt: Optional[str] = None,
127
134
  chat_history: Optional[List[ChatCompletionMessage]] = None,
128
- generate_config: Optional[PytorchGenerateConfig] = None,
129
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
135
+ ):
130
136
  prompt = self._message_content_to_qwen(prompt)
131
137
  # Convert openai history to qwen vl history
132
138
  qwen_history = []
@@ -141,6 +147,18 @@ class QwenVLChatModel(PytorchChatModel):
141
147
  if len(query_to_response) == 2:
142
148
  qwen_history.append(query_to_response)
143
149
  query_to_response = []
150
+ return prompt, qwen_history
151
+
152
+ def chat(
153
+ self,
154
+ prompt: Union[str, List[Dict]],
155
+ system_prompt: Optional[str] = None,
156
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
157
+ generate_config: Optional[PytorchGenerateConfig] = None,
158
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
159
+ prompt, qwen_history = self._get_prompt_and_chat_history(
160
+ prompt, chat_history=chat_history
161
+ )
144
162
 
145
163
  stream = generate_config.get("stream", False) if generate_config else False
146
164
  stream_options = (
@@ -152,10 +170,10 @@ class QwenVLChatModel(PytorchChatModel):
152
170
  else False
153
171
  )
154
172
  if stream:
155
- it = self._generate_stream(prompt, qwen_history, include_usage)
173
+ it = self._generate_stream(prompt, qwen_history, include_usage) # type: ignore
156
174
  return self._to_chat_completion_chunks(it)
157
175
  else:
158
- c = self._generate(prompt, qwen_history)
176
+ c = self._generate(prompt, qwen_history) # type: ignore
159
177
  return self._to_chat_completion(c)
160
178
 
161
179
  def _generate(self, prompt: str, qwen_history: List) -> Completion:
@@ -244,3 +262,146 @@ class QwenVLChatModel(PytorchChatModel):
244
262
  total_tokens=total_tokens,
245
263
  )
246
264
  yield chunk
265
+
266
+ @staticmethod
267
+ def get_batch_size_and_seq_len_indexes_from_kv() -> Tuple[int, int]:
268
+ """
269
+ Qwen-vl is very special for its kv_cache impl.
270
+ Its dimension is `bs * seq_len * head_num * dim`.
271
+ See https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/modeling_qwen.py
272
+ """
273
+ return 0, 1
274
+
275
+ @staticmethod
276
+ @typing.no_type_check
277
+ def make_context(
278
+ tokenizer: PreTrainedTokenizer,
279
+ query: str,
280
+ history: List[Tuple[str, str]] = None,
281
+ system: str = "",
282
+ max_window_size: int = 6144,
283
+ chat_format: str = "chatml",
284
+ ):
285
+ """
286
+ This function is from https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py.
287
+ Use this function to get input_ids with image.
288
+ """
289
+ if history is None:
290
+ history = []
291
+
292
+ if chat_format == "chatml":
293
+ im_start, im_end = "<|im_start|>", "<|im_end|>"
294
+ im_start_tokens = [tokenizer.im_start_id]
295
+ im_end_tokens = [tokenizer.im_end_id]
296
+ nl_tokens = tokenizer.encode("\n")
297
+
298
+ def _tokenize_str(role, content):
299
+ return f"{role}\n{content}", tokenizer.encode(
300
+ role, allowed_special=set(tokenizer.IMAGE_ST)
301
+ ) + nl_tokens + tokenizer.encode(
302
+ content, allowed_special=set(tokenizer.IMAGE_ST)
303
+ )
304
+
305
+ system_text, system_tokens_part = _tokenize_str("system", system)
306
+ system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
307
+
308
+ raw_text = ""
309
+ context_tokens = []
310
+
311
+ for turn_query, turn_response in reversed(history):
312
+ query_text, query_tokens_part = _tokenize_str("user", turn_query)
313
+ query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
314
+ if turn_response is not None:
315
+ response_text, response_tokens_part = _tokenize_str(
316
+ "assistant", turn_response
317
+ )
318
+ response_tokens = (
319
+ im_start_tokens + response_tokens_part + im_end_tokens
320
+ )
321
+
322
+ next_context_tokens = (
323
+ nl_tokens + query_tokens + nl_tokens + response_tokens
324
+ )
325
+ prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
326
+ else:
327
+ next_context_tokens = nl_tokens + query_tokens + nl_tokens
328
+ prev_chat = f"\n{im_start}{query_text}{im_end}\n"
329
+
330
+ current_context_size = (
331
+ len(system_tokens) + len(next_context_tokens) + len(context_tokens)
332
+ )
333
+ if current_context_size < max_window_size:
334
+ context_tokens = next_context_tokens + context_tokens
335
+ raw_text = prev_chat + raw_text
336
+ else:
337
+ break
338
+
339
+ context_tokens = system_tokens + context_tokens
340
+ raw_text = f"{im_start}{system_text}{im_end}" + raw_text
341
+ context_tokens += (
342
+ nl_tokens
343
+ + im_start_tokens
344
+ + _tokenize_str("user", query)[1]
345
+ + im_end_tokens
346
+ + nl_tokens
347
+ + im_start_tokens
348
+ + tokenizer.encode("assistant")
349
+ + nl_tokens
350
+ )
351
+ raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
352
+
353
+ elif chat_format == "raw":
354
+ raw_text = query
355
+ context_tokens = tokenizer.encode(raw_text)
356
+ else:
357
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
358
+
359
+ return raw_text, context_tokens
360
+
361
+ def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
362
+ prompt, qwen_history = self._get_prompt_and_chat_history(
363
+ prompt, chat_history=chat_history
364
+ )
365
+ _, context_tokens = self.make_context(self._tokenizer, prompt, qwen_history)
366
+ return context_tokens
367
+
368
+ def prepare_sanitize_generate_config(self, req: InferenceRequest):
369
+ """
370
+ Refer to https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/generation_config.json
371
+ """
372
+ raw_config = req.inference_kwargs.get("raw_params", {})
373
+ top_p = raw_config.get("top_p", None)
374
+ if top_p is None:
375
+ raw_config["top_p"] = 0.3
376
+ top_k = raw_config.get("top_k", None)
377
+ if top_k is None:
378
+ raw_config["top_k"] = 0
379
+ return raw_config
380
+
381
+ def build_prefill_inputs(self, prompts: List, req_list: List[InferenceRequest]):
382
+ context_len = self.get_context_len()
383
+ inputs = pad_prefill_tokens(prompts, context_len, req_list)
384
+ input_ids = torch.as_tensor(
385
+ pad_prefill_tokens(inputs, context_len, req_list), device=self._device
386
+ )
387
+ return input_ids
388
+
389
+ def build_prefill_position_ids(
390
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
391
+ ):
392
+ """
393
+ Qwen-vl fill `1` for position_ids padding
394
+ """
395
+ res = []
396
+ for r in reqs:
397
+ real_seq_len = seq_length - r.padding_len
398
+ res.append(
399
+ torch.cat(
400
+ [
401
+ torch.full((r.padding_len,), 1, dtype=torch.long),
402
+ torch.arange(0, real_seq_len, dtype=torch.long),
403
+ ]
404
+ )
405
+ )
406
+ r.extra_kwargs["max_position_id"] = real_seq_len - 1
407
+ return torch.stack(res).to(self._device)
@@ -17,7 +17,7 @@ import logging
17
17
  import os
18
18
  import time
19
19
  import uuid
20
- from typing import Dict, Iterable, Iterator, List, Optional, Tuple
20
+ from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Tuple
21
21
 
22
22
  import torch
23
23
  from transformers.cache_utils import DynamicCache
@@ -39,6 +39,10 @@ from ....types import (
39
39
  max_tokens_field,
40
40
  )
41
41
 
42
+ if TYPE_CHECKING:
43
+ from ...llm.pytorch.core import PytorchModel
44
+
45
+
42
46
  logger = logging.getLogger(__name__)
43
47
 
44
48
 
@@ -414,6 +418,19 @@ def get_max_src_len(context_len: int, r: InferenceRequest) -> int:
414
418
  return context_len - max_new_tokens - 8
415
419
 
416
420
 
421
+ def pad_prefill_tokens(
422
+ input_ids: List[List[int]], context_len: int, req_list: List[InferenceRequest]
423
+ ):
424
+ prompt_tokens = []
425
+ for i, input_id in enumerate(input_ids):
426
+ req = req_list[i]
427
+ max_src_len = get_max_src_len(context_len, req)
428
+ req.prompt_tokens = input_id[-max_src_len:]
429
+ prompt_tokens.append(req.prompt_tokens)
430
+ _pad_seqs_inplace(prompt_tokens, req_list, 0)
431
+ return prompt_tokens
432
+
433
+
417
434
  def _get_completion_chunk(
418
435
  output: str,
419
436
  chunk_id: str,
@@ -481,23 +498,33 @@ def _get_completion(
481
498
  return completion
482
499
 
483
500
 
501
+ def _get_pad_param(seq_len_idx: int, pad_len: int) -> Tuple:
502
+ dimensions = [0] * 8
503
+ dimensions[-2 * (seq_len_idx + 1)] = pad_len
504
+ return tuple(dimensions)
505
+
506
+
484
507
  def _merge_kv_cache(
485
- past_kv: Tuple[Tuple[torch.Tensor]], new_kv: Tuple[Tuple[torch.Tensor]]
508
+ xinf_model_obj: "PytorchModel",
509
+ past_kv: Tuple[Tuple[torch.Tensor]],
510
+ new_kv: Tuple[Tuple[torch.Tensor]],
486
511
  ):
487
512
  from torch.nn.functional import pad
488
513
 
514
+ _, seq_len_idx = xinf_model_obj.get_batch_size_and_seq_len_indexes_from_kv()
489
515
  past_cache = DynamicCache.from_legacy_cache(past_kv)
490
516
  new_cache = DynamicCache.from_legacy_cache(new_kv)
491
- past_seq_len = past_cache.get_seq_length()
492
- new_seq_len = new_cache.get_seq_length()
517
+ past_seq_len = past_kv[0][0].shape[seq_len_idx]
518
+ new_seq_len = new_kv[0][0].shape[seq_len_idx]
493
519
  if past_seq_len != new_seq_len:
494
520
  padding_target = new_cache if past_seq_len > new_seq_len else past_cache
495
521
  padding_len = abs(past_seq_len - new_seq_len)
522
+ pad_param = _get_pad_param(seq_len_idx, padding_len)
496
523
  for idx in range(len(padding_target)):
497
524
  k = padding_target.key_cache[idx]
498
525
  v = padding_target.value_cache[idx]
499
- _k = pad(k, (0, 0, padding_len, 0))
500
- _v = pad(v, (0, 0, padding_len, 0))
526
+ _k = pad(k, pad_param)
527
+ _v = pad(v, pad_param)
501
528
  padding_target.key_cache[idx] = _k
502
529
  padding_target.value_cache[idx] = _v
503
530
 
@@ -509,36 +536,19 @@ def _merge_kv_cache(
509
536
  return ret_kv.to_legacy_cache()
510
537
 
511
538
 
512
- def _get_attention_mask_and_position_ids(kv, reqs: List[InferenceRequest]):
513
- batch_size, seq_length, device = (
514
- kv[0][0].shape[0],
515
- kv[0][0].shape[2],
516
- kv[0][0].device,
517
- )
518
- seq_length = seq_length + 1
519
- position_ids = torch.as_tensor([[seq_length - 1]], dtype=torch.long, device=device)
520
- attention_mask = torch.ones(
521
- (batch_size, seq_length), dtype=torch.long, device=device
522
- )
523
- padding_lens = torch.as_tensor([r.padding_len for r in reqs])
524
- mask = torch.arange(seq_length).expand(
525
- batch_size, seq_length
526
- ) < padding_lens.unsqueeze(1)
527
- attention_mask[mask] = 0
528
- return attention_mask, position_ids
539
+ def get_batch_size_and_seq_len_from_kv_cache(kv, xinf_model_obj: "PytorchModel"):
540
+ bs_idx, seq_len_idx = xinf_model_obj.get_batch_size_and_seq_len_indexes_from_kv()
541
+ return kv[0][0].shape[bs_idx], kv[0][0].shape[seq_len_idx] + 1
529
542
 
530
543
 
531
544
  @torch.inference_mode()
532
545
  def _batch_inference_one_step_internal(
546
+ xinf_model_obj: "PytorchModel",
533
547
  req_list: List[InferenceRequest],
534
548
  model_uid,
535
549
  model,
536
550
  tokenizer,
537
- device,
538
- context_len: int,
539
- stop_tokens: Tuple[int],
540
551
  decode_round: int = 16,
541
- require_attention_mask: bool = False,
542
552
  bos_flag: str = "<bos_stream>",
543
553
  eos_flag: str = "<eos_stream>",
544
554
  ):
@@ -548,7 +558,9 @@ def _batch_inference_one_step_internal(
548
558
  if not valid_req_list:
549
559
  return
550
560
  generate_config_mapping: Dict[InferenceRequest, Tuple] = {
551
- r: r.get_generate_configs(tokenizer.eos_token_id, stop_tokens)
561
+ r: r.get_generate_configs(
562
+ tokenizer.eos_token_id, xinf_model_obj.get_builtin_stop_token_ids()
563
+ )
552
564
  for r in valid_req_list
553
565
  }
554
566
  s_time = time.time()
@@ -564,15 +576,8 @@ def _batch_inference_one_step_internal(
564
576
  decode_reqs.append(r)
565
577
 
566
578
  if prompts: # prefill first
567
- input_ids: List[List[int]] = tokenizer(prompts, padding=False).input_ids
568
- prompt_tokens = []
569
- for i, input_id in enumerate(input_ids):
570
- req = valid_req_list[i]
571
- max_src_len = get_max_src_len(context_len, req)
572
- req.prompt_tokens = input_id[-max_src_len:]
573
- prompt_tokens.append(req.prompt_tokens)
574
- _pad_seqs_inplace(prompt_tokens, valid_req_list, 0)
575
- out = model(torch.as_tensor(prompt_tokens, device=device), use_cache=True)
579
+ prefill_kws = xinf_model_obj.build_prefill_kwargs(prompts, prefill_reqs)
580
+ out = model(**prefill_kws, use_cache=True)
576
581
 
577
582
  logits = out.logits
578
583
  past_key_values = out.past_key_values
@@ -599,7 +604,9 @@ def _batch_inference_one_step_internal(
599
604
  if decode_reqs:
600
605
  decode_kv = decode_reqs[0].kv_cache
601
606
  # prefill and decode kv cache need to be merged at `batch_size` and `seq_len` dimensions.
602
- merged_kv_cache = _merge_kv_cache(decode_kv, past_key_values)
607
+ merged_kv_cache = _merge_kv_cache(
608
+ xinf_model_obj, decode_kv, past_key_values
609
+ )
603
610
  for r in valid_req_list:
604
611
  r.kv_cache = merged_kv_cache
605
612
  empty_cache()
@@ -612,20 +619,14 @@ def _batch_inference_one_step_internal(
612
619
  output_mapping: Dict[InferenceRequest, str] = {}
613
620
  # here, only decode phase, just run some rounds
614
621
  for _i in range(decode_round):
622
+ batch_size, seq_len = get_batch_size_and_seq_len_from_kv_cache(
623
+ past_key_values, xinf_model_obj
624
+ )
615
625
  decode_tokens: List[List[int]] = [[r.new_tokens[-1]] for r in valid_req_list]
616
- inf_kws = {}
617
- if require_attention_mask:
618
- attention_mask, position_ids = _get_attention_mask_and_position_ids(
619
- past_key_values, valid_req_list
620
- )
621
- inf_kws["position_ids"] = position_ids
622
- inf_kws["attention_mask"] = attention_mask
623
- out = model(
624
- input_ids=torch.as_tensor(decode_tokens, device=device),
625
- use_cache=True,
626
- past_key_values=past_key_values,
627
- **inf_kws,
626
+ inf_kws = xinf_model_obj.build_decode_kwargs(
627
+ decode_tokens, valid_req_list, batch_size, seq_len
628
628
  )
629
+ out = model(**inf_kws, use_cache=True, past_key_values=past_key_values)
629
630
  logits = out.logits
630
631
  past_key_values = out.past_key_values
631
632
 
@@ -755,27 +756,17 @@ def _batch_inference_one_step_internal(
755
756
 
756
757
 
757
758
  def batch_inference_one_step(
759
+ xinf_model_obj: "PytorchModel",
758
760
  req_list: List[InferenceRequest],
759
761
  model_uid,
760
762
  model,
761
763
  tokenizer,
762
- device,
763
- context_len: int,
764
- stop_token_ids: Tuple[int],
765
- require_attention_mask: bool = False,
766
764
  ):
767
765
  from ....core.model import OutOfMemoryError
768
766
 
769
767
  try:
770
768
  _batch_inference_one_step_internal(
771
- req_list,
772
- model_uid,
773
- model,
774
- tokenizer,
775
- device,
776
- context_len,
777
- stop_token_ids,
778
- require_attention_mask=require_attention_mask,
769
+ xinf_model_obj, req_list, model_uid, model, tokenizer
779
770
  )
780
771
  except OutOfMemoryError:
781
772
  logger.exception(
@@ -47,6 +47,11 @@ QWEN_TOOL_CALL_FAMILY = [
47
47
  "qwen2-moe-instruct",
48
48
  ]
49
49
 
50
+ GLM4_TOOL_CALL_FAMILY = [
51
+ "glm4-chat",
52
+ "glm4-chat-1m",
53
+ ]
54
+
50
55
 
51
56
  class ChatModelMixin:
52
57
  @staticmethod
@@ -617,9 +622,13 @@ Begin!"""
617
622
 
618
623
  @staticmethod
619
624
  def _eval_glm_chat_arguments(c, tools):
620
- if isinstance(c[0], str):
621
- return c[0], None, None
622
- return None, c[0]["name"], c[0]["parameters"]
625
+ try:
626
+ if isinstance(c[0], str):
627
+ return c[0], None, None
628
+ return None, c[0]["name"], c[0]["parameters"]
629
+ except KeyError:
630
+ logger.error("Can't parse glm output: %s", c)
631
+ return str(c), None, None
623
632
 
624
633
  @staticmethod
625
634
  def _eval_qwen_chat_arguments(c, tools):
@@ -668,7 +677,7 @@ Begin!"""
668
677
  family = model_family.model_family or model_family.model_name
669
678
  if family in ["gorilla-openfunctions-v1", "gorilla-openfunctions-v2"]:
670
679
  content, func, args = cls._eval_gorilla_openfunctions_arguments(c, tools)
671
- elif family in ["chatglm3", "glm4-chat"]:
680
+ elif family in ["chatglm3"] + GLM4_TOOL_CALL_FAMILY:
672
681
  content, func, args = cls._eval_glm_chat_arguments(c, tools)
673
682
  elif family in QWEN_TOOL_CALL_FAMILY:
674
683
  content, func, args = cls._eval_qwen_chat_arguments(c, tools)
@@ -756,6 +765,16 @@ Begin!"""
756
765
  "usage": usage,
757
766
  }
758
767
 
768
+ @classmethod
769
+ def get_full_prompt(cls, model_family, prompt, system_prompt, chat_history, tools):
770
+ assert model_family.prompt_style is not None
771
+ prompt_style = model_family.prompt_style.copy()
772
+ if system_prompt:
773
+ prompt_style.system_prompt = system_prompt
774
+ chat_history = chat_history or []
775
+ full_prompt = cls.get_prompt(prompt, chat_history, prompt_style, tools=tools)
776
+ return full_prompt
777
+
759
778
 
760
779
  def get_file_location(
761
780
  llm_family: LLMFamilyV1, spec: LLMSpecV1, quantization: str
@@ -772,7 +791,7 @@ def get_file_location(
772
791
  is_cached = cache_status
773
792
  assert isinstance(is_cached, bool)
774
793
 
775
- if spec.model_format in ["pytorch", "gptq", "awq"]:
794
+ if spec.model_format in ["pytorch", "gptq", "awq", "mlx"]:
776
795
  return cache_dir, is_cached
777
796
  elif spec.model_format in ["ggmlv3", "ggufv2"]:
778
797
  assert isinstance(spec, GgmlLLMSpecV1)
@@ -17,6 +17,7 @@ import logging
17
17
  import os
18
18
  import uuid
19
19
  from collections import defaultdict
20
+ from collections.abc import Sequence
20
21
  from typing import Dict, List, Optional, Tuple
21
22
 
22
23
  import numpy as np
@@ -217,7 +218,11 @@ class RerankModel:
217
218
  if similarity_scores.dtype == torch.bfloat16:
218
219
  similarity_scores = similarity_scores.float()
219
220
  else:
221
+ # Related issue: https://github.com/xorbitsai/inference/issues/1775
220
222
  similarity_scores = self._model.compute_score(sentence_combinations)
223
+ if not isinstance(similarity_scores, Sequence):
224
+ similarity_scores = [similarity_scores]
225
+
221
226
  sim_scores_argsort = list(reversed(np.argsort(similarity_scores)))
222
227
  if top_n is not None:
223
228
  sim_scores_argsort = sim_scores_argsort[:top_n]
@@ -0,0 +1,13 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.