xinference 0.14.4.post1__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (149) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +5 -39
  4. xinference/client/restful/restful_client.py +3 -24
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/model.py +82 -31
  11. xinference/core/scheduler.py +37 -37
  12. xinference/core/status_guard.py +1 -1
  13. xinference/core/supervisor.py +11 -10
  14. xinference/core/utils.py +80 -22
  15. xinference/core/worker.py +17 -16
  16. xinference/deploy/cmdline.py +8 -16
  17. xinference/deploy/local.py +1 -1
  18. xinference/deploy/supervisor.py +1 -1
  19. xinference/deploy/utils.py +1 -1
  20. xinference/deploy/worker.py +1 -1
  21. xinference/model/audio/cosyvoice.py +86 -41
  22. xinference/model/embedding/core.py +52 -31
  23. xinference/model/image/stable_diffusion/core.py +18 -1
  24. xinference/model/llm/__init__.py +21 -11
  25. xinference/model/llm/llama_cpp/core.py +16 -33
  26. xinference/model/llm/llm_family.json +619 -1297
  27. xinference/model/llm/llm_family.py +31 -52
  28. xinference/model/llm/llm_family_csghub.json +18 -35
  29. xinference/model/llm/llm_family_modelscope.json +573 -1119
  30. xinference/model/llm/lmdeploy/core.py +56 -88
  31. xinference/model/llm/mlx/core.py +46 -69
  32. xinference/model/llm/sglang/core.py +33 -18
  33. xinference/model/llm/transformers/chatglm.py +167 -305
  34. xinference/model/llm/transformers/cogvlm2.py +36 -63
  35. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  36. xinference/model/llm/transformers/core.py +49 -50
  37. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  38. xinference/model/llm/transformers/glm4v.py +55 -111
  39. xinference/model/llm/transformers/intern_vl.py +39 -70
  40. xinference/model/llm/transformers/internlm2.py +32 -54
  41. xinference/model/llm/transformers/minicpmv25.py +22 -55
  42. xinference/model/llm/transformers/minicpmv26.py +158 -68
  43. xinference/model/llm/transformers/omnilmm.py +5 -28
  44. xinference/model/llm/transformers/qwen2_vl.py +208 -0
  45. xinference/model/llm/transformers/qwen_vl.py +34 -86
  46. xinference/model/llm/transformers/utils.py +32 -38
  47. xinference/model/llm/transformers/yi_vl.py +32 -72
  48. xinference/model/llm/utils.py +195 -489
  49. xinference/model/llm/vllm/core.py +153 -100
  50. xinference/model/rerank/core.py +41 -8
  51. xinference/model/rerank/model_spec.json +7 -0
  52. xinference/model/rerank/model_spec_modelscope.json +7 -1
  53. xinference/model/utils.py +1 -31
  54. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  55. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  56. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  57. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  58. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  59. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  60. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  61. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  62. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  63. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  64. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  65. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  66. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  67. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  68. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  69. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  70. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  71. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  72. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  73. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  74. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  75. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  76. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  77. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  78. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  79. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  80. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  81. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +34 -0
  82. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  83. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  84. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  85. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  88. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  89. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  90. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  91. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  92. xinference/thirdparty/matcha/VERSION +1 -0
  93. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  94. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  95. xinference/thirdparty/omnilmm/LICENSE +201 -0
  96. xinference/thirdparty/whisper/__init__.py +156 -0
  97. xinference/thirdparty/whisper/__main__.py +3 -0
  98. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  99. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  100. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  101. xinference/thirdparty/whisper/audio.py +157 -0
  102. xinference/thirdparty/whisper/decoding.py +826 -0
  103. xinference/thirdparty/whisper/model.py +314 -0
  104. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  105. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  106. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  107. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  108. xinference/thirdparty/whisper/timing.py +386 -0
  109. xinference/thirdparty/whisper/tokenizer.py +395 -0
  110. xinference/thirdparty/whisper/transcribe.py +605 -0
  111. xinference/thirdparty/whisper/triton_ops.py +109 -0
  112. xinference/thirdparty/whisper/utils.py +316 -0
  113. xinference/thirdparty/whisper/version.py +1 -0
  114. xinference/types.py +7 -49
  115. xinference/web/ui/build/asset-manifest.json +6 -6
  116. xinference/web/ui/build/index.html +1 -1
  117. xinference/web/ui/build/static/css/{main.4bafd904.css → main.632e9148.css} +2 -2
  118. xinference/web/ui/build/static/css/main.632e9148.css.map +1 -0
  119. xinference/web/ui/build/static/js/main.9cfafbd6.js +3 -0
  120. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.9cfafbd6.js.LICENSE.txt} +2 -0
  121. xinference/web/ui/build/static/js/main.9cfafbd6.js.map +1 -0
  122. xinference/web/ui/node_modules/.cache/babel-loader/01d6d198156bacbd436c51435edbd4b2cacd47a79db929105eba30f74b67d48d.json +1 -0
  123. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  124. xinference/web/ui/node_modules/.cache/babel-loader/59eb25f514afcc4fefd1b309d192b2455f1e0aec68a9de598ca4b2333fe2c774.json +1 -0
  125. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  126. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  127. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  128. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  129. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  130. xinference/web/ui/node_modules/.package-lock.json +37 -0
  131. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  132. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  133. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  134. xinference/web/ui/package-lock.json +38 -0
  135. xinference/web/ui/package.json +1 -0
  136. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/METADATA +8 -8
  137. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/RECORD +141 -87
  138. xinference/model/llm/transformers/llama_2.py +0 -108
  139. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  140. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  141. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  142. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  144. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  145. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  146. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/LICENSE +0 -0
  147. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/WHEEL +0 -0
  148. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/entry_points.txt +0 -0
  149. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import asyncio
