xinference 0.14.4.post1__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (149) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +5 -39
  4. xinference/client/restful/restful_client.py +3 -24
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/model.py +82 -31
  11. xinference/core/scheduler.py +37 -37
  12. xinference/core/status_guard.py +1 -1
  13. xinference/core/supervisor.py +11 -10
  14. xinference/core/utils.py +80 -22
  15. xinference/core/worker.py +17 -16
  16. xinference/deploy/cmdline.py +8 -16
  17. xinference/deploy/local.py +1 -1
  18. xinference/deploy/supervisor.py +1 -1
  19. xinference/deploy/utils.py +1 -1
  20. xinference/deploy/worker.py +1 -1
  21. xinference/model/audio/cosyvoice.py +86 -41
  22. xinference/model/embedding/core.py +52 -31
  23. xinference/model/image/stable_diffusion/core.py +18 -1
  24. xinference/model/llm/__init__.py +21 -11
  25. xinference/model/llm/llama_cpp/core.py +16 -33
  26. xinference/model/llm/llm_family.json +619 -1297
  27. xinference/model/llm/llm_family.py +31 -52
  28. xinference/model/llm/llm_family_csghub.json +18 -35
  29. xinference/model/llm/llm_family_modelscope.json +573 -1119
  30. xinference/model/llm/lmdeploy/core.py +56 -88
  31. xinference/model/llm/mlx/core.py +46 -69
  32. xinference/model/llm/sglang/core.py +33 -18
  33. xinference/model/llm/transformers/chatglm.py +167 -305
  34. xinference/model/llm/transformers/cogvlm2.py +36 -63
  35. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  36. xinference/model/llm/transformers/core.py +49 -50
  37. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  38. xinference/model/llm/transformers/glm4v.py +55 -111
  39. xinference/model/llm/transformers/intern_vl.py +39 -70
  40. xinference/model/llm/transformers/internlm2.py +32 -54
  41. xinference/model/llm/transformers/minicpmv25.py +22 -55
  42. xinference/model/llm/transformers/minicpmv26.py +158 -68
  43. xinference/model/llm/transformers/omnilmm.py +5 -28
  44. xinference/model/llm/transformers/qwen2_vl.py +208 -0
  45. xinference/model/llm/transformers/qwen_vl.py +34 -86
  46. xinference/model/llm/transformers/utils.py +32 -38
  47. xinference/model/llm/transformers/yi_vl.py +32 -72
  48. xinference/model/llm/utils.py +195 -489
  49. xinference/model/llm/vllm/core.py +153 -100
  50. xinference/model/rerank/core.py +41 -8
  51. xinference/model/rerank/model_spec.json +7 -0
  52. xinference/model/rerank/model_spec_modelscope.json +7 -1
  53. xinference/model/utils.py +1 -31
  54. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  55. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  56. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  57. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  58. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  59. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  60. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  61. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  62. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  63. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  64. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  65. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  66. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  67. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  68. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  69. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  70. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  71. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  72. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  73. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  74. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  75. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  76. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  77. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  78. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  79. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  80. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  81. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +34 -0
  82. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  83. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  84. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  85. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  88. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  89. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  90. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  91. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  92. xinference/thirdparty/matcha/VERSION +1 -0
  93. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  94. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  95. xinference/thirdparty/omnilmm/LICENSE +201 -0
  96. xinference/thirdparty/whisper/__init__.py +156 -0
  97. xinference/thirdparty/whisper/__main__.py +3 -0
  98. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  99. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  100. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  101. xinference/thirdparty/whisper/audio.py +157 -0
  102. xinference/thirdparty/whisper/decoding.py +826 -0
  103. xinference/thirdparty/whisper/model.py +314 -0
  104. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  105. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  106. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  107. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  108. xinference/thirdparty/whisper/timing.py +386 -0
  109. xinference/thirdparty/whisper/tokenizer.py +395 -0
  110. xinference/thirdparty/whisper/transcribe.py +605 -0
  111. xinference/thirdparty/whisper/triton_ops.py +109 -0
  112. xinference/thirdparty/whisper/utils.py +316 -0
  113. xinference/thirdparty/whisper/version.py +1 -0
  114. xinference/types.py +7 -49
  115. xinference/web/ui/build/asset-manifest.json +6 -6
  116. xinference/web/ui/build/index.html +1 -1
  117. xinference/web/ui/build/static/css/{main.4bafd904.css → main.632e9148.css} +2 -2
  118. xinference/web/ui/build/static/css/main.632e9148.css.map +1 -0
  119. xinference/web/ui/build/static/js/main.9cfafbd6.js +3 -0
  120. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.9cfafbd6.js.LICENSE.txt} +2 -0
  121. xinference/web/ui/build/static/js/main.9cfafbd6.js.map +1 -0
  122. xinference/web/ui/node_modules/.cache/babel-loader/01d6d198156bacbd436c51435edbd4b2cacd47a79db929105eba30f74b67d48d.json +1 -0
  123. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  124. xinference/web/ui/node_modules/.cache/babel-loader/59eb25f514afcc4fefd1b309d192b2455f1e0aec68a9de598ca4b2333fe2c774.json +1 -0
  125. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  126. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  127. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  128. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  129. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  130. xinference/web/ui/node_modules/.package-lock.json +37 -0
  131. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  132. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  133. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  134. xinference/web/ui/package-lock.json +38 -0
  135. xinference/web/ui/package.json +1 -0
  136. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/METADATA +8 -8
  137. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/RECORD +141 -87
  138. xinference/model/llm/transformers/llama_2.py +0 -108
  139. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  140. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  141. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  142. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  144. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  145. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  146. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/LICENSE +0 -0
  147. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/WHEEL +0 -0
  148. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/entry_points.txt +0 -0
  149. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/top_level.txt +0 -0
