xinference 0.12.0__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (67) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +74 -6
  3. xinference/client/restful/restful_client.py +74 -5
  4. xinference/constants.py +1 -0
  5. xinference/core/cache_tracker.py +48 -28
  6. xinference/core/model.py +54 -42
  7. xinference/core/scheduler.py +34 -16
  8. xinference/core/supervisor.py +73 -24
  9. xinference/core/worker.py +68 -2
  10. xinference/deploy/cmdline.py +86 -2
  11. xinference/deploy/test/test_cmdline.py +19 -10
  12. xinference/model/audio/__init__.py +14 -1
  13. xinference/model/audio/core.py +12 -1
  14. xinference/model/audio/custom.py +6 -4
  15. xinference/model/audio/model_spec_modelscope.json +20 -0
  16. xinference/model/llm/__init__.py +34 -2
  17. xinference/model/llm/llm_family.json +2 -0
  18. xinference/model/llm/llm_family.py +86 -1
  19. xinference/model/llm/llm_family_csghub.json +66 -0
  20. xinference/model/llm/llm_family_modelscope.json +2 -0
  21. xinference/model/llm/pytorch/chatglm.py +18 -12
  22. xinference/model/llm/pytorch/core.py +92 -42
  23. xinference/model/llm/pytorch/glm4v.py +13 -3
  24. xinference/model/llm/pytorch/qwen_vl.py +1 -1
  25. xinference/model/llm/pytorch/utils.py +27 -14
  26. xinference/model/llm/utils.py +14 -13
  27. xinference/model/llm/vllm/core.py +10 -4
  28. xinference/model/utils.py +8 -2
  29. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  30. xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
  31. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  32. xinference/thirdparty/ChatTTS/infer/api.py +125 -0
  33. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  34. xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
  35. xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
  36. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  37. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
  38. xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
  39. xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
  40. xinference/web/ui/build/asset-manifest.json +6 -6
  41. xinference/web/ui/build/index.html +1 -1
  42. xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
  43. xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
  44. xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
  45. xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
  46. xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
  47. xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
  48. xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
  49. xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
  51. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/METADATA +1 -1
  52. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/RECORD +57 -45
  53. xinference/web/ui/build/static/css/main.54bca460.css +0 -2
  54. xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
  55. xinference/web/ui/build/static/js/main.551aa479.js +0 -3
  56. xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
  57. xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
  58. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
  63. /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
  64. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
  65. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
  66. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
  67. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
15
15
  import json
16
16
  import logging
17
17
  import os
18
+ from functools import lru_cache
18
19
  from typing import Iterable, Iterator, List, Optional, Union
19
20
 
20
21
  from ....core.scheduler import InferenceRequest