16
- import json
17
16
  import logging
18
17
  import multiprocessing
19
18
  import os
@@ -24,9 +23,9 @@ from typing import (
24
23
  Any,
25
24
  AsyncGenerator,
26
25
  Dict,
27
- Iterable,
28
26
  List,
29
27
  Optional,
28
+ Tuple,
30
29
  TypedDict,
31
30
  Union,
32
31
  )
@@ -34,18 +33,20 @@ from typing import (
34
33
  from ....types import (
35
34
  ChatCompletion,
36
35
  ChatCompletionChunk,
37
- ChatCompletionMessage,
38
36
  Completion,
39
37
  CompletionChoice,
40
38
  CompletionChunk,
41
39
  CompletionUsage,
42
40
  LoRA,
43
- ToolCallFunction,
44
- ToolCalls,
45
41
  )
46
42
  from .. import LLM, LLMFamilyV1, LLMSpecV1
47
43
  from ..llm_family import CustomLLMFamilyV1
48
- from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin
44
+ from ..utils import (
45
+ QWEN_TOOL_CALL_FAMILY,
46
+ QWEN_TOOL_CALL_SYMBOLS,
47
+ ChatModelMixin,
48
+ generate_completion_chunk,
49
+ )
49
50
 
50
51
  logger = logging.getLogger(__name__)
51
52
 
@@ -363,23 +364,28 @@ class VLLMModel(LLM):
363
364
  @staticmethod
364
365
  def _convert_request_output_to_completion_chunk(
365
366
  request_id: str, model: str, request_output: "RequestOutput"
366
- ) -> CompletionChunk:
367
+ ) -> Tuple[CompletionChunk, Optional[str]]:
367
368
  choices: List[CompletionChoice] = []
369
+ finish_reason = None
368
370
  for output in request_output.outputs:
369
371
  choices.append(
370
372
  CompletionChoice(
371
373
  text=output.text,
372
374
  index=output.index,
373
375
  logprobs=None, # TODO: support logprobs.
374
- finish_reason=output.finish_reason,
376
+ finish_reason=None,
375
377
  )
376
378
  )
