xinference 0.13.2__py3-none-any.whl → 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (103) hide show
  1. xinference/__init__.py +0 -1
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +30 -5
  4. xinference/client/restful/restful_client.py +18 -3
  5. xinference/constants.py +0 -4
  6. xinference/core/chat_interface.py +2 -2
  7. xinference/core/image_interface.py +6 -3
  8. xinference/core/model.py +9 -4
  9. xinference/core/scheduler.py +4 -4
  10. xinference/core/supervisor.py +2 -0
  11. xinference/core/worker.py +7 -0
  12. xinference/deploy/utils.py +6 -0
  13. xinference/model/audio/core.py +9 -4
  14. xinference/model/audio/cosyvoice.py +136 -0
  15. xinference/model/audio/model_spec.json +24 -0
  16. xinference/model/audio/model_spec_modelscope.json +27 -0
  17. xinference/model/core.py +25 -4
  18. xinference/model/embedding/core.py +88 -13
  19. xinference/model/embedding/model_spec.json +8 -0
  20. xinference/model/embedding/model_spec_modelscope.json +8 -0
  21. xinference/model/flexible/core.py +8 -2
  22. xinference/model/flexible/launchers/__init__.py +1 -0
  23. xinference/model/flexible/launchers/image_process_launcher.py +70 -0
  24. xinference/model/image/core.py +8 -5
  25. xinference/model/image/model_spec.json +36 -5
  26. xinference/model/image/model_spec_modelscope.json +21 -3
  27. xinference/model/image/stable_diffusion/core.py +36 -28
  28. xinference/model/llm/core.py +6 -4
  29. xinference/model/llm/ggml/llamacpp.py +7 -5
  30. xinference/model/llm/llm_family.json +802 -82
  31. xinference/model/llm/llm_family.py +6 -6
  32. xinference/model/llm/llm_family_csghub.json +39 -0
  33. xinference/model/llm/llm_family_modelscope.json +295 -47
  34. xinference/model/llm/mlx/core.py +7 -0
  35. xinference/model/llm/pytorch/chatglm.py +246 -5
  36. xinference/model/llm/pytorch/cogvlm2.py +1 -1
  37. xinference/model/llm/pytorch/deepseek_vl.py +2 -1
  38. xinference/model/llm/pytorch/falcon.py +2 -1
  39. xinference/model/llm/pytorch/llama_2.py +4 -2
  40. xinference/model/llm/pytorch/omnilmm.py +2 -1
  41. xinference/model/llm/pytorch/qwen_vl.py +2 -1
  42. xinference/model/llm/pytorch/vicuna.py +2 -1
  43. xinference/model/llm/pytorch/yi_vl.py +2 -1
  44. xinference/model/llm/sglang/core.py +12 -6
  45. xinference/model/llm/utils.py +78 -1
  46. xinference/model/llm/vllm/core.py +9 -5
  47. xinference/model/rerank/core.py +4 -3
  48. xinference/thirdparty/cosyvoice/__init__.py +0 -0
  49. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  50. xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
  51. xinference/thirdparty/cosyvoice/bin/train.py +136 -0
  52. xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
  53. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
  54. xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
  55. xinference/thirdparty/cosyvoice/cli/model.py +60 -0
  56. xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
  57. xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
  58. xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
  59. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  60. xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
  61. xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
  62. xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
  63. xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
  64. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  65. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
  66. xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
  67. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  68. xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
  69. xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
  70. xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
  71. xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
  72. xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
  73. xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
  74. xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
  77. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
  78. xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
  79. xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  80. xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
  81. xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
  82. xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
  83. xinference/thirdparty/cosyvoice/utils/common.py +103 -0
  84. xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
  85. xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
  86. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
  87. xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
  88. xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
  89. xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
  90. xinference/web/ui/build/asset-manifest.json +3 -3
  91. xinference/web/ui/build/index.html +1 -1
  92. xinference/web/ui/build/static/js/{main.95c1d652.js → main.af906659.js} +3 -3
  93. xinference/web/ui/build/static/js/main.af906659.js.map +1 -0
  94. xinference/web/ui/node_modules/.cache/babel-loader/2cd5e4279ad7e13a1f41d486e9fca7756295bfad5bd77d90992f4ac3e10b496d.json +1 -0
  95. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/METADATA +39 -11
  96. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/RECORD +101 -57
  97. xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
  98. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
  99. /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.af906659.js.LICENSE.txt} +0 -0
  100. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/LICENSE +0 -0
  101. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/WHEEL +0 -0
  102. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/entry_points.txt +0 -0
  103. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/top_level.txt +0 -0
