xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +209 -40
  4. xinference/client/restful/restful_client.py +7 -26
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/image_interface.py +28 -0
  11. xinference/core/model.py +110 -31
  12. xinference/core/scheduler.py +37 -37
  13. xinference/core/status_guard.py +1 -1
  14. xinference/core/supervisor.py +17 -10
  15. xinference/core/utils.py +80 -22
  16. xinference/core/worker.py +17 -16
  17. xinference/deploy/cmdline.py +8 -16
  18. xinference/deploy/local.py +1 -1
  19. xinference/deploy/supervisor.py +1 -1
  20. xinference/deploy/utils.py +1 -1
  21. xinference/deploy/worker.py +1 -1
  22. xinference/model/audio/cosyvoice.py +86 -41
  23. xinference/model/audio/fish_speech.py +9 -9
  24. xinference/model/audio/model_spec.json +9 -9
  25. xinference/model/audio/whisper.py +4 -1
  26. xinference/model/embedding/core.py +52 -31
  27. xinference/model/image/core.py +2 -1
  28. xinference/model/image/model_spec.json +16 -4
  29. xinference/model/image/model_spec_modelscope.json +16 -4
  30. xinference/model/image/sdapi.py +136 -0
  31. xinference/model/image/stable_diffusion/core.py +164 -19
  32. xinference/model/llm/__init__.py +29 -11
  33. xinference/model/llm/llama_cpp/core.py +16 -33
  34. xinference/model/llm/llm_family.json +1011 -1296
  35. xinference/model/llm/llm_family.py +34 -53
  36. xinference/model/llm/llm_family_csghub.json +18 -35
  37. xinference/model/llm/llm_family_modelscope.json +981 -1122
  38. xinference/model/llm/lmdeploy/core.py +56 -88
  39. xinference/model/llm/mlx/core.py +46 -69
  40. xinference/model/llm/sglang/core.py +36 -18
  41. xinference/model/llm/transformers/chatglm.py +168 -306
  42. xinference/model/llm/transformers/cogvlm2.py +36 -63
  43. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  44. xinference/model/llm/transformers/core.py +55 -50
  45. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  46. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  47. xinference/model/llm/transformers/glm4v.py +55 -111
  48. xinference/model/llm/transformers/intern_vl.py +39 -70
  49. xinference/model/llm/transformers/internlm2.py +32 -54
  50. xinference/model/llm/transformers/minicpmv25.py +22 -55
  51. xinference/model/llm/transformers/minicpmv26.py +158 -68
  52. xinference/model/llm/transformers/omnilmm.py +5 -28
  53. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  54. xinference/model/llm/transformers/qwen2_vl.py +234 -0
  55. xinference/model/llm/transformers/qwen_vl.py +34 -86
  56. xinference/model/llm/transformers/utils.py +32 -38
  57. xinference/model/llm/transformers/yi_vl.py +32 -72
  58. xinference/model/llm/utils.py +280 -554
  59. xinference/model/llm/vllm/core.py +161 -100
  60. xinference/model/rerank/core.py +41 -8
  61. xinference/model/rerank/model_spec.json +7 -0
  62. xinference/model/rerank/model_spec_modelscope.json +7 -1
  63. xinference/model/utils.py +1 -31
  64. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  65. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  66. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  67. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  68. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  69. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  70. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  71. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  72. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  73. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  74. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  77. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  78. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  79. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  80. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  81. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  82. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  83. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  84. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  85. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  86. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  87. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  88. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  89. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  90. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  91. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
  92. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  93. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  94. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  95. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  96. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  97. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  98. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  99. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  100. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  101. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  102. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  103. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  104. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  105. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  107. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  108. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  109. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  110. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  111. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  112. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  113. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  114. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  115. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  116. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  117. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  118. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  122. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  123. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  124. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  126. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  127. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  128. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  129. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  130. xinference/thirdparty/matcha/VERSION +1 -0
  131. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  132. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  133. xinference/thirdparty/omnilmm/LICENSE +201 -0
  134. xinference/thirdparty/whisper/__init__.py +156 -0
  135. xinference/thirdparty/whisper/__main__.py +3 -0
  136. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  137. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  138. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  139. xinference/thirdparty/whisper/audio.py +157 -0
  140. xinference/thirdparty/whisper/decoding.py +826 -0
  141. xinference/thirdparty/whisper/model.py +314 -0
  142. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  143. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  144. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  145. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  146. xinference/thirdparty/whisper/timing.py +386 -0
  147. xinference/thirdparty/whisper/tokenizer.py +395 -0
  148. xinference/thirdparty/whisper/transcribe.py +605 -0
  149. xinference/thirdparty/whisper/triton_ops.py +109 -0
  150. xinference/thirdparty/whisper/utils.py +316 -0
  151. xinference/thirdparty/whisper/version.py +1 -0
  152. xinference/types.py +14 -53
  153. xinference/web/ui/build/asset-manifest.json +6 -6
  154. xinference/web/ui/build/index.html +1 -1
  155. xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
  156. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  157. xinference/web/ui/build/static/js/main.754740c0.js +3 -0
  158. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
  159. xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
  160. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  161. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  162. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  163. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  164. xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
  165. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  166. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  167. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  168. xinference/web/ui/node_modules/.package-lock.json +37 -0
  169. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  170. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  171. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  172. xinference/web/ui/package-lock.json +38 -0
  173. xinference/web/ui/package.json +1 -0
  174. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
  175. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
  176. xinference/model/llm/transformers/llama_2.py +0 -108
  177. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  178. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  179. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  180. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  181. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  182. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  183. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  184. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  185. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  186. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  187. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  188. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  189. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  190. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  191. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import asyncio