377
- return CompletionChunk(
378
- id=request_id,
379
- object="text_completion",
380
- created=int(time.time()),
381
- model=model,
382
- choices=choices,
379
+ finish_reason = output.finish_reason
380
+ return (
381
+ CompletionChunk(
382
+ id=request_id,
383
+ object="text_completion",
384
+ created=int(time.time()),
385
+ model=model,
386
+ choices=choices,
387
+ ),
388
+ finish_reason,
383
389
  )
384
390
 
385
391
  @staticmethod
@@ -420,6 +426,7 @@ class VLLMModel(LLM):
420
426
  prompt: Union[str, Dict[str, Any]],
421
427
  generate_config: Optional[Dict] = None,
422
428
  tools: object = False,
429
+ request_id: Optional[str] = None,
423
430
  ) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]:
424
431
  try:
425
432
  from vllm.sampling_params import SamplingParams
@@ -454,7 +461,8 @@ class VLLMModel(LLM):
454
461
  else False
455
462
  )
456
463
  sampling_params = SamplingParams(**sanitized_generate_config)
457
- request_id = str(uuid.uuid1())
464
+ if not request_id:
465
+ request_id = str(uuid.uuid1())
458
466
 
459
467
  assert self._engine is not None
460
468
  results_generator = self._engine.generate(
@@ -463,10 +471,14 @@ class VLLMModel(LLM):
463
471
 
464
472
  async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
465
473
  previous_texts = [""] * sanitized_generate_config["n"]
466
- tools_token_filter = ChatModelMixin._tools_token_filter(self.model_family)
467
474
  prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
475
+ complete_response = ""
476
+ match_tool_call_tmp_results = []
477
+ is_match_tool_call = False
478
+ chunk = None
479
+ finish_reason = None
468
480
  async for _request_output in results_generator:
469
- chunk = self._convert_request_output_to_completion_chunk(
481
+ chunk, finish_reason = self._convert_request_output_to_completion_chunk(
470
482
  request_id=request_id,
471
483
  model=self.model_uid,
472
484
  request_output=_request_output,
@@ -476,40 +488,8 @@ class VLLMModel(LLM):
476
488
  delta = choice["text"][len(previous_texts[i]) :]
477
489
  previous_texts[i] = choice["text"]
478
490
  choice["text"] = delta
491
+ complete_response += delta
479
492
 
480
- if tools:
481
- # only handle the first choice
482
- choice = chunk["choices"][0]
483
- if choice["finish_reason"] is not None:
484
- # use previous text for evaluation temporarily
485
- choice_delta = choice["text"]
486
- choice["text"] = previous_texts[0]
487
- _content, func, args = ChatModelMixin._eval_tool_arguments(
488
- self.model_family, chunk, tools
489
- )
490
- choice["text"] = tools_token_filter(
491
- tokens=previous_texts[0], delta=choice_delta
492
- )
493
- if func is not None:
494
- choice["text"] = None
495
- choice["finish_reason"] = "tool_calls"
496
- choice["tool_calls"] = [
497
- ToolCalls(
498
- id=str(uuid.uuid4()),
499
- type="function",
500
- function=ToolCallFunction(
501
- name=func,
502
- arguments=json.dumps(args, ensure_ascii=False),
503
- ),
504
- )
505
- ]
506
- else:
507
- # use a filter function to skip Qwen's react thought process
508
- choice["text"] = tools_token_filter(
509
- tokens=previous_texts[0], delta=choice["text"]
510
- )
511
- if not choice["text"]:
512
- continue
513
493
  prompt_tokens = len(_request_output.prompt_token_ids)
514
494
  completion_tokens = sum(
515
495
  len(output.token_ids) for output in _request_output.outputs
@@ -520,7 +500,59 @@ class VLLMModel(LLM):
520
500
  completion_tokens=completion_tokens,
521
501
  total_tokens=total_tokens,
522
502
  )
503
+
504
+ if tools:
505
+ """
506
+ The qwen2 tool call returns format like this:
507
+ <tool_call>
508
+ {...}
509
+ </tool_call>
510
+ Here is to match this.
511
+ """
512
+ if (len(QWEN_TOOL_CALL_SYMBOLS[0]) > len(complete_response)) and (
513
+ not QWEN_TOOL_CALL_SYMBOLS[0].startswith(complete_response)
514
+ ):
515
+ for c in match_tool_call_tmp_results:
516
+ yield c
517
+ match_tool_call_tmp_results.clear()
518
+ yield chunk
519
+ elif (len(QWEN_TOOL_CALL_SYMBOLS[0]) > len(complete_response)) and (
520
+ QWEN_TOOL_CALL_SYMBOLS[0].startswith(complete_response)
521
+ ):
522
+ match_tool_call_tmp_results.append(chunk)
523
+ else:
524
+ assert len(QWEN_TOOL_CALL_SYMBOLS[0]) <= len(complete_response)
525
+ if not is_match_tool_call and complete_response.startswith(
526
+ QWEN_TOOL_CALL_SYMBOLS[0]
527
+ ):
528
+ is_match_tool_call = True
529
+ match_tool_call_tmp_results.clear()
530
+
531
+ if not is_match_tool_call:
532
+ for c in match_tool_call_tmp_results:
533
+ yield c
534
+ match_tool_call_tmp_results.clear()
535
+ yield chunk
536
+ else:
537
+ chunk["choices"][0]["text"] = complete_response
538
+ else:
539
+ yield chunk
540
+
541
+ if is_match_tool_call:
542
+ assert chunk is not None
523
543
  yield chunk
544
+
545
+ # match OpenAI API stream
546
+ yield generate_completion_chunk(
547
+ chunk_text="",
548
+ finish_reason=finish_reason,
549
+ chunk_id=request_id,
550
+ model_uid=self.model_uid,
551
+ prompt_tokens=prompt_tokens,
552
+ completion_tokens=completion_tokens,
553
+ total_tokens=total_tokens,
554
+ )
555
+
524
556
  if include_usage:
525
557
  chunk = CompletionChunk(
526
558
  id=request_id,
@@ -586,59 +618,74 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
586
618
  ) -> Dict:
587
619
  if not generate_config:
588
620
  generate_config = {}
589
- if self.model_family.prompt_style:
590
- if (
591
- not generate_config.get("stop")
592
- ) and self.model_family.prompt_style.stop:
593
- generate_config["stop"] = self.model_family.prompt_style.stop.copy()
594
- if self.model_family.prompt_style.stop_token_ids:
595
- generate_config.setdefault(
596
- "stop_token_ids",
597
- self.model_family.prompt_style.stop_token_ids.copy(),
598
- )
621
+ if not generate_config.get("stop") and self.model_family.stop:
622
+ generate_config["stop"] = self.model_family.stop.copy()
623
+ if (
624
+ not generate_config.get("stop_token_ids")
625
+ and self.model_family.stop_token_ids
626
+ ):
627
+ generate_config["stop_token_ids"] = self.model_family.stop_token_ids.copy()
599
628
  return generate_config
600
629
 
630
+ @staticmethod
631
+ def is_tool_call_chunk(chunk):
632
+ return chunk["choices"][0]["text"].startswith(QWEN_TOOL_CALL_SYMBOLS[0])
633
+
634
+ async def _async_to_tool_completion_chunks(
635
+ self,
636
+ chunks: AsyncGenerator[CompletionChunk, None],
637
+ ) -> AsyncGenerator[ChatCompletionChunk, None]:
638
+ i = 0
639
+ async for chunk in chunks:
640
+ if i == 0:
641
+ yield self._get_first_chat_completion_chunk(chunk)
642
+ # usage
643
+ choices = chunk.get("choices")
644
+ if not choices:
645
+ yield self._get_final_chat_completion_chunk(chunk)
646
+ else:
647
+ if self.is_tool_call_chunk(chunk):
648
+ yield self._tool_calls_completion_chunk(
649
+ self.model_family, self.model_uid, chunk
650
+ )
651
+ else:
652
+ yield self._to_chat_completion_chunk(chunk)
653
+ i += 1
654
+
601
655
  async def async_chat(
602
656
  self,
603
- prompt: str,
604
- system_prompt: Optional[str] = None,
605
- chat_history: Optional[List[ChatCompletionMessage]] = None,
657
+ messages: List[Dict],
606
658
  generate_config: Optional[Dict] = None,
659
+ request_id: Optional[str] = None,
607
660
  ) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
608
- assert self.model_family.prompt_style is not None
609
- prompt_style = self.model_family.prompt_style.copy()
610
- if system_prompt:
611
- prompt_style.system_prompt = system_prompt
612
- chat_history = chat_history or []
613
661
  tools = generate_config.pop("tools", []) if generate_config else None
614
- full_prompt = self.get_prompt(prompt, chat_history, prompt_style, tools=tools)
615
-
616
- generate_config = self._sanitize_chat_config(generate_config)
617
- # TODO(codingl2k1): qwen hacky to set stop for function call.
618
662
  model_family = self.model_family.model_family or self.model_family.model_name
663
+ full_context_kwargs = {}
619
664
  if tools and model_family in QWEN_TOOL_CALL_FAMILY:
620
- stop = generate_config.get("stop")
621
- if isinstance(stop, str):
622
- generate_config["stop"] = [stop, "Observation:"]
623
- elif isinstance(stop, Iterable):
624
- assert not isinstance(stop, str)
625
- generate_config["stop"] = list(stop) + ["Observation:"]
626
- else:
627
- generate_config["stop"] = "Observation:"
665
+ full_context_kwargs["tools"] = tools
666
+ assert self.model_family.chat_template is not None
667
+ full_prompt = self.get_full_context(
668
+ messages, self.model_family.chat_template, **full_context_kwargs
669
+ )
628
670
 
671
+ generate_config = self._sanitize_chat_config(generate_config)
629
672
  stream = generate_config.get("stream", None)
630
673
 
631
674
  if stream:
632
- agen = await self.async_generate(full_prompt, generate_config, tools)
675
+ agen = await self.async_generate(
676
+ full_prompt, generate_config, tools, request_id=request_id
677
+ )
633
678
  assert isinstance(agen, AsyncGenerator)
679
+ if tools:
680
+ return self._async_to_tool_completion_chunks(agen)
634
681
  return self._async_to_chat_completion_chunks(agen)
635
682
  else:
636
- c = await self.async_generate(full_prompt, generate_config)
683
+ c = await self.async_generate(
684
+ full_prompt, generate_config, request_id=request_id
685
+ )
637
686
  assert not isinstance(c, AsyncGenerator)
638
687
  if tools:
639
- return self._tool_calls_completion(
640
- self.model_family, self.model_uid, c, tools
641
- )
688
+ return self._tool_calls_completion(self.model_family, self.model_uid, c)
642
689
  return self._to_chat_completion(c)
643
690
 
644
691
 
@@ -666,28 +713,30 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
666
713
  self,
667
714
  generate_config: Optional[Dict] = None,
668
715
  ) -> Dict:
716
+ from ..utils import get_stop_token_ids_from_config_file
717
+
669
718
  if not generate_config:
670
719
  generate_config = {}
671
- if self.model_family.prompt_style:
672
- if self.model_family.prompt_style.stop_token_ids:
673
- generate_config.setdefault(
674
- "stop_token_ids",
675
- self.model_family.prompt_style.stop_token_ids.copy(),
676
- )
720
+ if generate_config.get("stop_token_ids", None) is None:
721
+ stop_token_ids = get_stop_token_ids_from_config_file(self.model_path)
722
+ if stop_token_ids is not None:
723
+ generate_config.setdefault("stop_token_ids", stop_token_ids)
724
+ else:
725
+ if self.model_family.stop_token_ids:
726
+ generate_config.setdefault(
727
+ "stop_token_ids", self.model_family.stop_token_ids.copy()
728
+ )
677
729
  return generate_config
678
730
 
679
731
  async def async_chat(
680
732
  self,
681
- prompt: str,
682
- system_prompt: Optional[str] = None,
683
- chat_history: Optional[List[ChatCompletionMessage]] = None,
733
+ messages: List[Dict],
684
734
  generate_config: Optional[Dict] = None,
735
+ request_id: Optional[str] = None,
685
736
  ) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
686
737
  # only support single image, waiting vllm support multi images
687
- assert self.model_family.prompt_style is not None
688
- prompt_style = self.model_family.prompt_style.copy()
689
- chat_history = chat_history or []
690
- prompt, images = self.get_prompt(prompt, chat_history, prompt_style)
738
+ model_family = self.model_family.model_family or self.model_family.model_name
739
+ prompt, images = self.get_specific_prompt(model_family, messages)
691
740
 
692
741
  if len(images) == 0:
693
742
  inputs = {
@@ -703,10 +752,14 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
703
752
  stream = generate_config.get("stream", None)
704
753
 
705
754
  if stream:
706
- agen = await self.async_generate(inputs, generate_config)
755
+ agen = await self.async_generate(
756
+ inputs, generate_config, request_id=request_id
757
+ )
707
758
  assert isinstance(agen, AsyncGenerator)
708
759
  return self._async_to_chat_completion_chunks(agen)
709
760
  else:
710
- c = await self.async_generate(inputs, generate_config)
761
+ c = await self.async_generate(
762
+ inputs, generate_config, request_id=request_id
763
+ )
711
764
  assert not isinstance(c, AsyncGenerator)
712
765
  return self._to_chat_completion(c)
@@ -15,6 +15,7 @@
15
15
  import gc
16
16
  import logging
17
17
  import os
18
+ import threading
18
19
  import uuid
19
20
  from collections import defaultdict
20
21
  from collections.abc import Sequence
@@ -22,6 +23,7 @@ from typing import Dict, List, Literal, Optional, Tuple
22
23
 
23
24
  import numpy as np
24
25
  import torch
26
+ import torch.nn as nn
25
27
 
26
28
  from ...constants import XINFERENCE_CACHE_DIR
27
29
  from ...device_utils import empty_cache
@@ -49,6 +51,7 @@ class RerankModelSpec(CacheableModelSpec):
49
51
  model_name: str
50
52
  language: List[str]
51
53
  type: Optional[str] = "unknown"
54
+ max_tokens: Optional[int]
52
55
  model_id: str
53
56
  model_revision: Optional[str]
54
57
  model_hub: str = "huggingface"
@@ -102,6 +105,30 @@ def generate_rerank_description(model_spec: RerankModelSpec) -> Dict[str, List[D
102
105
  return res
103
106
 
104
107
 
108
+ class _ModelWrapper:
109
+ def __init__(self, module: nn.Module):
110
+ self._module = module
111
+ self._local_data = threading.local()
112
+
113
+ @property
114
+ def n_tokens(self):
115
+ return getattr(self._local_data, "n_tokens", 0)
116
+
117
+ @n_tokens.setter
118
+ def n_tokens(self, new_n_tokens):
119
+ self._local_data.n_tokens = new_n_tokens
120
+
121
+ def __getattr__(self, attr):
122
+ return getattr(self._module, attr)
123
+
124
+ def __call__(self, **kwargs):
125
+ attention_mask = kwargs["attention_mask"]
126
+ # when batching, the attention mask 1 means there is a token
127
+ # thus we just sum up it to get the total number of tokens
128
+ self.n_tokens += attention_mask.sum().item()
129
+ return self._module(**kwargs)
130
+
131
+
105
132
  class RerankModel:
106
133
  def __init__(
107
134
  self,
@@ -166,6 +193,7 @@ class RerankModel:
166
193
  self._model_path,
167
194
  device=self._device,
168
195
  trust_remote_code=True,
196
+ max_length=getattr(self._model_spec, "max_tokens"),
169
197
  **self._model_config,
170
198
  )
171
199
  if self._use_fp16:
@@ -189,6 +217,8 @@ class RerankModel:
189
217
 
190
218
  raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
191
219
  self._model = FlagReranker(self._model_path, use_fp16=self._use_fp16)
220
+ # Wrap transformers model to record number of tokens
221
+ self._model.model = _ModelWrapper(self._model.model)
192
222
 
193
223
  def rerank(
194
224
  self,
@@ -200,17 +230,14 @@ class RerankModel:
200
230
  return_len: Optional[bool],
201
231
  **kwargs,
202
232
  ) -> Rerank:
203
- self._counter += 1
204
- if self._counter % RERANK_EMPTY_CACHE_COUNT == 0:
205
- logger.debug("Empty rerank cache.")
206
- gc.collect()
207
- empty_cache()
208
233
  assert self._model is not None
209
234
  if kwargs:
210
235
  raise ValueError("rerank hasn't support extra parameter.")
211
236
  if max_chunks_per_doc is not None:
212
237
  raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
213
238
  sentence_combinations = [[query, doc] for doc in documents]
239
+ # reset n tokens
240
+ self._model.model.n_tokens = 0
214
241
  if self._model_spec.type == "normal":
215
242
  similarity_scores = self._model.predict(
216
243
  sentence_combinations, convert_to_numpy=False, convert_to_tensor=True
@@ -245,9 +272,7 @@ class RerankModel:
245
272
  for arg in sim_scores_argsort
246
273
  ]
247
274
  if return_len:
248
- tokenizer = self._get_tokenizer(self._model_path)
249
- input_len = sum([len(tokenizer.tokenize(t)) for t in documents])
250
-
275
+ input_len = self._model.model.n_tokens
251
276
  # Rerank Model output is just score or documents
252
277
  # while return_documents = True
253
278
  output_len = input_len
@@ -265,6 +290,14 @@ class RerankModel:
265
290
  "warnings": None,
266
291
  }
267
292
 
293
+ del similarity_scores
294
+ # clear cache if possible
295
+ self._counter += 1
296
+ if self._counter % RERANK_EMPTY_CACHE_COUNT == 0:
297
+ logger.debug("Empty rerank cache.")
298
+ gc.collect()
299
+ empty_cache()
300
+
268
301
  return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)
269
302
 
270
303
 
@@ -3,6 +3,7 @@
3
3
  "model_name": "bge-reranker-large",
4
4
  "type": "normal",
5
5
  "language": ["en", "zh"],
6
+ "max_tokens": 512,
6
7
  "model_id": "BAAI/bge-reranker-large",
7
8
  "model_revision": "27c9168d479987529781de8474dff94d69beca11"
8
9
  },
@@ -10,6 +11,7 @@
10
11
  "model_name": "bge-reranker-base",
11
12
  "type": "normal",
12
13
  "language": ["en", "zh"],
14
+ "max_tokens": 512,
13
15
  "model_id": "BAAI/bge-reranker-base",
14
16
  "model_revision": "465b4b7ddf2be0a020c8ad6e525b9bb1dbb708ae"
15
17
  },
@@ -17,6 +19,7 @@
17
19
  "model_name": "bce-reranker-base_v1",
18
20
  "type": "normal",
19
21
  "language": ["en", "zh"],
22
+ "max_tokens": 512,
20
23
  "model_id": "maidalun1020/bce-reranker-base_v1",
21
24
  "model_revision": "eaa31a577a0574e87a08959bd229ca14ce1b5496"
22
25
  },