@@ -12,25 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import logging
15
- import time
16
15
  import uuid
17
16
  from typing import AsyncGenerator, Dict, Iterator, List, Optional, TypedDict, Union
18
17
 
19
18
  import torch
20
19
 
21
- from ....types import (
22
- ChatCompletion,
23
- ChatCompletionChunk,
24
- ChatCompletionChunkChoice,
25
- ChatCompletionMessage,
26
- Completion,
27
- CompletionChoice,
28
- CompletionUsage,
29
- LoRA,
30
- )
20
+ from ....types import ChatCompletion, ChatCompletionChunk, Completion, LoRA
31
21
  from ..core import LLM
32
22
  from ..llm_family import LLMFamilyV1, LLMSpecV1
33
- from ..utils import ChatModelMixin
23
+ from ..utils import ChatModelMixin, generate_chat_completion, generate_completion_chunk
34
24
 
35
25
  logger = logging.getLogger(__name__)
36
26
 
@@ -74,8 +64,8 @@ class LMDeployGenerateConfig(TypedDict, total=False):
74
64
  repetition_penalty: Optional[float]
75
65
  ignore_eos: Optional[bool]
76
66
  random_seed: Optional[int]
77
- stop_words: Optional[List[str]]
78
- bad_words: Optional[List[str]]
67
+ stop_words: Optional[List[int]]
68
+ bad_words: Optional[List[int]]
79
69
  min_new_tokens: Optional[int]
80
70
  skip_special_tokens: Optional[bool]
81
71
  logprobs: Optional[int]
@@ -164,9 +154,6 @@ class LMDeployChatModel(LMDeployModel, ChatModelMixin):
164
154
  raise ValueError(f"Can not find correct chat template.")
165
155
 
166
156
  chat_template_config = ChatTemplateConfig(chat_temp_name)
167
- chat_template_config.meta_instruction = (
168
- self.model_family.prompt_style.system_prompt
169
- )
170
157
  count = torch.cuda.device_count()
171
158
  if count > 1:
172
159
  self._model_config.setdefault("tp", torch.cuda.device_count())
@@ -192,9 +179,7 @@ class LMDeployChatModel(LMDeployModel, ChatModelMixin):
192
179
 
193
180
  async def async_chat(
194
181
  self,
195
- prompt: Union[str, List[Dict]],
196
- system_prompt: Optional[str] = None,
197
- chat_history: Optional[List[ChatCompletionMessage]] = None,
182
+ messages: List[Dict],
198
183
  generate_config: Optional[Dict] = None,
199
184
  ) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
200
185
  stream = (
@@ -213,75 +198,69 @@ class LMDeployChatModel(LMDeployModel, ChatModelMixin):
213
198
  else False
214
199
  )
215
200
 
216
- chat_history = chat_history or []
217
-
218
201
  if stream:
219
- chunk = self._chat_stream(prompt, chat_history, include_usage)
202
+ chunk = self._chat_stream(messages, include_usage)
220
203
  return self._async_to_chat_completion_chunks(chunk)
221
204
  else:
222
- chunk = await self._chat(prompt, chat_history)
223
- return self._to_chat_completion(chunk)
205
+ return await self._chat(messages)
224
206
 
225
- async def _chat_stream(self, prompt, chat_history, include_usage):
207
+ async def _chat_stream(self, messages, include_usage):
226
208
  from lmdeploy.messages import Response
227
209
 
228
210
  prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
229
211
  completion_id = str(uuid.uuid1())
212
+ finish_reason = None
230
213
  async for output in self._generate(
231
- prompt,
232
- chat_history,
214
+ messages,
233
215
  session_id=-1,
234
216
  stream_response=True,
235
217
  ):
236
218
  new_text = output.text if isinstance(output, Response) else output.response
237
-
238
- completion_choice = ChatCompletionChunkChoice(
239
- text=new_text,
240
- index=0,
241
- logprobs=None,
242
- finish_reason=output.finish_reason,
243
- )
244
- chunk = ChatCompletionChunk(
245
- id=completion_id,
246
- object="chat.completion",
247
- created=int(time.time()),
248
- model=self.model_uid,
249
- choices=[completion_choice],
250
- )
251
219
  prompt_tokens = output.input_token_len
252
220
  completion_tokens = output.generate_token_len
253
221
  total_tokens = prompt_tokens + completion_tokens