@@ -11,10 +11,17 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import copy
15
+ import json
16
+ import threading
14
17
  import time
15
18
  import uuid
16
19
  from typing import Any, Dict, Iterator, List, Optional, Union
17
20
 
21
+ import torch
22
+ from transformers.generation.logits_process import LogitsProcessor
23
+ from transformers.generation.utils import LogitsProcessorList
24
+
18
25
  from ....core.scheduler import InferenceRequest
19
26
  from ....types import (
20
27
  SPECIAL_TOOL_PROMPT,
@@ -33,6 +40,16 @@ from ..utils import GLM4_TOOL_CALL_FAMILY
33
40
  from .core import PytorchChatModel, PytorchModelConfig
34
41
 
35
42
 
43
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
44
+ def __call__(
45
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
46
+ ) -> torch.FloatTensor:
47
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
48
+ scores.zero_()
49
+ scores[..., 198] = 5e4
50
+ return scores
51
+
52
+
36
53
  class ChatglmPytorchChatModel(PytorchChatModel):
37
54
  def __init__(
38
55
  self,
@@ -103,9 +120,11 @@ class ChatglmPytorchChatModel(PytorchChatModel):
103
120
  tools = generate_config.pop("tools", None)
104
121
  if tools is None:
105
122
  return False
123
+ # Convert a iterable to a list
124
+ tools = list(tools)
106
125
  tool_choice = generate_config.pop("tool_choice", "none")
107
126
  if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
108
- chat_history[:] = self.process_messages(
127
+ chat_history[:] = self._process_messages(
109
128
  chat_history, tools=tools, tool_choice=tool_choice
110
129
  )
111
130
  return True
@@ -124,7 +143,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
124
143
  return True
125
144
 
126
145
  @staticmethod
127
- def process_messages(messages, tools=None, tool_choice="none"):
146
+ def _process_messages(messages, tools=None, tool_choice="none"):
128
147
  # This method is adapted from https://github.com/THUDM/GLM-4/blob/main/basic_demo/openai_api_server.py
129
148
  _messages = messages
130
149
  processed_messages = []
@@ -210,6 +229,212 @@ class ChatglmPytorchChatModel(PytorchChatModel):
210
229
  break
211
230
  return processed_messages
212
231
 
232
+ @staticmethod
233
+ def _process_response(output, history, tools, end=False):
234
+ # Copy from https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/modeling_chatglm.py
235
+ content = ""
236
+ history = copy.deepcopy(history)
237
+ if not tools and end:
238
+ return None, None
239
+ for response in output.split("<|assistant|>"):
240
+ if "\n" in response:
241
+ metadata, content = response.split("\n", maxsplit=1)
242
+ else:
243
+ metadata, content = "", response
244
+ if not metadata.strip():
245
+ if tools and any(t.startswith(response) for t in tools) and not end:
246
+ # Waiting for tool call complete.
247
+ return None, None
248
+ content = content.strip()
249
+ history.append(
250
+ {"role": "assistant", "metadata": metadata, "content": content}
251
+ )
252
+ content = content.replace("[[训练时间]]", "2023年")
253
+ else:
254
+ if tools and metadata in tools and not end:
255
+ return None, None
256
+ history.append(
257
+ {"role": "assistant", "metadata": metadata, "content": content}
258
+ )
259
+ metadata = metadata.strip()
260
+ if tools and metadata in tools and end:
261
+ try:
262
+ parameters = json.loads(content)
263
+ content = {"name": metadata.strip(), "parameters": parameters}
264
+ except json.JSONDecodeError:
265
+ content = {"name": metadata.strip(), "content": content}
266
+ else:
267
+ content = {"name": metadata.strip(), "content": content}
268
+ return content, history
269
+
270
+ def _get_generate_args(
271
+ self,
272
+ tokenizer,
273
+ query: str,
274
+ history: Optional[List[Dict]] = None,
275
+ role: str = "user",
276
+ past_key_values=None,
277
+ max_length: int = 8192,
278
+ do_sample=True,
279
+ top_p=0.8,
280
+ temperature=0.8,
281
+ logits_processor=None,
282
+ **kwargs,
283
+ ):
284
+ # Copy from https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/modeling_chatglm.py
285
+ if history is None:
286
+ history = []
287
+ if logits_processor is None:
288
+ logits_processor = LogitsProcessorList()
289
+ logits_processor.append(InvalidScoreLogitsProcessor())
290
+ eos_token_id = [
291
+ tokenizer.eos_token_id,
292
+ tokenizer.convert_tokens_to_ids("<|user|>"),
293
+ tokenizer.convert_tokens_to_ids("<|observation|>"),
294
+ ]
295
+ gen_kwargs = {
296
+ "max_length": max_length,
297
+ "do_sample": do_sample,
298
+ "top_p": top_p,
299
+ "temperature": temperature,
300
+ "logits_processor": logits_processor,
301
+ **kwargs,
302
+ }
303
+ if past_key_values is None:
304
+ inputs = tokenizer.apply_chat_template(
305
+ history + [{"role": role, "content": query}],
306
+ add_generation_prompt=True,
307
+ tokenize=True,
308
+ return_tensors="pt",
309
+ return_dict=True,
310
+ )
311
+ else:
312
+ inputs = tokenizer.apply_chat_template(
313
+ [{"role": role, "content": query}],
314
+ add_special_tokens=False,
315
+ add_generation_prompt=True,
316
+ tokenize=True,
317
+ return_tensors="pt",
318
+ return_dict=True,
319
+ )
320
+ inputs = inputs.to(self._model.device)
321
+ if past_key_values is not None:
322
+ past_length = past_key_values[0][0].shape[2]
323
+ inputs.position_ids += past_length
324
+ attention_mask = inputs.attention_mask
325
+ attention_mask = torch.cat(
326
+ (attention_mask.new_ones(1, past_length), attention_mask), dim=1
327
+ )
328
+ inputs["attention_mask"] = attention_mask
329
+ history.append({"role": role, "content": query})
330
+ tools = history[0]["role"] == "system" and history[0].get("tools")
331
+ tools = (
332
+ [
333
+ t.get("function", {}).get("name", "")
334
+ for t in tools
335
+ if isinstance(t, dict)
336
+ ]
337
+ if tools
338
+ else []
339
+ )
340
+ kwargs = dict(inputs)
341
+ kwargs["past_key_values"] = past_key_values
342
+ kwargs["eos_token_id"] = eos_token_id
343
+ kwargs.update(gen_kwargs)
344
+ return kwargs, tools
345
+
346
+ @torch.inference_mode()
347
+ def stream_chat(
348
+ self,
349
+ tokenizer,
350
+ query: str,
351
+ history: Optional[List[Dict]] = None,
352
+ role: str = "user",
353
+ past_key_values=None,
354
+ max_length: int = 8192,
355
+ do_sample=True,
356
+ top_p=0.8,
357
+ temperature=0.8,
358
+ logits_processor=None,
359
+ **kwargs,
360
+ ):
361
+ from transformers import TextIteratorStreamer
362
+
363
+ kwargs, tools = self._get_generate_args(
364
+ tokenizer=tokenizer,
365
+ query=query,
366
+ history=history,
367
+ role=role,
368
+ past_key_values=past_key_values,
369
+ max_length=max_length,
370
+ do_sample=do_sample,
371
+ top_p=top_p,
372
+ temperature=temperature,
373
+ logits_processor=logits_processor,
374
+ **kwargs,
375
+ )
376
+
377
+ streamer = TextIteratorStreamer(
378
+ tokenizer, skip_prompt=True, skip_special_tokens=True
379
+ )
380
+ kwargs["streamer"] = streamer
381
+ thread = threading.Thread(target=self._model.generate, kwargs=kwargs)
382
+ thread.start()
383
+
384
+ response = ""
385
+ for token in streamer:
386
+ response += token
387
+ if response and response[-1] != "�":
388
+ new_response, new_history = self._process_response(
389
+ response, history, tools, end=False
390
+ )
391
+ if new_response is None:
392
+ continue
393
+ yield new_response, new_history
394
+ if tools:
395
+ new_response, new_history = self._process_response(
396
+ response, history, tools, end=True
397
+ )
398
+ if new_response:
399
+ yield new_response, new_history
400
+
401
+ @torch.inference_mode()
402
+ def non_stream_chat(
403
+ self,
404
+ tokenizer,
405
+ query: str,
406
+ history: Optional[List[Dict]] = None,
407
+ role: str = "user",
408
+ past_key_values=None,
409
+ max_length: int = 8192,
410
+ do_sample=True,
411
+ top_p=0.8,
412
+ temperature=0.8,
413
+ logits_processor=None,
414
+ **kwargs,
415
+ ):
416
+ kwargs, tools = self._get_generate_args(
417
+ tokenizer=tokenizer,
418
+ query=query,
419
+ history=history,
420
+ role=role,
421
+ past_key_values=past_key_values,
422
+ max_length=max_length,
423
+ do_sample=do_sample,
424
+ top_p=top_p,
425
+ temperature=temperature,
426
+ logits_processor=logits_processor,
427
+ **kwargs,
428
+ )
429
+
430
+ outputs = self._model.generate(**kwargs)
431
+ outputs = outputs[:, kwargs["input_ids"].shape[1] :]
432
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
433
+ if tools:
434
+ return self._process_response(response, history, tools, end=True)
435
+ else:
436
+ return self._process_response(response, history, tools)
437
+
213
438
  def chat(
214
439
  self,
215
440
  prompt: str,
@@ -247,7 +472,13 @@ class ChatglmPytorchChatModel(PytorchChatModel):
247
472
  if isinstance(stream_options, dict)
248
473
  else False
249
474
  )
250
- if stream and not tools:
475
+ if stream and (
476
+ not tools or self.model_family.model_name in GLM4_TOOL_CALL_FAMILY
477
+ ):
478
+ if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
479
+ stream_chat = self.stream_chat
480
+ else:
481
+ stream_chat = self._model.stream_chat
251
482
 
252
483
  def _stream_generator():
253
484
  last_chunk_text_length = 0
@@ -256,9 +487,14 @@ class ChatglmPytorchChatModel(PytorchChatModel):
256
487
  inputs = self._tokenizer([prompt], return_tensors="pt")
257
488
  inputs = inputs.to(self._model.device)
258
489
  prompt_tokens = len(inputs["input_ids"][0])
259
- for chunk_text, _ in self._model.stream_chat(
490
+ for chunk_text, _ in stream_chat(
260
491
  self._tokenizer, prompt, chat_history, **kwargs
261
492
  ):
493
+ if tools and isinstance(chunk_text, dict):
494
+ yield self._tool_calls_completion_chunk(
495
+ self.model_family, self.model_uid, [chunk_text, _], tools
496
+ )
497
+ return
262
498
  completion_tokens = completion_tokens + 1
263
499
  total_tokens = prompt_tokens + completion_tokens
264
500
  chunk_text = chunk_text[last_chunk_text_length:]
@@ -312,7 +548,12 @@ class ChatglmPytorchChatModel(PytorchChatModel):
312
548
 
313
549
  return self._to_chat_completion_chunks(_stream_generator())
314
550
  else:
315
- response = self._model.chat(self._tokenizer, prompt, chat_history, **kwargs)
551
+ if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
552
+ chat = self.non_stream_chat
553
+ else:
554
+ chat = self._model.chat
555
+
556
+ response = chat(self._tokenizer, prompt, chat_history, **kwargs)
316
557
  if tools:
317
558
  return self._tool_calls_completion(
318
559
  self.model_family, self.model_uid, response, tools
@@ -387,7 +387,7 @@ class CogVLM2Model(PytorchChatModel):
387
387
  prompt, system_prompt=system_prompt, chat_history=chat_history
388
388
  )
389
389
 
390
- input_by_model: dict = self._model.build_conversation_input_ids(
390
+ input_by_model: dict = self._model.build_conversation_input_ids( # type: ignore
391
391
  self._tokenizer,
392
392
  query=query,
393
393
  history=history,
@@ -52,7 +52,8 @@ class DeepSeekVLChatModel(PytorchChatModel):
52
52
  def match(
53
53
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
54
54
  ) -> bool:
55
- if "deepseek" in model_family.model_name:
55
+ llm_family = model_family.model_family or model_family.model_name
56
+ if "deepseek-vl" in llm_family:
56
57
  return True
57
58
  return False
58
59
 
@@ -71,7 +71,8 @@ class FalconPytorchModel(PytorchModel):
71
71
  ) -> bool:
72
72
  if llm_spec.model_format != "pytorch":
73
73
  return False
74
- if "falcon" not in llm_family.model_name:
74
+ model_family = llm_family.model_family or llm_family.model_name
75
+ if "falcon" not in model_family:
75
76
  return False
76
77
  if "generate" not in llm_family.model_ability:
77
78
  return False
@@ -55,7 +55,8 @@ class LlamaPytorchModel(PytorchModel):
55
55
  ) -> bool:
56
56
  if llm_spec.model_format != "pytorch":
57
57
  return False
58
- if "llama-2" not in llm_family.model_name:
58
+ model_family = llm_family.model_family or llm_family.model_name
59
+ if "llama-2" not in model_family:
59
60
  return False
60
61
  if "generate" not in llm_family.model_ability:
61
62
  return False
@@ -99,7 +100,8 @@ class LlamaPytorchChatModel(PytorchChatModel):
99
100
  ) -> bool:
100
101
  if llm_spec.model_format != "pytorch":
101
102
  return False
102
- if "llama-2" not in llm_family.model_name:
103
+ model_family = llm_family.model_family or llm_family.model_name
104
+ if "llama-2" not in model_family:
103
105
  return False
104
106
  if "chat" not in llm_family.model_ability:
105
107
  return False
@@ -44,7 +44,8 @@ class OmniLMMModel(PytorchChatModel):
44
44
  def match(
45
45
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
46
46
  ) -> bool:
47
- if "OmniLMM" in model_family.model_name:
47
+ llm_family = model_family.model_family or model_family.model_name
48
+ if "OmniLMM" in llm_family:
48
49
  return True
49
50
  return False
50
51
 
@@ -52,7 +52,8 @@ class QwenVLChatModel(PytorchChatModel):
52
52
  def match(
53
53
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
54
54
  ) -> bool:
55
- if "qwen" in model_family.model_name and "vision" in model_family.model_ability:
55
+ llm_family = model_family.model_family or model_family.model_name
56
+ if "qwen" in llm_family and "vision" in model_family.model_ability:
56
57
  return True
57
58
  return False
58
59
 
@@ -61,7 +61,8 @@ class VicunaPytorchChatModel(PytorchChatModel):
61
61
  ) -> bool:
62
62
  if llm_spec.model_format != "pytorch":
63
63
  return False
64
- if "vicuna" not in llm_family.model_name:
64
+ model_family = llm_family.model_family or llm_family.model_name
65
+ if "vicuna" not in model_family:
65
66
  return False
66
67
  if "chat" not in llm_family.model_ability:
67
68
  return False
@@ -51,7 +51,8 @@ class YiVLChatModel(PytorchChatModel):
51
51
  def match(
52
52
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
53
53
  ) -> bool:
54
- if "yi" in model_family.model_name:
54
+ llm_family = model_family.model_family or model_family.model_name
55
+ if "yi-vl" in llm_family:
55
56
  return True
56
57
  return False
57
58
 
@@ -17,7 +17,6 @@ import time
17
17
  import uuid
18
18
  from typing import AsyncGenerator, Dict, List, Optional, TypedDict, Union
19
19
 
20
- from ....constants import XINFERENCE_ENABLE_SGLANG
21
20
  from ....types import (
22
21
  ChatCompletion,
23
22
  ChatCompletionChunk,
@@ -63,15 +62,26 @@ try:
63
62
  except ImportError:
64
63
  SGLANG_INSTALLED = False
65
64
 
66
- SGLANG_SUPPORTED_MODELS = ["llama-2", "mistral-v0.1", "mixtral-v0.1"]
65
+ SGLANG_SUPPORTED_MODELS = [
66
+ "llama-2",
67
+ "llama-3",
68
+ "llama-3.1",
69
+ "mistral-v0.1",
70
+ "mixtral-v0.1",
71
+ ]
67
72
  SGLANG_SUPPORTED_CHAT_MODELS = [
68
73
  "llama-2-chat",
74
+ "llama-3-instruct",
75
+ "llama-3.1-instruct",
69
76
  "qwen-chat",
70
77
  "qwen1.5-chat",
78
+ "qwen2-instruct",
79
+ "qwen2-moe-instruct",
71
80
  "mistral-instruct-v0.1",
72
81
  "mistral-instruct-v0.2",
73
82
  "mixtral-instruct-v0.1",
74
83
  "gemma-it",
84
+ "gemma-2-it",
75
85
  ]
76
86
 
77
87
 
@@ -168,8 +178,6 @@ class SGLANGModel(LLM):
168
178
  def match(
169
179
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
170
180
  ) -> bool:
171
- if not XINFERENCE_ENABLE_SGLANG:
172
- return False
173
181
  if not cls._has_cuda_device():
174
182
  return False
175
183
  if not cls._is_linux():
@@ -332,8 +340,6 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
332
340
  def match(
333
341
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
334
342
  ) -> bool:
335
- if not XINFERENCE_ENABLE_SGLANG:
336
- return False
337
343
  if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
338
344
  return False
339
345
  if llm_spec.model_format == "pytorch":
@@ -483,11 +483,40 @@ Begin!"""
483
483
  else:
484
484
  ret += role
485
485
  return ret
486
+ elif prompt_style.style_name == "mistral-nemo":
487
+ seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
488
+ ret = "<s>"
489
+ for i, message in enumerate(chat_history):
490
+ role = get_role(message["role"])
491
+ content = message["content"]
492
+ if content:
493
+ if i == len(chat_history) - 2 and prompt_style.system_prompt:
494
+ ret += (
495
+ role
496
+ + " "
497
+ + prompt_style.system_prompt
498
+ + "\n\n"
499
+ + content
500
+ + seps[i % 2]
501
+ )
502
+ else:
503
+ ret += role + " " + content + seps[i % 2]
504
+ else:
505
+ ret += role
506
+ return ret
486
507
  else:
487
508
  raise ValueError(f"Invalid prompt style: {prompt_style.style_name}")
488
509
 
489
510
  @classmethod
490
511
  def _to_chat_completion_chunk(cls, chunk: CompletionChunk) -> ChatCompletionChunk:
512
+ choices = chunk.get("choices")
513
+ if (
514
+ chunk.get("object") == "chat.completion.chunk"
515
+ and choices
516
+ and "delta" in choices[0]
517
+ ):
518
+ # Already a ChatCompletionChunk, we don't need to convert chunk.
519
+ return cast(ChatCompletionChunk, chunk)
491
520
  chat_chunk = {
492
521
  "id": "chat" + chunk["id"],
493
522
  "model": chunk["model"],
@@ -497,7 +526,7 @@ Begin!"""
497
526
  {
498
527
  "index": i,
499
528
  "delta": {
500
- "content": choice["text"],
529
+ "content": choice.get("text"),
501
530
  **(
502
531
  {"tool_calls": choice["tool_calls"]}
503
532
  if "tool_calls" in choice
@@ -718,6 +747,54 @@ Begin!"""
718
747
  else:
719
748
  return lambda tokens, delta: delta
720
749
 
750
+ @classmethod
751
+ def _tool_calls_completion_chunk(cls, model_family, model_uid, c, tools):
752
+ _id = str(uuid.uuid4())
753
+ content, func, args = cls._eval_tool_arguments(model_family, c, tools)
754
+ if func:
755
+ d = {
756
+ "role": "assistant",
757
+ "content": content,
758
+ "tool_calls": [
759
+ {
760
+ "id": f"call_{_id}",
761
+ "type": "function",
762
+ "function": {
763
+ "name": func,
764
+ "arguments": json.dumps(args),
765
+ },
766
+ }
767
+ ],
768
+ }
769
+ finish_reason = "tool_calls"
770
+ else:
771
+ d = {"role": "assistant", "content": content, "tool_calls": []}
772
+ finish_reason = "stop"
773
+ try:
774
+ usage = c.get("usage")
775
+ assert "prompt_tokens" in usage
776
+ except Exception:
777
+ usage = {
778
+ "prompt_tokens": -1,
779
+ "completion_tokens": -1,
780
+ "total_tokens": -1,
781
+ }
782
+ return {
783
+ "id": "chat" + f"cmpl-{_id}",
784
+ "model": model_uid,
785
+ "object": "chat.completion.chunk",
786
+ "created": int(time.time()),
787
+ "choices": [
788
+ {
789
+ "index": 0,
790
+ "delta": d,
791
+ "logprobs": None,
792
+ "finish_reason": finish_reason,
793
+ }
794
+ ],
795
+ "usage": usage,
796
+ }
797
+
721
798
  @classmethod
722
799
  def _tool_calls_completion(cls, model_family, model_uid, c, tools):
723
800
  _id = str(uuid.uuid4())
@@ -28,7 +28,6 @@ from typing import (
28
28
  Union,
29
29
  )
30
30
 
31
- from ....constants import XINFERENCE_DISABLE_VLLM
32
31
  from ....types import (
33
32
  ChatCompletion,
34
33
  ChatCompletionChunk,
@@ -151,6 +150,15 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.4.0":
151
150
  VLLM_SUPPORTED_CHAT_MODELS.append("qwen2-moe-instruct")
152
151
  VLLM_SUPPORTED_CHAT_MODELS.append("c4ai-command-r-v01")
153
152
 
153
+ if VLLM_INSTALLED and vllm.__version__ >= "0.5.3":
154
+ VLLM_SUPPORTED_CHAT_MODELS.append("gemma-2-it")
155
+ VLLM_SUPPORTED_CHAT_MODELS.append("mistral-nemo-instruct")
156
+ VLLM_SUPPORTED_CHAT_MODELS.append("mistral-large-instruct")
157
+
158
+ if VLLM_INSTALLED and vllm.__version__ > "0.5.3":
159
+ VLLM_SUPPORTED_MODELS.append("llama-3.1")
160
+ VLLM_SUPPORTED_CHAT_MODELS.append("llama-3.1-instruct")
161
+
154
162
 
155
163
  class VLLMModel(LLM):
156
164
  def __init__(
@@ -288,8 +296,6 @@ class VLLMModel(LLM):
288
296
  def match(
289
297
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
290
298
  ) -> bool:
291
- if XINFERENCE_DISABLE_VLLM:
292
- return False
293
299
  if not cls._has_cuda_device():
294
300
  return False
295
301
  if not cls._is_linux():
@@ -514,8 +520,6 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
514
520
  def match(
515
521
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
516
522
  ) -> bool:
517
- if XINFERENCE_DISABLE_VLLM:
518
- return False
519
523
  if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
520
524
  return False
521
525
  if llm_spec.model_format == "pytorch":
@@ -107,7 +107,7 @@ class RerankModel:
107
107
  self,
108
108
  model_spec: RerankModelSpec,
109
109
  model_uid: str,
110
- model_path: str,
110
+ model_path: Optional[str] = None,
111
111
  device: Optional[str] = None,
112
112
  use_fp16: bool = False,
113
113
  model_config: Optional[Dict] = None,
@@ -290,6 +290,7 @@ def create_rerank_model_instance(
290
290
  model_uid: str,
291
291
  model_name: str,
292
292
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
293
+ model_path: Optional[str] = None,
293
294
  **kwargs,
294
295
  ) -> Tuple[RerankModel, RerankModelDescription]:
295
296
  from ..utils import download_from_modelscope
@@ -321,8 +322,8 @@ def create_rerank_model_instance(
321
322
  f"Huggingface: {BUILTIN_RERANK_MODELS.keys()}"
322
323
  f"ModelScope: {MODELSCOPE_RERANK_MODELS.keys()}"
323
324
  )
324
-
325
- model_path = cache(model_spec)
325
+ if not model_path:
326
+ model_path = cache(model_spec)
326
327
  use_fp16 = kwargs.pop("use_fp16", False)
327
328
  model = RerankModel(
328
329
  model_spec, model_uid, model_path, use_fp16=use_fp16, model_config=kwargs
File without changes
File without changes