16
- import json
17
16
  import logging
18
17
  import multiprocessing
19
18
  import os
@@ -24,9 +23,9 @@ from typing import (
24
23
  Any,
25
24
  AsyncGenerator,
26
25
  Dict,
27
- Iterable,
28
26
  List,
29
27
  Optional,
28
+ Tuple,
30
29
  TypedDict,
31
30
  Union,
32
31
  )
@@ -34,18 +33,20 @@ from typing import (
34
33
  from ....types import (
35
34
  ChatCompletion,
36
35
  ChatCompletionChunk,
37
- ChatCompletionMessage,
38
36
  Completion,
39
37
  CompletionChoice,
40
38
  CompletionChunk,
41
39
  CompletionUsage,
42
40
  LoRA,
43
- ToolCallFunction,
44
- ToolCalls,
45
41
  )
46
42
  from .. import LLM, LLMFamilyV1, LLMSpecV1
47
43
  from ..llm_family import CustomLLMFamilyV1
48
- from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin
44
+ from ..utils import (
45
+ QWEN_TOOL_CALL_FAMILY,
46
+ QWEN_TOOL_CALL_SYMBOLS,
47
+ ChatModelMixin,
48
+ generate_completion_chunk,
49
+ )
49
50
 
50
51
  logger = logging.getLogger(__name__)
51
52
 
@@ -103,6 +104,7 @@ VLLM_SUPPORTED_MODELS = [
103
104
  "code-llama-python",
104
105
  "deepseek",
105
106
  "deepseek-coder",
107
+ "yi-coder",
106
108
  ]
107
109
  VLLM_SUPPORTED_CHAT_MODELS = [
108
110
  "llama-2-chat",
@@ -129,6 +131,7 @@ VLLM_SUPPORTED_CHAT_MODELS = [
129
131
  "codegeex4",
130
132
  "deepseek-chat",
131
133
  "deepseek-coder-instruct",
134
+ "yi-coder-chat",
132
135
  ]
133
136
  if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
134
137
  VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-chat")
@@ -148,6 +151,12 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.4.0":
148
151
  VLLM_SUPPORTED_CHAT_MODELS.append("qwen2-moe-instruct")
149
152
  VLLM_SUPPORTED_CHAT_MODELS.append("c4ai-command-r-v01")
150
153
 
154
+ if VLLM_INSTALLED and vllm.__version__ >= "0.5.1":
155
+ VLLM_SUPPORTED_CHAT_MODELS.append("deepseek-v2-chat")
156
+ VLLM_SUPPORTED_CHAT_MODELS.append("deepseek-v2-chat-0628")
157
+ VLLM_SUPPORTED_CHAT_MODELS.append("deepseek-v2.5")
158
+
159
+
151
160
  if VLLM_INSTALLED and vllm.__version__ >= "0.5.3":
152
161
  VLLM_SUPPORTED_CHAT_MODELS.append("gemma-2-it")
153
162
  VLLM_SUPPORTED_CHAT_MODELS.append("mistral-nemo-instruct")
@@ -363,23 +372,28 @@ class VLLMModel(LLM):
363
372
  @staticmethod
364
373
  def _convert_request_output_to_completion_chunk(
365
374
  request_id: str, model: str, request_output: "RequestOutput"
366
- ) -> CompletionChunk:
375
+ ) -> Tuple[CompletionChunk, Optional[str]]:
367
376
  choices: List[CompletionChoice] = []
377
+ finish_reason = None
368
378
  for output in request_output.outputs:
369
379
  choices.append(
370
380
  CompletionChoice(
371
381
  text=output.text,
372
382
  index=output.index,
373
383
  logprobs=None, # TODO: support logprobs.
374
- finish_reason=output.finish_reason,
384
+ finish_reason=None,
375
385
  )
376
386
  )
377
- return CompletionChunk(
378
- id=request_id,
379
- object="text_completion",
380
- created=int(time.time()),
381
- model=model,
382
- choices=choices,
387
+ finish_reason = output.finish_reason
388
+ return (
389
+ CompletionChunk(
390
+ id=request_id,
391
+ object="text_completion",
392
+ created=int(time.time()),
393
+ model=model,
394
+ choices=choices,
395
+ ),
396
+ finish_reason,
383
397
  )
384
398
 
385
399
  @staticmethod
@@ -420,6 +434,7 @@ class VLLMModel(LLM):
420
434
  prompt: Union[str, Dict[str, Any]],
421
435
  generate_config: Optional[Dict] = None,
422
436
  tools: object = False,
437
+ request_id: Optional[str] = None,
423
438
  ) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]:
424
439
  try:
425
440
  from vllm.sampling_params import SamplingParams
@@ -454,7 +469,8 @@ class VLLMModel(LLM):
454
469
  else False
455
470
  )
456
471
  sampling_params = SamplingParams(**sanitized_generate_config)
457
- request_id = str(uuid.uuid1())
472
+ if not request_id:
473
+ request_id = str(uuid.uuid1())
458
474
 
459
475
  assert self._engine is not None
460
476
  results_generator = self._engine.generate(
@@ -463,10 +479,14 @@ class VLLMModel(LLM):
463
479
 
464
480
  async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
465
481
  previous_texts = [""] * sanitized_generate_config["n"]
466
- tools_token_filter = ChatModelMixin._tools_token_filter(self.model_family)
467
482
  prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
483
+ complete_response = ""
484
+ match_tool_call_tmp_results = []
485
+ is_match_tool_call = False
486
+ chunk = None
487
+ finish_reason = None
468
488
  async for _request_output in results_generator:
469
- chunk = self._convert_request_output_to_completion_chunk(
489
+ chunk, finish_reason = self._convert_request_output_to_completion_chunk(
470
490
  request_id=request_id,
471
491
  model=self.model_uid,
472
492
  request_output=_request_output,
@@ -476,40 +496,8 @@ class VLLMModel(LLM):
476
496
  delta = choice["text"][len(previous_texts[i]) :]
477
497
  previous_texts[i] = choice["text"]
478
498
  choice["text"] = delta
499
+ complete_response += delta
479
500
 
480
- if tools:
481
- # only handle the first choice
482
- choice = chunk["choices"][0]
483
- if choice["finish_reason"] is not None:
484
- # use previous text for evaluation temporarily
485
- choice_delta = choice["text"]
486
- choice["text"] = previous_texts[0]
487
- _content, func, args = ChatModelMixin._eval_tool_arguments(
488
- self.model_family, chunk, tools
489
- )
490
- choice["text"] = tools_token_filter(
491
- tokens=previous_texts[0], delta=choice_delta
492
- )
493
- if func is not None:
494
- choice["text"] = None
495
- choice["finish_reason"] = "tool_calls"
496
- choice["tool_calls"] = [
497
- ToolCalls(
498
- id=str(uuid.uuid4()),
499
- type="function",
500
- function=ToolCallFunction(
501
- name=func,
502
- arguments=json.dumps(args, ensure_ascii=False),
503
- ),
504
- )
505
- ]
506
- else:
507
- # use a filter function to skip Qwen's react thought process
508
- choice["text"] = tools_token_filter(
509
- tokens=previous_texts[0], delta=choice["text"]
510
- )
511
- if not choice["text"]:
512
- continue
513
501
  prompt_tokens = len(_request_output.prompt_token_ids)
514
502
  completion_tokens = sum(
515
503
  len(output.token_ids) for output in _request_output.outputs
@@ -520,7 +508,59 @@ class VLLMModel(LLM):
520
508
  completion_tokens=completion_tokens,
521
509
  total_tokens=total_tokens,
522
510
  )
511
+
512
+ if tools:
513
+ """
514
+ The qwen2 tool call returns format like this:
515
+ <tool_call>
516
+ {...}
517
+ </tool_call>
518
+ Here is to match this.
519
+ """
520
+ if (len(QWEN_TOOL_CALL_SYMBOLS[0]) > len(complete_response)) and (
521
+ not QWEN_TOOL_CALL_SYMBOLS[0].startswith(complete_response)
522
+ ):
523
+ for c in match_tool_call_tmp_results:
524
+ yield c
525
+ match_tool_call_tmp_results.clear()
526
+ yield chunk
527
+ elif (len(QWEN_TOOL_CALL_SYMBOLS[0]) > len(complete_response)) and (
528
+ QWEN_TOOL_CALL_SYMBOLS[0].startswith(complete_response)
529
+ ):
530
+ match_tool_call_tmp_results.append(chunk)
531
+ else:
532
+ assert len(QWEN_TOOL_CALL_SYMBOLS[0]) <= len(complete_response)
533
+ if not is_match_tool_call and complete_response.startswith(
534
+ QWEN_TOOL_CALL_SYMBOLS[0]
535
+ ):
536
+ is_match_tool_call = True
537
+ match_tool_call_tmp_results.clear()
538
+
539
+ if not is_match_tool_call:
540
+ for c in match_tool_call_tmp_results:
541
+ yield c
542
+ match_tool_call_tmp_results.clear()
543
+ yield chunk
544
+ else:
545
+ chunk["choices"][0]["text"] = complete_response
546
+ else:
547
+ yield chunk
548
+
549
+ if is_match_tool_call:
550
+ assert chunk is not None
523
551
  yield chunk
552
+
553
+ # match OpenAI API stream
554
+ yield generate_completion_chunk(
555
+ chunk_text="",
556
+ finish_reason=finish_reason,
557
+ chunk_id=request_id,
558
+ model_uid=self.model_uid,
559
+ prompt_tokens=prompt_tokens,
560
+ completion_tokens=completion_tokens,
561
+ total_tokens=total_tokens,
562
+ )
563
+
524
564
  if include_usage:
525
565
  chunk = CompletionChunk(
526
566
  id=request_id,
@@ -586,59 +626,74 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
586
626
  ) -> Dict:
587
627
  if not generate_config:
588
628
  generate_config = {}
589
- if self.model_family.prompt_style:
590
- if (
591
- not generate_config.get("stop")
592
- ) and self.model_family.prompt_style.stop:
593
- generate_config["stop"] = self.model_family.prompt_style.stop.copy()
594
- if self.model_family.prompt_style.stop_token_ids:
595
- generate_config.setdefault(
596
- "stop_token_ids",
597
- self.model_family.prompt_style.stop_token_ids.copy(),
598
- )
629
+ if not generate_config.get("stop") and self.model_family.stop:
630
+ generate_config["stop"] = self.model_family.stop.copy()
631
+ if (
632
+ not generate_config.get("stop_token_ids")
633
+ and self.model_family.stop_token_ids
634
+ ):
635
+ generate_config["stop_token_ids"] = self.model_family.stop_token_ids.copy()
599
636
  return generate_config
600
637
 
638
+ @staticmethod
639
+ def is_tool_call_chunk(chunk):
640
+ return chunk["choices"][0]["text"].startswith(QWEN_TOOL_CALL_SYMBOLS[0])
641
+
642
+ async def _async_to_tool_completion_chunks(
643
+ self,
644
+ chunks: AsyncGenerator[CompletionChunk, None],
645
+ ) -> AsyncGenerator[ChatCompletionChunk, None]:
646
+ i = 0
647
+ async for chunk in chunks:
648
+ if i == 0:
649
+ yield self._get_first_chat_completion_chunk(chunk)
650
+ # usage
651
+ choices = chunk.get("choices")
652
+ if not choices:
653
+ yield self._get_final_chat_completion_chunk(chunk)
654
+ else:
655
+ if self.is_tool_call_chunk(chunk):
656
+ yield self._tool_calls_completion_chunk(
657
+ self.model_family, self.model_uid, chunk
658
+ )
659
+ else:
660
+ yield self._to_chat_completion_chunk(chunk)
661
+ i += 1
662
+
601
663
  async def async_chat(
602
664
  self,
603
- prompt: str,
604
- system_prompt: Optional[str] = None,
605
- chat_history: Optional[List[ChatCompletionMessage]] = None,
665
+ messages: List[Dict],
606
666
  generate_config: Optional[Dict] = None,
667
+ request_id: Optional[str] = None,
607
668
  ) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
608
- assert self.model_family.prompt_style is not None
609
- prompt_style = self.model_family.prompt_style.copy()
610
- if system_prompt:
611
- prompt_style.system_prompt = system_prompt
612
- chat_history = chat_history or []
613
669
  tools = generate_config.pop("tools", []) if generate_config else None
614
- full_prompt = self.get_prompt(prompt, chat_history, prompt_style, tools=tools)
615
-
616
- generate_config = self._sanitize_chat_config(generate_config)
617
- # TODO(codingl2k1): qwen hacky to set stop for function call.
618
670
  model_family = self.model_family.model_family or self.model_family.model_name
671
+ full_context_kwargs = {}
619
672
  if tools and model_family in QWEN_TOOL_CALL_FAMILY:
620
- stop = generate_config.get("stop")
621
- if isinstance(stop, str):
622
- generate_config["stop"] = [stop, "Observation:"]
623
- elif isinstance(stop, Iterable):
624
- assert not isinstance(stop, str)
625
- generate_config["stop"] = list(stop) + ["Observation:"]
626
- else:
627
- generate_config["stop"] = "Observation:"
673
+ full_context_kwargs["tools"] = tools
674
+ assert self.model_family.chat_template is not None
675
+ full_prompt = self.get_full_context(
676
+ messages, self.model_family.chat_template, **full_context_kwargs
677
+ )
628
678
 
679
+ generate_config = self._sanitize_chat_config(generate_config)
629
680
  stream = generate_config.get("stream", None)
630
681
 
631
682
  if stream:
632
- agen = await self.async_generate(full_prompt, generate_config, tools)
683
+ agen = await self.async_generate(
684
+ full_prompt, generate_config, tools, request_id=request_id
685
+ )
633
686
  assert isinstance(agen, AsyncGenerator)
687
+ if tools:
688
+ return self._async_to_tool_completion_chunks(agen)
634
689
  return self._async_to_chat_completion_chunks(agen)
635
690
  else:
636
- c = await self.async_generate(full_prompt, generate_config)
691
+ c = await self.async_generate(
692
+ full_prompt, generate_config, request_id=request_id
693
+ )
637
694
  assert not isinstance(c, AsyncGenerator)
638
695
  if tools:
639
- return self._tool_calls_completion(
640
- self.model_family, self.model_uid, c, tools
641
- )
696
+ return self._tool_calls_completion(self.model_family, self.model_uid, c)
642
697
  return self._to_chat_completion(c)
643
698
 
644
699
 
@@ -666,28 +721,30 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
666
721
  self,
667
722
  generate_config: Optional[Dict] = None,
668
723
  ) -> Dict:
724
+ from ..utils import get_stop_token_ids_from_config_file
725
+
669
726
  if not generate_config:
670
727
  generate_config = {}
671
- if self.model_family.prompt_style:
672
- if self.model_family.prompt_style.stop_token_ids:
673
- generate_config.setdefault(
674
- "stop_token_ids",
675
- self.model_family.prompt_style.stop_token_ids.copy(),
676
- )
728
+ if generate_config.get("stop_token_ids", None) is None:
729
+ stop_token_ids = get_stop_token_ids_from_config_file(self.model_path)
730
+ if stop_token_ids is not None:
731
+ generate_config.setdefault("stop_token_ids", stop_token_ids)
732
+ else:
733
+ if self.model_family.stop_token_ids:
734
+ generate_config.setdefault(
735
+ "stop_token_ids", self.model_family.stop_token_ids.copy()
736
+ )
677
737
  return generate_config
678
738
 
679
739
  async def async_chat(
680
740
  self,
681
- prompt: str,
682
- system_prompt: Optional[str] = None,
683
- chat_history: Optional[List[ChatCompletionMessage]] = None,
741
+ messages: List[Dict],
684
742
  generate_config: Optional[Dict] = None,
743
+ request_id: Optional[str] = None,
685
744
  ) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
686
745
  # only support single image, waiting vllm support multi images
687
- assert self.model_family.prompt_style is not None
688
- prompt_style = self.model_family.prompt_style.copy()
689
- chat_history = chat_history or []
690
- prompt, images = self.get_prompt(prompt, chat_history, prompt_style)
746
+ model_family = self.model_family.model_family or self.model_family.model_name
747
+ prompt, images = self.get_specific_prompt(model_family, messages)
691
748
 
692
749
  if len(images) == 0:
693
750
  inputs = {
@@ -703,10 +760,14 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
703
760
  stream = generate_config.get("stream", None)
704
761
 
705
762
  if stream:
706
- agen = await self.async_generate(inputs, generate_config)
763
+ agen = await self.async_generate(
764
+ inputs, generate_config, request_id=request_id
765
+ )
707
766
  assert isinstance(agen, AsyncGenerator)
708
767
  return self._async_to_chat_completion_chunks(agen)
709
768
  else:
710
- c = await self.async_generate(inputs, generate_config)
769
+ c = await self.async_generate(
770
+ inputs, generate_config, request_id=request_id
771
+ )
711
772
  assert not isinstance(c, AsyncGenerator)
712
773
  return self._to_chat_completion(c)
@@ -15,6 +15,7 @@
15
15
  import gc
16
16
  import logging
17
17
  import os
18
+ import threading
18
19
  import uuid
19
20
  from collections import defaultdict
20
21
  from collections.abc import Sequence
@@ -22,6 +23,7 @@ from typing import Dict, List, Literal, Optional, Tuple
22
23
 
23
24
  import numpy as np
24
25
  import torch
26
+ import torch.nn as nn
25
27
 
26
28
  from ...constants import XINFERENCE_CACHE_DIR
27
29
  from ...device_utils import empty_cache
@@ -49,6 +51,7 @@ class RerankModelSpec(CacheableModelSpec):
49
51
  model_name: str
50
52
  language: List[str]
51
53
  type: Optional[str] = "unknown"
54
+ max_tokens: Optional[int]
52
55
  model_id: str
53
56
  model_revision: Optional[str]
54
57
  model_hub: str = "huggingface"
@@ -102,6 +105,30 @@ def generate_rerank_description(model_spec: RerankModelSpec) -> Dict[str, List[D
102
105
  return res
103
106
 
104
107
 
108
+ class _ModelWrapper:
109
+ def __init__(self, module: nn.Module):
110
+ self._module = module
111
+ self._local_data = threading.local()
112
+
113
+ @property
114
+ def n_tokens(self):
115
+ return getattr(self._local_data, "n_tokens", 0)
116
+
117
+ @n_tokens.setter
118
+ def n_tokens(self, new_n_tokens):
119
+ self._local_data.n_tokens = new_n_tokens
120
+
121
+ def __getattr__(self, attr):
122
+ return getattr(self._module, attr)
123
+
124
+ def __call__(self, **kwargs):
125
+ attention_mask = kwargs["attention_mask"]
126
+ # when batching, the attention mask 1 means there is a token
127
+ # thus we just sum up it to get the total number of tokens
128
+ self.n_tokens += attention_mask.sum().item()
129
+ return self._module(**kwargs)
130
+
131
+
105
132
  class RerankModel:
106
133
  def __init__(
107
134
  self,
@@ -166,6 +193,7 @@ class RerankModel:
166
193
  self._model_path,
167
194
  device=self._device,
168
195
  trust_remote_code=True,
196
+ max_length=getattr(self._model_spec, "max_tokens"),
169
197
  **self._model_config,
170
198
  )
171
199
  if self._use_fp16:
@@ -189,6 +217,8 @@ class RerankModel:
189
217
 
190
218
  raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
191
219
  self._model = FlagReranker(self._model_path, use_fp16=self._use_fp16)
220
+ # Wrap transformers model to record number of tokens
221
+ self._model.model = _ModelWrapper(self._model.model)
192
222
 
193
223
  def rerank(
194
224
  self,
@@ -200,17 +230,14 @@ class RerankModel:
200
230
  return_len: Optional[bool],
201
231
  **kwargs,
202
232
  ) -> Rerank:
203
- self._counter += 1
204
- if self._counter % RERANK_EMPTY_CACHE_COUNT == 0:
205
- logger.debug("Empty rerank cache.")
206
- gc.collect()
207
- empty_cache()
208
233
  assert self._model is not None
209
234
  if kwargs:
210
235
  raise ValueError("rerank hasn't support extra parameter.")
211
236
  if max_chunks_per_doc is not None:
212
237
  raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
213
238
  sentence_combinations = [[query, doc] for doc in documents]
239
+ # reset n tokens
240
+ self._model.model.n_tokens = 0
214
241
  if self._model_spec.type == "normal":
215
242
  similarity_scores = self._model.predict(
216
243
  sentence_combinations, convert_to_numpy=False, convert_to_tensor=True
@@ -245,9 +272,7 @@ class RerankModel:
245
272
  for arg in sim_scores_argsort
246
273
  ]
247
274
  if return_len:
248
- tokenizer = self._get_tokenizer(self._model_path)
249
- input_len = sum([len(tokenizer.tokenize(t)) for t in documents])
250
-
275
+ input_len = self._model.model.n_tokens
251
276
  # Rerank Model output is just score or documents
252
277
  # while return_documents = True
253
278
  output_len = input_len
@@ -265,6 +290,14 @@ class RerankModel:
265
290
  "warnings": None,
266
291
  }
267
292
 
293
+ del similarity_scores
294
+ # clear cache if possible
295
+ self._counter += 1
296
+ if self._counter % RERANK_EMPTY_CACHE_COUNT == 0:
297
+ logger.debug("Empty rerank cache.")
298
+ gc.collect()
299
+ empty_cache()
300
+
268
301
  return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)
269
302
 
270
303
 
@@ -3,6 +3,7 @@
3
3
  "model_name": "bge-reranker-large",
4
4
  "type": "normal",
5
5
  "language": ["en", "zh"],
6
+ "max_tokens": 512,
6
7
  "model_id": "BAAI/bge-reranker-large",
7
8
  "model_revision": "27c9168d479987529781de8474dff94d69beca11"
8
9
  },
@@ -10,6 +11,7 @@
10
11
  "model_name": "bge-reranker-base",
11
12
  "type": "normal",
12
13
  "language": ["en", "zh"],
14
+ "max_tokens": 512,
13
15
  "model_id": "BAAI/bge-reranker-base",
14
16
  "model_revision": "465b4b7ddf2be0a020c8ad6e525b9bb1dbb708ae"
15
17
  },
@@ -17,6 +19,7 @@
17
19
  "model_name": "bce-reranker-base_v1",
18
20
  "type": "normal",
19
21
  "language": ["en", "zh"],
22
+ "max_tokens": 512,
20
23
  "model_id": "maidalun1020/bce-reranker-base_v1",
21
24
  "model_revision": "eaa31a577a0574e87a08959bd229ca14ce1b5496"
22
25
  },