@@ -24,6 +27,7 @@
24
27
  "model_name": "bge-reranker-v2-m3",
25
28
  "type": "normal",
26
29
  "language": ["en", "zh", "multilingual"],
30
+ "max_tokens": 8192,
27
31
  "model_id": "BAAI/bge-reranker-v2-m3",
28
32
  "model_revision": "12e974610ba9083ed95f3edf08d7e899581f4de4"
29
33
  },
@@ -31,6 +35,7 @@
31
35
  "model_name": "bge-reranker-v2-gemma",
32
36
  "type": "LLM-based",
33
37
  "language": ["en", "zh", "multilingual"],
38
+ "max_tokens": 8192,
34
39
  "model_id": "BAAI/bge-reranker-v2-gemma",
35
40
  "model_revision": "1787044f8b6fb740a9de4557c3a12377f84d9e17"
36
41
  },
@@ -38,6 +43,7 @@
38
43
  "model_name": "bge-reranker-v2-minicpm-layerwise",
39
44
  "type": "LLM-based layerwise",
40
45
  "language": ["en", "zh", "multilingual"],
46
+ "max_tokens": 2048,
41
47
  "model_id": "BAAI/bge-reranker-v2-minicpm-layerwise",
42
48
  "model_revision": "47b5332b296c4d8cb6ee2c60502cc62a0d708881"