@@ -28,6 +29,7 @@ from ....types import (
28
29
  ChatCompletionChunk,
29
30
  ChatCompletionMessage,
30
31
  Completion,
32
+ CompletionChoice,
31
33
  CompletionChunk,
32
34
  CreateCompletionTorch,
33
35
  Embedding,
@@ -366,6 +368,90 @@ class PytorchModel(LLM):
366
368
  else:
367
369
  return generator_wrapper(prompt, generate_config)
368
370
 
371
+ @lru_cache
372
+ def get_context_len(self):
373
+ return get_context_length(self._model.config)
374
+
375
+ def get_max_num_seqs(self) -> int:
376
+ return self._pytorch_model_config.get("max_num_seqs") # type: ignore
377
+
378
+ def prepare_batch_inference(self, req_list: List[InferenceRequest]):
379
+ # check some parameters
380
+ for r in req_list:
381
+ if r.sanitized_generate_config is None:
382
+ r.sanitized_generate_config = self._sanitize_generate_config(
383
+ r.generate_config
384
+ )
385
+ if r.is_prefill:
386
+ # check some generate params
387
+ max_src_len = get_max_src_len(self.get_context_len(), r) # type: ignore
388
+ if max_src_len < 0:
389
+ r.stopped = True
390
+ r.error_msg = "Max tokens exceeds model's max length"
391
+ continue
392
+ if r.stream_interval <= 0:
393
+ r.stopped = True
394
+ r.error_msg = "`stream_interval` must be greater than 0"
395
+ continue
396
+ stop_str = r.sanitized_generate_config.get("stop", None)
397
+ if stop_str and (
398
+ not (isinstance(stop_str, str) or isinstance(stop_str, Iterable))
399
+ ):
400
+ r.stopped = True
401
+ r.error_msg = "Invalid `stop` field type"
402
+ continue
403
+
404
+ def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
405
+ for req in req_list:
406
+ if req.error_msg is None:
407
+ # nothing need handle for non-stream case
408
+ if req.stream:
409
+ results = []
410
+ for i, c in enumerate(req.completion):
411
+ if c == "<bos_stream>":
412
+ chunk = req.completion[i + 1]
413
+ results.append(
414
+ CompletionChunk(
415
+ id=chunk["id"],
416
+ object=chunk["object"],
417
+ created=chunk["created"],
418
+ model=chunk["model"],
419
+ choices=[
420
+ CompletionChoice(
421
+ text="",
422
+ index=0,
423
+ logprobs=None,
424
+ finish_reason=None,
425
+ )
426
+ ],
427
+ )
428
+ )
429
+ continue
430
+ elif c == "<eos_stream>":
431
+ break
432
+ else:
433
+ results.append(c)
434
+
435
+ if req.stopped and req.include_usage:
436
+ results.append(req.completion[-1])
437
+ req.completion = results
438
+
439
+ def batch_inference(self, req_list: List[InferenceRequest]):
440
+ from .utils import batch_inference_one_step
441
+
442
+ self.prepare_batch_inference(req_list)
443
+ context_len = self.get_context_len()
444
+ assert isinstance(context_len, int)
445
+ batch_inference_one_step(
446
+ req_list,
447
+ self.model_uid,
448
+ self._model,
449
+ self._tokenizer,
450
+ self._device,
451
+ context_len,
452
+ )
453
+ self.handle_batch_inference_results(req_list)
454
+
369
455
  def create_embedding(self, input: Union[str, List[str]]) -> Embedding:
370
456
  try:
371
457
  import torch
@@ -464,7 +550,6 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
464
550
  pytorch_model_config,
465
551
  peft_model,
466
552
  )
467
- self._context_len = None
468
553
 
469
554
  def _sanitize_generate_config(
470
555
  self,
@@ -540,7 +625,6 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
540
625
 
541
626
  def load(self):
542
627
  super().load()
543
- self._context_len = get_context_length(self._model.config)
544
628
 
545
629
  def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
546
630
  assert self.model_family.prompt_style is not None
@@ -553,48 +637,14 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
553
637
  )
554
638
  return full_prompt
555
639
 
556
- def get_max_num_seqs(self) -> int:
557
- return self._pytorch_model_config.get("max_num_seqs") # type: ignore
558
-
559
- def batch_inference(self, req_list: List[InferenceRequest]):
560
- from .utils import batch_inference_one_step
561
-
640
+ def prepare_batch_inference(self, req_list: List[InferenceRequest]):
641
+ super().prepare_batch_inference(req_list)
562
642
  for r in req_list:
563
- if r.sanitized_generate_config is None:
564
- r.sanitized_generate_config = self._sanitize_generate_config(
565
- r.generate_config
566
- )
567
- if r.is_prefill:
568
- # check some generate params
569
- max_src_len = get_max_src_len(self._context_len, r) # type: ignore
570
- if max_src_len < 0:
571
- r.stopped = True
572
- r.error_msg = "Max tokens exceeds model's max length"
573
- continue
574
- if r.stream_interval <= 0:
575
- r.stopped = True
576
- r.error_msg = "`stream_interval` must be greater than 0"
577
- continue
578
- stop_str = r.sanitized_generate_config.get("stop", None)
579
- if stop_str and (
580
- not (isinstance(stop_str, str) or isinstance(stop_str, Iterable))
581
- ):
582
- r.stopped = True
583
- r.error_msg = "Invalid `stop` field type"
584
- continue
585
- r.full_prompt = self._get_full_prompt(
586
- r.prompt, r.system_prompt, r.chat_history, None
587
- )
643
+ r.full_prompt = self._get_full_prompt(
644
+ r.prompt, r.system_prompt, r.chat_history, None
645
+ )
588
646
 