@@ -24,6 +27,7 @@
24
27
  "model_name": "bge-reranker-v2-m3",
25
28
  "type": "normal",
26
29
  "language": ["en", "zh", "multilingual"],
30
+ "max_tokens": 8192,
27
31
  "model_id": "BAAI/bge-reranker-v2-m3",
28
32
  "model_revision": "12e974610ba9083ed95f3edf08d7e899581f4de4"
29
33
  },
@@ -31,6 +35,7 @@
31
35
  "model_name": "bge-reranker-v2-gemma",
32
36
  "type": "LLM-based",
33
37
  "language": ["en", "zh", "multilingual"],
38
+ "max_tokens": 8192,
34
39
  "model_id": "BAAI/bge-reranker-v2-gemma",
35
40
  "model_revision": "1787044f8b6fb740a9de4557c3a12377f84d9e17"
36
41
  },
@@ -38,6 +43,7 @@
38
43
  "model_name": "bge-reranker-v2-minicpm-layerwise",
39
44
  "type": "LLM-based layerwise",
40
45
  "language": ["en", "zh", "multilingual"],
46
+ "max_tokens": 2048,
41
47
  "model_id": "BAAI/bge-reranker-v2-minicpm-layerwise",
42
48
  "model_revision": "47b5332b296c4d8cb6ee2c60502cc62a0d708881"