43
49
  },
@@ -45,6 +51,7 @@
45
51
  "model_name": "jina-reranker-v2",
46
52
  "type": "normal",
47
53
  "language": ["en", "zh", "multilingual"],
54
+ "max_tokens": 1024,
48
55
  "model_id": "jinaai/jina-reranker-v2-base-multilingual",
49
56
  "model_revision": "298e48cada4a9318650d7fbd795f63827f884087"
50
57
  }
@@ -3,6 +3,7 @@
3
3
  "model_name": "bge-reranker-base",
4
4
  "type": "normal",
5
5
  "language": ["en", "zh"],
6
+ "max_tokens": 512,
6
7
  "model_id": "Xorbits/bge-reranker-base",
7
8
  "model_revision": "v0.0.1",
8
9
  "model_hub": "modelscope"
@@ -11,6 +12,7 @@
11
12
  "model_name": "bge-reranker-large",
12
13
  "type": "normal",
13
14
  "language": ["en", "zh"],
15
+ "max_tokens": 512,
14
16
  "model_id": "Xorbits/bge-reranker-large",
15
17
  "model_revision": "v0.0.1",
16
18
  "model_hub": "modelscope"
@@ -19,6 +21,7 @@
19
21
  "model_name": "bce-reranker-base_v1",
20
22
  "type": "normal",