589
- assert isinstance(self._context_len, int)
590
- batch_inference_one_step(
591
- req_list,
592
- self.model_uid,
593
- self._model,
594
- self._tokenizer,
595
- self._device,
596
- self._context_len,
597
- )
647
+ def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
598
648
  for req in req_list:
599
649
  if req.stream and req.error_msg is None:
600
650
  if req.completion:
@@ -56,19 +56,29 @@ class Glm4VModel(PytorchChatModel):
56
56
  return True
57
57
  return False
58
58
 
59
- def load(self, **kwargs):
59
+ def load(self):
60
60
  from transformers import AutoModelForCausalLM, AutoTokenizer
61
61
 
62
62
  device = self._pytorch_model_config.get("device", "auto")
63
63
  self._device = select_device(device)
64
- self._device = "auto" if self._device == "cuda" else self._device
64
+
65
+ kwargs = {"device_map": self._device}
66
+ quantization = self.quantization
67
+ if quantization != "none":
68
+ if self._device == "cuda" and self._is_linux():
69
+ kwargs["device_map"] = "auto"
70
+ self._device = "auto"
71
+ if quantization == "4-bit":
72
+ kwargs["load_in_4bit"] = True
73
+ elif quantization == "8-bit":
74
+ kwargs["load_in_8bit"] = True
65
75
 
66
76
  model = AutoModelForCausalLM.from_pretrained(
67
77
  self.model_path,
68
78
  low_cpu_mem_usage=True,
69
79
  trust_remote_code=True,
70
80
  torch_dtype=torch.float16,
71
- device_map=self._device,
81
+ **kwargs,
72
82
  )
73
83
  self._model = model.eval()
74
84
 
@@ -45,7 +45,7 @@ class QwenVLChatModel(PytorchChatModel):
45
45
  def match(
46
46
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
47
  ) -> bool:
48
- if "qwen" in model_family.model_name:
48
+ if "qwen" in model_family.model_name and "vision" in model_family.model_ability:
49
49
  return True
50
50
  return False
51
51
 