43
49
  },
@@ -45,6 +51,7 @@
45
51
  "model_name": "jina-reranker-v2",
46
52
  "type": "normal",
47
53
  "language": ["en", "zh", "multilingual"],
54
+ "max_tokens": 1024,
48
55
  "model_id": "jinaai/jina-reranker-v2-base-multilingual",
49
56
  "model_revision": "298e48cada4a9318650d7fbd795f63827f884087"
50
57
  }
@@ -3,6 +3,7 @@
3
3
  "model_name": "bge-reranker-base",
4
4
  "type": "normal",
5
5
  "language": ["en", "zh"],
6
+ "max_tokens": 512,
6
7
  "model_id": "Xorbits/bge-reranker-base",
7
8
  "model_revision": "v0.0.1",
8
9
  "model_hub": "modelscope"
@@ -11,6 +12,7 @@
11
12
  "model_name": "bge-reranker-large",
12
13
  "type": "normal",
13
14
  "language": ["en", "zh"],
15
+ "max_tokens": 512,
14
16
  "model_id": "Xorbits/bge-reranker-large",
15
17
  "model_revision": "v0.0.1",
16
18
  "model_hub": "modelscope"
@@ -19,6 +21,7 @@
19
21
  "model_name": "bce-reranker-base_v1",
20
22
  "type": "normal",