254
- completion_usage = CompletionUsage(
222
+ finish_reason = output.finish_reason
223
+ yield generate_completion_chunk(
224
+ chunk_text=new_text,
225
+ finish_reason=None,
226
+ chunk_id=completion_id,
227
+ model_uid=self.model_uid,
255
228
  prompt_tokens=prompt_tokens,
256
229
  completion_tokens=completion_tokens,
257
230
  total_tokens=total_tokens,
258
231
  )
259
- chunk["usage"] = completion_usage
260
- print(chunk)
261
- yield chunk
232
+
233
+ yield generate_completion_chunk(
234
+ chunk_text=None,
235
+ finish_reason=finish_reason,
236
+ chunk_id=completion_id,
237
+ model_uid=self.model_uid,
238
+ prompt_tokens=prompt_tokens,
239
+ completion_tokens=completion_tokens,
240
+ total_tokens=total_tokens,
241
+ has_choice=True,
242
+ has_content=False,
243
+ )
262
244
  if include_usage:
263
- chunk = ChatCompletionChunk(
264
- id=completion_id,
265
- object="chat.completion",
266
- created=int(time.time()),
267
- model=self.model_uid,
268
- choices=[],
269
- )
270
- chunk["usage"] = CompletionUsage(
245
+ yield generate_completion_chunk(
246
+ chunk_text=None,
247
+ finish_reason=None,
248
+ chunk_id=completion_id,
249
+ model_uid=self.model_uid,
271
250
  prompt_tokens=prompt_tokens,
272
251
  completion_tokens=completion_tokens,
273
252
  total_tokens=total_tokens,
253
+ has_choice=False,
254
+ has_content=False,
274
255
  )
275
- yield chunk
276
256
 
277
- async def _chat(self, prompt, chat_history):
257
+ async def _chat(self, messages) -> ChatCompletion:
278
258
  from lmdeploy.messages import Response
279
259
 
280
- response, finish_reason = "", ""
260
+ response, finish_reason = "", None
281
261
  prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
282
262
  async for output in self._generate(
283
- prompt,
284
- chat_history,
263
+ messages,
285
264
  session_id=-1,
286
265
  stream_response=False,
287
266
  ):
@@ -291,30 +270,20 @@ class LMDeployChatModel(LMDeployModel, ChatModelMixin):
291
270
  total_tokens = output.input_token_len + output.generate_token_len
292
271
  finish_reason = output.finish_reason
293
272
 
294
- chunk = ChatCompletion(
295
- id=str(uuid.uuid1()),
296
- object="chat.completion",
297
- created=int(time.time()),
298
- model=self.model_uid,
299
- choices=[
300
- CompletionChoice(
301
- index=0, text=response, finish_reason=finish_reason, logprobs=None
302
- )
303
- ],
304
- usage=CompletionUsage(
305
- prompt_tokens=prompt_tokens,
306
- completion_tokens=completion_tokens,
307
- total_tokens=total_tokens,
308
- ),
273
+ return generate_chat_completion(
274
+ self.model_uid,
275
+ response,
276
+ prompt_tokens=prompt_tokens,
277
+ completion_tokens=completion_tokens,
278
+ total_tokens=total_tokens,
279
+ finish_reason=finish_reason,
309
280
  )
310
- return chunk
311
281
 
312
282
  # copy from lmdeploy
313
283
  # Reference: lmdeploy.serve.async_engine.py
314
284
  async def _generate(
315
285
  self,
316
- prompt,
317
- chat_history,
286
+ messages: List[Dict],
318
287
  session_id: int,
319
288
  generate_config: Optional[Dict] = None,
320
289
  tools: Optional[List[object]] = None,
@@ -332,6 +301,8 @@ class LMDeployChatModel(LMDeployModel, ChatModelMixin):
332
301
  from lmdeploy.serve.async_engine import GenOut
333
302
  from lmdeploy.tokenizer import DetokenizeState
334
303
 
304
+ from ..utils import get_stop_token_ids_from_config_file
305
+
335
306
  session_id = -1
336
307
 
337
308
  if str(session_id) not in self._model.id2step:
@@ -343,7 +314,9 @@ class LMDeployChatModel(LMDeployModel, ChatModelMixin):
343
314
  generate_config, self._model.tokenizer
344
315
  )
345
316
  if generate_config.stop_words is None: # type: ignore
346
- generate_config.stop_words = self._model.stop_words # type: ignore
317
+ stop_token_ids = get_stop_token_ids_from_config_file(self.model_path)
318
+ if stop_token_ids is not None:
319
+ generate_config.stop_words = stop_token_ids # type: ignore
347
320
  if generate_config.random_seed is None and sequence_start: # type: ignore
348
321
  generate_config.random_seed = random.getrandbits(64) # type: ignore
349
322
  if generate_config.n > 1: # type: ignore
@@ -353,7 +326,7 @@ class LMDeployChatModel(LMDeployModel, ChatModelMixin):
353
326
  )
354
327
  generate_config.n = 1 # type: ignore
355
328
 
356
- prompt_input = await self._get_prompt_input(prompt, chat_history)
329
+ prompt_input = await self._get_prompt_input(messages)
357
330
  prompt = prompt_input["prompt"]
358
331
  input_ids = prompt_input["input_ids"]
359
332
  finish_reason = None
@@ -482,8 +455,7 @@ class LMDeployChatModel(LMDeployModel, ChatModelMixin):
482
455
  # Reference: lmdeploy.serve.vl_async_engine.py
483
456
  async def _get_prompt_input(
484
457
  self,
485
- prompt: Union[str, List[Dict]],
486
- chat_history: Optional[List[ChatCompletionMessage]] = None,
458
+ messages: List[Dict],
487
459
  sequence_start: bool = True,
488
460
  tools: Optional[List[object]] = None,
489
461
  **kwargs,
@@ -493,13 +465,9 @@ class LMDeployChatModel(LMDeployModel, ChatModelMixin):
493
465
  IMAGE_DUMMY_TOKEN_INDEX = 0
494
466
  import numpy as np
495
467
 
496
- assert self.model_family.prompt_style is not None
497
- prompt_style = self.model_family.prompt_style.copy()
498
- chat_history = chat_history or []
499
-
500
- decorated, _ = self.get_prompt(prompt, chat_history, prompt_style) # type: ignore
501
- chat_history.append(ChatCompletionMessage(role="user", content=prompt)) # type: ignore
502
- prompt = chat_history # type: ignore
468
+ model_family = self.model_family.model_family or self.model_family.model_name
469
+ decorated, _ = self.get_specific_prompt(model_family, messages) # type: ignore
470
+ prompt = messages # type: ignore
503
471
 
504
472
  decorated = decorated.replace("<image>", "<img><IMAGE_TOKEN></img>")
505
473
 
@@ -17,22 +17,20 @@ import platform
17
17
  import sys
18
18
  import time
19
19
  import uuid
20
- from typing import Dict, Iterable, Iterator, List, Optional, TypedDict, Union
20
+ from typing import Dict, Iterator, List, Optional, TypedDict, Union
21
21
 
22
22
  from ....fields import max_tokens_field
23
23
  from ....types import (
24
24
  ChatCompletion,
25
25
  ChatCompletionChunk,
26
- ChatCompletionMessage,
27
26
  Completion,
28
- CompletionChoice,
29
27
  CompletionChunk,
30
28
  CompletionUsage,
31
29
  LoRA,
32
30
  )
33
31
  from ..core import LLM
34
32
  from ..llm_family import LLMFamilyV1, LLMSpecV1
35
- from ..utils import ChatModelMixin
33
+ from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin, generate_completion_chunk
36
34
 
37
35
  logger = logging.getLogger(__name__)
38
36
 
@@ -54,6 +52,7 @@ class MLXGenerateConfig(TypedDict, total=False):
54
52
  stop_token_ids: Optional[Union[int, List[int]]]
55
53
  stream: bool
56
54
  stream_options: Optional[Union[dict, None]]
55
+ tools: Optional[List[Dict]]
57
56
 
58
57
 
59
58
  class MLXModel(LLM):
@@ -211,23 +210,21 @@ class MLXModel(LLM):
211
210
  else:
212
211
  output += out
213
212
 
214
- completion_choice = CompletionChoice(
215
- text=output, index=0, logprobs=None, finish_reason=None
216
- )
217
- completion_chunk = CompletionChunk(
218
- id=chunk_id,
219
- object="text_completion",
220
- created=int(time.time()),
221
- model=model_uid,
222
- choices=[completion_choice],
223
- )
224
213
  completion_usage = CompletionUsage(
225
214
  prompt_tokens=input_echo_len,
226
215
  completion_tokens=i,
227
216
  total_tokens=(input_echo_len + i),
228
217
  )
229
218
 
230
- yield completion_chunk, completion_usage
219
+ yield generate_completion_chunk(
220
+ chunk_text=output,
221
+ finish_reason=None,
222
+ chunk_id=chunk_id,
223
+ model_uid=model_uid,
224
+ prompt_tokens=input_echo_len,
225
+ completion_tokens=i,
226
+ total_tokens=(input_echo_len + i),
227
+ ), completion_usage
231
228
 
232
229
  logger.info(
233
230
  f"Average generation speed: {i / (time.time() - start):.2f} tokens/s."
@@ -238,29 +235,31 @@ class MLXModel(LLM):
238
235
  else:
239
236
  finish_reason = "stop"
240
237
 
241
- if stream:
242
- completion_choice = CompletionChoice(
243
- text="", index=0, logprobs=None, finish_reason=finish_reason
244
- )
245
- else:
246
- completion_choice = CompletionChoice(
247
- text=output, index=0, logprobs=None, finish_reason=finish_reason
248
- )
249
-
250
- completion_chunk = CompletionChunk(
251
- id=chunk_id,
252
- object="text_completion",
253
- created=int(time.time()),
254
- model=model_uid,
255
- choices=[completion_choice],
256
- )
257
238
  completion_usage = CompletionUsage(
258
239
  prompt_tokens=input_echo_len,
259
240
  completion_tokens=i,
260
241
  total_tokens=(input_echo_len + i),
261
242
  )
262
-
263
- yield completion_chunk, completion_usage
243
+ if stream:
244
+ yield generate_completion_chunk(
245
+ "",
246
+ finish_reason=finish_reason,
247
+ chunk_id=chunk_id,
248
+ model_uid=model_uid,
249
+ prompt_tokens=input_echo_len,
250
+ completion_tokens=i,
251
+ total_tokens=(input_echo_len + i),
252
+ ), completion_usage
253
+ else:
254
+ yield generate_completion_chunk(
255
+ output,
256
+ finish_reason=finish_reason,
257
+ chunk_id=chunk_id,
258
+ model_uid=model_uid,
259
+ prompt_tokens=input_echo_len,
260
+ completion_tokens=i,
261
+ total_tokens=(input_echo_len + i),
262
+ ), completion_usage
264
263
 
265
264
  if include_usage:
266
265
  completion_chunk = CompletionChunk(
@@ -270,11 +269,6 @@ class MLXModel(LLM):
270
269
  model=model_uid,
271
270
  choices=[],
272
271
  )
273
- completion_usage = CompletionUsage(
274
- prompt_tokens=input_echo_len,
275
- completion_tokens=i,
276
- total_tokens=(input_echo_len + i),
277
- )
278
272
  yield completion_chunk, completion_usage
279
273
 
280
274
  def generate(
@@ -345,20 +339,13 @@ class MLXChatModel(MLXModel, ChatModelMixin):
345
339
  generate_config: Optional[MLXGenerateConfig],
346
340
  ) -> MLXGenerateConfig:
347
341
  generate_config = super()._sanitize_generate_config(generate_config)
348
- if (
349
- (not generate_config.get("stop"))
350
- and self.model_family.prompt_style
351
- and self.model_family.prompt_style.stop
352
- ):
353
- generate_config["stop"] = self.model_family.prompt_style.stop.copy()
342
+ if (not generate_config.get("stop")) and self.model_family.stop:
343
+ generate_config["stop"] = self.model_family.stop.copy()
354
344
  if (
355
345
  generate_config.get("stop_token_ids", None) is None
356
- and self.model_family.prompt_style
357
- and self.model_family.prompt_style.stop_token_ids
346
+ and self.model_family.stop_token_ids
358
347
  ):
359
- generate_config[
360
- "stop_token_ids"
361
- ] = self.model_family.prompt_style.stop_token_ids.copy()
348
+ generate_config["stop_token_ids"] = self.model_family.stop_token_ids.copy()
362
349
 
363
350
  return generate_config
364
351
 
@@ -377,28 +364,20 @@ class MLXChatModel(MLXModel, ChatModelMixin):
377
364
 
378
365
  def chat(
379
366
  self,
380
- prompt: str,
381
- system_prompt: Optional[str] = None,
382
- chat_history: Optional[List[ChatCompletionMessage]] = None,
367
+ messages: List[Dict],
383
368
  generate_config: Optional[MLXGenerateConfig] = None,
384
369
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
385
- tools = generate_config.pop("tools", []) if generate_config else None # type: ignore
386
- full_prompt = self.get_full_prompt(
387
- self.model_family, prompt, system_prompt, chat_history, tools
370
+ model_family = self.model_family.model_family or self.model_family.model_name
371
+ tools = generate_config.pop("tools", []) if generate_config else None
372
+ full_context_kwargs = {}
373
+ if tools and model_family in QWEN_TOOL_CALL_FAMILY:
374
+ full_context_kwargs["tools"] = tools
375
+ assert self.model_family.chat_template is not None
376
+ full_prompt = self.get_full_context(
377
+ messages, self.model_family.chat_template, **full_context_kwargs
388
378
  )
389
379
 
390
380
  generate_config = self._sanitize_generate_config(generate_config)
391
- # TODO(codingl2k1): qwen hacky to set stop for function call.
392
- model_family = self.model_family.model_family or self.model_family.model_name
393
- if tools and model_family in ["qwen-chat", "qwen1.5-chat"]:
394
- stop = generate_config.get("stop")
395
- if isinstance(stop, str):
396
- generate_config["stop"] = [stop, "Observation:"]
397
- elif isinstance(stop, Iterable):
398
- assert not isinstance(stop, str)
399
- generate_config["stop"] = list(stop) + ["Observation:"]
400
- else:
401
- generate_config["stop"] = "Observation:"
402
381
 
403
382
  stream = generate_config.get("stream", False)
404
383
  if stream:
@@ -409,7 +388,5 @@ class MLXChatModel(MLXModel, ChatModelMixin):
409
388
  c = self.generate(full_prompt, generate_config)
410
389
  assert not isinstance(c, Iterator)
411
390
  if tools:
412
- return self._tool_calls_completion(
413
- self.model_family, self.model_uid, c, tools
414
- )
391
+ return self._tool_calls_completion(self.model_family, self.model_uid, c)
415
392
  return self._to_chat_completion(c)
@@ -21,7 +21,6 @@ from typing import AsyncGenerator, Dict, List, Optional, TypedDict, Union
21
21
  from ....types import (
22
22
  ChatCompletion,
23
23
  ChatCompletionChunk,
24
- ChatCompletionMessage,
25
24
  Completion,
26
25
  CompletionChoice,
27
26
  CompletionChunk,
@@ -29,7 +28,7 @@ from ....types import (
29
28
  )
30
29
  from .. import LLM, LLMFamilyV1, LLMSpecV1
31
30
  from ..llm_family import CustomLLMFamilyV1
32
- from ..utils import ChatModelMixin
31
+ from ..utils import ChatModelMixin, generate_completion_chunk
33
32
 
34
33
  logger = logging.getLogger(__name__)
35
34
 
@@ -319,6 +318,7 @@ class SGLANGModel(LLM):
319
318
  self,
320
319
  prompt: str,
321
320
  generate_config: Optional[SGLANGGenerateConfig] = None,
321
+ request_id: Optional[str] = None,
322
322
  ) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]:
323
323
  sanitized_generate_config = self._sanitize_generate_config(generate_config)
324
324
  logger.debug(
@@ -332,8 +332,8 @@ class SGLANGModel(LLM):
332
332
  if isinstance(stream_options, dict)
333
333
  else False
334
334
  )
335
-
336
- request_id = str(uuid.uuid1())
335
+ if not request_id:
336
+ request_id = str(uuid.uuid1())
337
337
  if not stream:
338
338
  state = await self._non_stream_generate(prompt, **sanitized_generate_config)
339
339
  return self._convert_state_to_completion(
@@ -346,12 +346,14 @@ class SGLANGModel(LLM):
346
346
 
347
347
  async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
348
348
  prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
349
+ finish_reason = None
349
350
  async for meta_info, out in self._stream_generate(
350
351
  prompt, **sanitized_generate_config
351
352
  ):
352
353
  chunk = self._convert_state_to_completion_chunk(
353
354
  request_id, self.model_uid, output_text=out
354
355
  )
356
+ finish_reason = meta_info["finish_reason"]
355
357
  prompt_tokens = meta_info["prompt_tokens"]
356
358
  completion_tokens = meta_info["completion_tokens"]
357
359
  total_tokens = prompt_tokens + completion_tokens
@@ -361,6 +363,26 @@ class SGLANGModel(LLM):
361
363
  total_tokens=total_tokens,
362
364
  )
363
365
  yield chunk
366
+
367
+ finish_reason = (
368
+ "stop"
369
+ if finish_reason is None
370
+ or (
371
+ isinstance(finish_reason, str)
372
+ and finish_reason.lower() == "none"
373
+ )
374
+ else finish_reason
375
+ )
376
+ yield generate_completion_chunk(
377
+ "",
378
+ finish_reason=finish_reason,
379
+ chunk_id=request_id,
380
+ model_uid=self.model_uid,
381
+ prompt_tokens=prompt_tokens,
382
+ completion_tokens=completion_tokens,
383
+ total_tokens=total_tokens,
384
+ )
385
+
364
386
  if include_usage:
365
387
  chunk = CompletionChunk(
366
388
  id=request_id,
@@ -409,26 +431,19 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
409
431
  ) -> Dict:
410
432
  if not generate_config:
411
433
  generate_config = {}
412
- if self.model_family.prompt_style:
413
- if (
414
- not generate_config.get("stop")
415
- ) and self.model_family.prompt_style.stop:
416
- generate_config["stop"] = self.model_family.prompt_style.stop.copy()
434
+ if self.model_family.stop:
435
+ if (not generate_config.get("stop")) and self.model_family.stop:
436
+ generate_config["stop"] = self.model_family.stop.copy()
417
437
  return generate_config
418
438
 
419
439
  async def async_chat(
420
440
  self,
421
- prompt: str,
422
- system_prompt: Optional[str] = None,
423
- chat_history: Optional[List[ChatCompletionMessage]] = None,
441
+ messages: List[Dict],
424
442
  generate_config: Optional[Dict] = None,
443
+ request_id: Optional[str] = None,
425
444
  ) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
426
- assert self.model_family.prompt_style is not None
427
- prompt_style = self.model_family.prompt_style.copy()
428
- if system_prompt:
429
- prompt_style.system_prompt = system_prompt
430
- chat_history = chat_history or []
431
- full_prompt = self.get_prompt(prompt, chat_history, prompt_style)
445
+ assert self.model_family.chat_template is not None
446
+ full_prompt = self.get_full_context(messages, self.model_family.chat_template)
432
447
 
433
448
  generate_config = self._sanitize_chat_config(generate_config)
434
449
  stream = generate_config.get("stream", None)