21
23
  "language": ["en", "zh"],
24
+ "max_tokens": 512,
22
25
  "model_id": "maidalun/bce-reranker-base_v1",
23
26
  "model_revision": "v0.0.1",
24
27
  "model_hub": "modelscope"
@@ -26,6 +29,7 @@
26
29
  {
27
30
  "model_name": "bge-reranker-v2-m3",
28
31
  "type": "normal",
32
+ "max_tokens": 8192,
29
33
  "language": ["en", "zh", "multilingual"],
30
34
  "model_id": "AI-ModelScope/bge-reranker-v2-m3",
31
35
  "model_hub": "modelscope"
@@ -34,6 +38,7 @@
34
38
  "model_name": "bge-reranker-v2-gemma",
35
39
  "type": "LLM-based",
36
40
  "language": ["en", "zh", "multilingual"],
41
+ "max_tokens": 8192,
37
42
  "model_id": "AI-ModelScope/bge-reranker-v2-gemma",
38
43
  "model_hub": "modelscope"
39
44
  },
@@ -41,7 +46,8 @@
41
46
  "model_name": "bge-reranker-v2-minicpm-layerwise",
42
47
  "type": "LLM-based layerwise",
43
48
  "language": ["en", "zh", "multilingual"],
44
- "model_id": "zfffff/bge-reranker-v2-minicpm-layerwise",
49
+ "max_tokens": 2048,
50
+ "model_id": "mirror013/bge-reranker-v2-minicpm-layerwise",
45
51
  "model_hub": "modelscope"