21
23
  "language": ["en", "zh"],
24
+ "max_tokens": 512,
22
25
  "model_id": "maidalun/bce-reranker-base_v1",
23
26
  "model_revision": "v0.0.1",
24
27
  "model_hub": "modelscope"
@@ -26,6 +29,7 @@
26
29
  {
27
30
  "model_name": "bge-reranker-v2-m3",
28
31
  "type": "normal",
32
+ "max_tokens": 8192,
29
33
  "language": ["en", "zh", "multilingual"],
30
34
  "model_id": "AI-ModelScope/bge-reranker-v2-m3",
31
35
  "model_hub": "modelscope"
@@ -34,6 +38,7 @@
34
38
  "model_name": "bge-reranker-v2-gemma",
35
39
  "type": "LLM-based",
36
40
  "language": ["en", "zh", "multilingual"],
41
+ "max_tokens": 8192,
37
42
  "model_id": "AI-ModelScope/bge-reranker-v2-gemma",
38
43
  "model_hub": "modelscope"
39
44
  },
@@ -41,7 +46,8 @@
41
46
  "model_name": "bge-reranker-v2-minicpm-layerwise",
42
47
  "type": "LLM-based layerwise",
43
48
  "language": ["en", "zh", "multilingual"],
44
- "model_id": "zfffff/bge-reranker-v2-minicpm-layerwise",
49
+ "max_tokens": 2048,
50
+ "model_id": "mirror013/bge-reranker-v2-minicpm-layerwise",
45
51
  "model_hub": "modelscope"