@@ -126,6 +126,7 @@ def generate_stream(
126
126
  stop_str = generate_config.get("stop", None)
127
127
  stop_token_ids = generate_config.get("stop_token_ids", None) or []
128
128
  stop_token_ids.append(tokenizer.eos_token_id)
129
+ chunk_id = str(uuid.uuid4())
129
130
 
130
131
  logits_processor = prepare_logits_processor(
131
132
  temperature, repetition_penalty, top_p, top_k
@@ -289,7 +290,7 @@ def generate_stream(
289
290
  text=output, index=0, logprobs=None, finish_reason=None
290
291
  )
291
292
  completion_chunk = CompletionChunk(
292
- id=str(uuid.uuid1()),
293
+ id=chunk_id,
293
294
  object="text_completion",
294
295
  created=int(time.time()),
295
296
  model=model_uid,
@@ -327,7 +328,7 @@ def generate_stream(
327
328
  )
328
329
 
329
330
  completion_chunk = CompletionChunk(
330
- id=str(uuid.uuid1()),
331
+ id=chunk_id,
331
332
  object="text_completion",
332
333
  created=int(time.time()),
333
334
  model=model_uid,
@@ -343,7 +344,7 @@ def generate_stream(
343
344
 
344
345
  if include_usage:
345
346
  completion_chunk = CompletionChunk(
346
- id=str(uuid.uuid1()),
347
+ id=chunk_id,
347
348
  object="text_completion",
348
349
  created=int(time.time()),
349
350
  model=model_uid,
@@ -390,6 +391,7 @@ def generate_stream_falcon(
390
391
  stop_str = generate_config.get("stop", None)
391
392
  stop_token_ids = generate_config.get("stop_token_ids", None) or []
392
393
  stop_token_ids.append(tokenizer.eos_token_id)
394
+ chunk_id = str(uuid.uuid4())
393
395
 
394
396
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
395
397
  input_ids = inputs["input_ids"]
@@ -473,7 +475,7 @@ def generate_stream_falcon(
473
475
  text=output, index=0, logprobs=None, finish_reason=None
474
476
  )
475
477
  completion_chunk = CompletionChunk(
476
- id=str(uuid.uuid1()),
478
+ id=chunk_id,
477
479
  object="text_completion",
478
480
  created=int(time.time()),
479
481
  model=model_uid,
@@ -500,7 +502,7 @@ def generate_stream_falcon(
500
502
  text=output, index=0, logprobs=None, finish_reason=finish_reason
501
503
  )
502
504
  completion_chunk = CompletionChunk(
503
- id=str(uuid.uuid1()),
505
+ id=chunk_id,
504
506
  object="text_completion",
505
507
  created=int(time.time()),
506
508
  model=model_uid,
@@ -516,7 +518,7 @@ def generate_stream_falcon(
516
518
 
517
519
  if include_usage:
518
520
  completion_chunk = CompletionChunk(
519
- id=str(uuid.uuid1()),
521
+ id=chunk_id,
520
522
  object="text_completion",
521
523
  created=int(time.time()),
522
524
  model=model_uid,
@@ -586,6 +588,7 @@ def get_max_src_len(context_len: int, r: InferenceRequest) -> int:
586
588
 
587
589
  def _get_completion_chunk(
588
590
  output: str,
591
+ chunk_id: str,
589
592
  finish_reason: Optional[str],
590
593
  model_uid: str,
591
594
  r: InferenceRequest,
@@ -601,7 +604,7 @@ def _get_completion_chunk(
601
604
  else []
602
605
  )
603
606
  completion_chunk = CompletionChunk(
604
- id=str(uuid.uuid1()),
607
+ id=chunk_id,
605
608
  object="text_completion",
606
609
  created=int(time.time()),
607
610
  model=model_uid,
@@ -617,14 +620,18 @@ def _get_completion_chunk(
617
620
 
618
621
 
619
622
  def _get_completion(
620
- output: str, finish_reason: Optional[str], model_uid: str, r: InferenceRequest
623
+ output: str,
624
+ chunk_id: str,
625
+ finish_reason: Optional[str],
626
+ model_uid: str,
627
+ r: InferenceRequest,
621
628
  ):
622
629
  completion_choice = CompletionChoice(
623
630
  text=output, index=0, logprobs=None, finish_reason=finish_reason
624
631
  )
625
632
 
626
633
  completion_chunk = CompletionChunk(
627
- id=str(uuid.uuid1()),
634
+ id=chunk_id,
628
635
  object="text_completion",
629
636
  created=int(time.time()),
630
637
  model=model_uid,
@@ -701,7 +708,7 @@ def _batch_inference_one_step_internal(
701
708
  decode_reqs = []
702
709
  for r in valid_req_list:
703
710
  if r.is_prefill:
704
- prompts.append(r.full_prompt)
711
+ prompts.append(r.full_prompt if r.full_prompt is not None else r.prompt)
705
712
  prefill_reqs.append(r)
706
713
  else:
707
714
  decode_reqs.append(r)
@@ -846,7 +853,7 @@ def _batch_inference_one_step_internal(
846
853
  r.last_output_length += len(output)
847
854
 
848
855
  completion_chunk = _get_completion_chunk(
849
- output, r.finish_reason, model_uid, r, False
856
+ output, r.chunk_id, r.finish_reason, model_uid, r, False
850
857
  )
851
858
  r.completion.append(completion_chunk)
852
859
  if r.stopped:
@@ -859,7 +866,7 @@ def _batch_inference_one_step_internal(
859
866
  if r.stopped and _i == decode_round - 1 and include_usage:
860
867
  r.completion.append(
861
868
  _get_completion_chunk(
862
- "", r.finish_reason, model_uid, r, True
869
+ "", r.chunk_id, r.finish_reason, model_uid, r, True
863
870
  )
864
871
  )
865
872
  else:
@@ -878,7 +885,9 @@ def _batch_inference_one_step_internal(
878
885
  if r not in output_mapping
879
886
  else output_mapping[r]
880
887
  )
881
- completion = _get_completion(outputs, r.finish_reason, model_uid, r)
888
+ completion = _get_completion(
889
+ outputs, r.chunk_id, r.finish_reason, model_uid, r
890
+ )
882
891
  r.completion = [completion]
883
892
 
884
893
  e_time = time.time()
@@ -911,4 +920,8 @@ def batch_inference_one_step(
911
920
  os._exit(1)
912
921
  except Exception as e:
913
922
  logger.exception(f"Internal error for batch inference: {e}.")
914
- # TODO: handle this
923
+ # If internal error happens, just skip all the requests in this batch.
924
+ # If not handle here, the client will hang.
925
+ for r in req_list:
926
+ r.stopped = True
927
+ r.error_msg = str(e)
@@ -607,7 +607,7 @@ Begin!"""
607
607
  return arguments, None, None
608
608
 
609
609
  @staticmethod
610
- def _eval_chatglm3_arguments(c, tools):
610
+ def _eval_glm_chat_arguments(c, tools):
611
611
  if isinstance(c[0], str):
612
612
  return c[0], None, None
613
613
  return None, c[0]["name"], c[0]["parameters"]
@@ -659,9 +659,9 @@ Begin!"""
659
659
  family = model_family.model_family or model_family.model_name
660
660
  if family in ["gorilla-openfunctions-v1", "gorilla-openfunctions-v2"]:
661
661
  content, func, args = cls._eval_gorilla_openfunctions_arguments(c, tools)
662
- elif "chatglm3" == family:
663
- content, func, args = cls._eval_chatglm3_arguments(c, tools)
664
- elif family in ["qwen-chat", "qwen1.5-chat"]:
662
+ elif family in ["chatglm3", "glm4-chat"]:
663
+ content, func, args = cls._eval_glm_chat_arguments(c, tools)
664
+ elif family in ["qwen-chat", "qwen1.5-chat", "qwen2-instruct"]:
665
665
  content, func, args = cls._eval_qwen_chat_arguments(c, tools)
666
666
  else:
667
667
  raise Exception(
@@ -676,28 +676,29 @@ Begin!"""
676
676
  Generates a filter function for Qwen series models to retain outputs after "\nFinal Answer:".
677
677
 
678
678
  Returns:
679
- A function that takes tokens (string output by the model so far) as input
680
- returns True if current token is after "\nFinal Answer:", else False.
679
+ A function that takes tokens (string output by the model so far) and delta (new tokens added) as input,
680
+ returns the part after "\nFinal Answer:" if found, else returns delta.
681
681
  """
682
682
  family = model_family.model_family or model_family.model_name
683
683
  if family in ["qwen-chat", "qwen1.5-chat"]:
684
684
  # Encapsulating function to reset 'found' after each call
685
685
  found = False
686
686
 
687
- def process_token(tokens: str):
687
+ def process_tokens(tokens: str, delta: str):
688
688
  nonlocal found
689
689
  # Once "Final Answer:" is found, future tokens are allowed.
690
690
  if found:
691
- return True
691
+ return delta
692
692
  # Check if the token ends with "\nFinal Answer:" and update `found`.
693
- if tokens.endswith("\nFinal Answer:"):
693
+ final_answer_idx = tokens.lower().rfind("\nfinal answer:")
694
+ if final_answer_idx != -1:
694
695
  found = True
695
- return False
696
+ return tokens[final_answer_idx + len("\nfinal answer:") :]
697
+ return ""
696
698
 
697
- return process_token
699
+ return process_tokens
698
700
  else:
699
- # For other families, allow all tokens.
700
- return lambda tokens: True
701
+ return lambda tokens, delta: delta
701
702
 
702
703
  @classmethod
703
704
  def _tool_calls_completion(cls, model_family, model_uid, c, tools):
@@ -444,7 +444,9 @@ class VLLMModel(LLM):
444
444
  _content, func, args = ChatModelMixin._eval_tool_arguments(
445
445
  self.model_family, chunk, tools
446
446
  )
447
- choice["text"] = choice_delta
447
+ choice["text"] = tools_token_filter(
448
+ tokens=previous_texts[0], delta=choice_delta
449
+ )
448
450
  if func is not None:
449
451
  choice["text"] = None
450
452
  choice["finish_reason"] = "tool_calls"
@@ -458,9 +460,13 @@ class VLLMModel(LLM):
458
460
  ),
459
461
  )
460
462
  ]
461
- # use a filter function to skip Qwen's react thought process
462
- elif not tools_token_filter(previous_texts[0]):
463
- continue
463
+ else:
464
+ # use a filter function to skip Qwen's react thought process
465
+ choice["text"] = tools_token_filter(
466
+ tokens=previous_texts[0], delta=choice["text"]
467
+ )
468
+ if not choice["text"]:
469
+ continue
464
470
  prompt_tokens = len(_request_output.prompt_token_ids)
465
471
  completion_tokens = sum(
466
472
  len(output.token_ids) for output in _request_output.outputs
xinference/model/utils.py CHANGED
@@ -42,14 +42,20 @@ def is_locale_chinese_simplified() -> bool:
42
42
 
43
43
 
44
44
  def download_from_modelscope() -> bool:
45
- if os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "modelscope":
46
- return True
45
+ if os.environ.get(XINFERENCE_ENV_MODEL_SRC):
46
+ return os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "modelscope"
47
47
  elif is_locale_chinese_simplified():
48
48
  return True
49
49
  else:
50
50
  return False
51
51
 
52
52
 
53
+ def download_from_csghub() -> bool:
54
+ if os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "csghub":
55
+ return True
56
+ return False
57
+
58
+
53
59
  def symlink_local_file(path: str, local_dir: str, relpath: str) -> str:
54
60
  from huggingface_hub.file_download import _create_symlink
55
61
 
File without changes
@@ -0,0 +1,40 @@
1
+
2
+ from openai import OpenAI
3
+
4
+ prompt_dict = {
5
+ 'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"},
6
+ {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
7
+ {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
8
+ 'deepseek': [
9
+ {"role": "system", "content": "You are a helpful assistant"},
10
+ {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
11
+ {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
12
+ 'deepseek_TN': [
13
+ {"role": "system", "content": "You are a helpful assistant"},
14
+ {"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"},
15
+ {"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"},
16
+ {"role": "user", "content": "We paid $123 for this desk."},
17
+ {"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."},
18
+ {"role": "user", "content": "详询请拨打010-724654"},
19
+ {"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"},
20
+ {"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"},
21
+ {"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"},
22
+ ],
23
+ }
24
+
25
+ class llm_api:
26
+ def __init__(self, api_key, base_url, model):
27
+ self.client = OpenAI(
28
+ api_key = api_key,
29
+ base_url = base_url,
30
+ )
31
+ self.model = model
32
+ def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs):
33
+
34
+ completion = self.client.chat.completions.create(
35
+ model = self.model,
36
+ messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},],
37
+ temperature = temperature,
38
+ **kwargs
39
+ )
40
+ return completion.choices[0].message.content
File without changes
@@ -0,0 +1,125 @@
1
+
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
5
+ from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
6
+
7
+ def infer_code(
8
+ models,
9
+ text,
10
+ spk_emb = None,
11
+ top_P = 0.7,
12
+ top_K = 20,
13
+ temperature = 0.3,
14
+ repetition_penalty = 1.05,
15
+ max_new_token = 2048,
16
+ **kwargs
17
+ ):
18
+
19
+ device = next(models['gpt'].parameters()).device
20
+
21
+ if not isinstance(text, list):
22
+ text = [text]
23
+
24
+ if not isinstance(temperature, list):
25
+ temperature = [temperature] * models['gpt'].num_vq
26
+
27
+ if spk_emb is not None:
28
+ text = [f'[Stts][spk_emb]{i}[Ptts]' for i in text]
29
+ else:
30
+ text = [f'[Stts][empty_spk]{i}[Ptts]' for i in text]
31
+
32
+ text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
33
+ input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq)
34
+ text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
35
+
36
+ inputs = {
37
+ 'input_ids': input_ids,
38
+ 'text_mask': text_mask,
39
+ 'attention_mask': text_token['attention_mask'],
40
+ }
41
+
42
+ emb = models['gpt'].get_emb(**inputs)
43
+ if spk_emb is not None:
44
+ emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \
45
+ F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12)
46
+
47
+ num_code = models['gpt'].emb_code[0].num_embeddings - 1
48
+
49
+ LogitsWarpers = []
50
+ if top_P is not None:
51
+ LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
52
+ if top_K is not None:
53
+ LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
54
+
55
+ LogitsProcessors = []
56
+ if repetition_penalty is not None and repetition_penalty != 1:
57
+ LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\
58
+ repetition_penalty, num_code, 16))
59
+
60
+ result = models['gpt'].generate(
61
+ emb, inputs['input_ids'],
62
+ temperature = torch.tensor(temperature, device=device),
63
+ attention_mask = inputs['attention_mask'],
64
+ LogitsWarpers = LogitsWarpers,
65
+ LogitsProcessors = LogitsProcessors,
66
+ eos_token = num_code,
67
+ max_new_token = max_new_token,
68
+ infer_text = False,
69
+ **kwargs
70
+ )
71
+
72
+ return result
73
+
74
+
75
+ def refine_text(
76
+ models,
77
+ text,
78
+ top_P = 0.7,
79
+ top_K = 20,
80
+ temperature = 0.7,
81
+ repetition_penalty = 1.0,
82
+ max_new_token = 384,
83
+ prompt = '',
84
+ **kwargs
85
+ ):
86
+
87
+ device = next(models['gpt'].parameters()).device
88
+
89
+ if not isinstance(text, list):
90
+ text = [text]
91
+
92
+ assert len(text), 'text should not be empty'
93
+
94
+ text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
95
+ text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
96
+ text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
97
+
98
+ inputs = {
99
+ 'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq),
100
+ 'text_mask': text_mask,
101
+ 'attention_mask': text_token['attention_mask'],
102
+ }
103
+
104
+ LogitsWarpers = []
105
+ if top_P is not None:
106
+ LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
107
+ if top_K is not None:
108
+ LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
109
+
110
+ LogitsProcessors = []
111
+ if repetition_penalty is not None and repetition_penalty != 1:
112
+ LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16))
113
+
114
+ result = models['gpt'].generate(
115
+ models['gpt'].get_emb(**inputs), inputs['input_ids'],
116
+ temperature = torch.tensor([temperature,], device=device),
117
+ attention_mask = inputs['attention_mask'],
118
+ LogitsWarpers = LogitsWarpers,
119
+ LogitsProcessors = LogitsProcessors,
120
+ eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
121
+ max_new_token = max_new_token,
122
+ infer_text = True,
123
+ **kwargs
124
+ )
125
+ return result
File without changes