46
52
  }
47
53
  ]
xinference/model/utils.py CHANGED
@@ -11,10 +11,6 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
- import functools
16
- import gc
17
- import inspect
18
14
  import json
19
15
  import logging
20
16
  import os
@@ -28,7 +24,7 @@ import numpy as np
28
24
  import torch
29
25
 
30
26
  from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_ENV_MODEL_SRC
31
- from ..device_utils import empty_cache, get_available_device, is_device_available
27
+ from ..device_utils import get_available_device, is_device_available
32
28
  from .core import CacheableModelSpec
33
29
 
34
30
  logger = logging.getLogger(__name__)
@@ -357,32 +353,6 @@ def convert_float_to_int_or_str(model_size: float) -> Union[int, str]:
357
353
  return str(model_size)
358
354
 
359
355
 
360
- def ensure_cache_cleared(func: Callable):
361
- assert not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
362
- func
363
- )
364
- if inspect.isgeneratorfunction(func):
365
-
366
- @functools.wraps(func)
367
- def inner(*args, **kwargs):
368
- for obj in func(*args, **kwargs):
369
- yield obj
370
- gc.collect()
371
- empty_cache()
372
-
373
- else:
374
-
375
- @functools.wraps(func)
376
- def inner(*args, **kwargs):
377
- try:
378
- return func(*args, **kwargs)
379
- finally:
380
- gc.collect()
381
- empty_cache()
382
-
383
- return inner
384
-
385
-
386
356
  def set_all_random_seed(seed: int):
387
357
  random.seed(seed)
388
358
  np.random.seed(seed)