46
52
  }
47
53
  ]
xinference/model/utils.py CHANGED
@@ -11,10 +11,6 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
- import functools
16
- import gc
17
- import inspect
18
14
  import json
19
15
  import logging
20
16
  import os
@@ -28,7 +24,7 @@ import numpy as np
28
24
  import torch
29
25
 
30
26
  from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_ENV_MODEL_SRC
31
- from ..device_utils import empty_cache, get_available_device, is_device_available
27
+ from ..device_utils import get_available_device, is_device_available
32
28
  from .core import CacheableModelSpec
33
29
 
34
30
  logger = logging.getLogger(__name__)
@@ -357,32 +353,6 @@ def convert_float_to_int_or_str(model_size: float) -> Union[int, str]:
357
353
  return str(model_size)
358
354
 
359
355
 
360
- def ensure_cache_cleared(func: Callable):
361
- assert not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
362
- func
363
- )
364
- if inspect.isgeneratorfunction(func):
365
-
366
- @functools.wraps(func)
367
- def inner(*args, **kwargs):
368
- for obj in func(*args, **kwargs):
369
- yield obj
370
- gc.collect()
371
- empty_cache()
372
-
373
- else:
374
-
375
- @functools.wraps(func)
376
- def inner(*args, **kwargs):
377
- try:
378
- return func(*args, **kwargs)
379
- finally:
380
- gc.collect()
381
- empty_cache()
382
-
383
- return inner
384
-
385
-
386
356
  def set_all_random_seed(seed: int):
387
357
  random.seed(seed)
388
358
  np.random.seed(seed)