xinference 0.11.3__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (75) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +143 -6
  3. xinference/client/restful/restful_client.py +144 -5
  4. xinference/constants.py +5 -0
  5. xinference/core/cache_tracker.py +48 -28
  6. xinference/core/model.py +160 -19
  7. xinference/core/scheduler.py +446 -0
  8. xinference/core/supervisor.py +99 -24
  9. xinference/core/worker.py +68 -2
  10. xinference/deploy/cmdline.py +86 -2
  11. xinference/deploy/test/test_cmdline.py +19 -10
  12. xinference/isolation.py +9 -2
  13. xinference/model/audio/__init__.py +14 -1
  14. xinference/model/audio/chattts.py +84 -0
  15. xinference/model/audio/core.py +22 -4
  16. xinference/model/audio/custom.py +6 -4
  17. xinference/model/audio/model_spec.json +20 -0
  18. xinference/model/audio/model_spec_modelscope.json +20 -0
  19. xinference/model/llm/__init__.py +38 -2
  20. xinference/model/llm/llm_family.json +509 -1
  21. xinference/model/llm/llm_family.py +86 -1
  22. xinference/model/llm/llm_family_csghub.json +66 -0
  23. xinference/model/llm/llm_family_modelscope.json +411 -2
  24. xinference/model/llm/pytorch/chatglm.py +20 -13
  25. xinference/model/llm/pytorch/cogvlm2.py +76 -17
  26. xinference/model/llm/pytorch/core.py +141 -6
  27. xinference/model/llm/pytorch/glm4v.py +268 -0
  28. xinference/model/llm/pytorch/minicpmv25.py +232 -0
  29. xinference/model/llm/pytorch/qwen_vl.py +1 -1
  30. xinference/model/llm/pytorch/utils.py +405 -8
  31. xinference/model/llm/utils.py +14 -13
  32. xinference/model/llm/vllm/core.py +16 -4
  33. xinference/model/utils.py +8 -2
  34. xinference/thirdparty/ChatTTS/__init__.py +1 -0
  35. xinference/thirdparty/ChatTTS/core.py +200 -0
  36. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  37. xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
  38. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  39. xinference/thirdparty/ChatTTS/infer/api.py +125 -0
  40. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  41. xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
  42. xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
  43. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  44. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
  45. xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
  46. xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
  47. xinference/types.py +3 -0
  48. xinference/web/ui/build/asset-manifest.json +6 -6
  49. xinference/web/ui/build/index.html +1 -1
  50. xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
  51. xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
  52. xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
  53. xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
  59. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/METADATA +26 -9
  60. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/RECORD +65 -47
  61. xinference/web/ui/build/static/css/main.54bca460.css +0 -2
  62. xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
  63. xinference/web/ui/build/static/js/main.551aa479.js +0 -3
  64. xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
  65. xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
  66. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
  67. xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
  68. xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
  69. xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
  70. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
  71. /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
  72. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
  73. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
  74. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
  75. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
@@ -14,13 +14,15 @@
14
14
 
15
15
  import gc
16
16
  import logging
17
+ import os
17
18
  import time
18
19
  import uuid
19
20
  from threading import Thread
20
- from typing import Iterable, Iterator, Tuple
21
+ from typing import Dict, Iterable, Iterator, List, Optional, Tuple
21
22
 
22
23
  import torch
23
24
  from transformers import GenerationConfig, TextIteratorStreamer
25
+ from transformers.cache_utils import DynamicCache
24
26
  from transformers.generation.logits_process import (
25
27
  LogitsProcessorList,
26
28
  RepetitionPenaltyLogitsProcessor,
@@ -29,8 +31,10 @@ from transformers.generation.logits_process import (
29
31
  TopPLogitsWarper,
30
32
  )
31
33
 
34
+ from ....core.scheduler import InferenceRequest
32
35
  from ....device_utils import empty_cache
33
36
  from ....types import (
37
+ Completion,
34
38
  CompletionChoice,
35
39
  CompletionChunk,
36
40
  CompletionUsage,
@@ -54,7 +58,7 @@ def is_partial_stop(output: str, stop_str: str):
54
58
  return False
55
59
 
56
60
 
57
- def get_context_length(config):
61
+ def get_context_length(config) -> int:
58
62
  """Get the context length of a model from a huggingface model config."""
59
63
  if (
60
64
  hasattr(config, "max_sequence_length")
@@ -122,6 +126,7 @@ def generate_stream(
122
126
  stop_str = generate_config.get("stop", None)
123
127
  stop_token_ids = generate_config.get("stop_token_ids", None) or []
124
128
  stop_token_ids.append(tokenizer.eos_token_id)
129
+ chunk_id = str(uuid.uuid4())
125
130
 
126
131
  logits_processor = prepare_logits_processor(
127
132
  temperature, repetition_penalty, top_p, top_k
@@ -285,7 +290,7 @@ def generate_stream(
285
290
  text=output, index=0, logprobs=None, finish_reason=None
286
291
  )
287
292
  completion_chunk = CompletionChunk(
288
- id=str(uuid.uuid1()),
293
+ id=chunk_id,
289
294
  object="text_completion",
290
295
  created=int(time.time()),
291
296
  model=model_uid,
@@ -323,7 +328,7 @@ def generate_stream(
323
328
  )
324
329
 
325
330
  completion_chunk = CompletionChunk(
326
- id=str(uuid.uuid1()),
331
+ id=chunk_id,
327
332
  object="text_completion",
328
333
  created=int(time.time()),
329
334
  model=model_uid,
@@ -339,7 +344,7 @@ def generate_stream(
339
344
 
340
345
  if include_usage:
341
346
  completion_chunk = CompletionChunk(
342
- id=str(uuid.uuid1()),
347
+ id=chunk_id,
343
348
  object="text_completion",
344
349
  created=int(time.time()),
345
350
  model=model_uid,
@@ -386,6 +391,7 @@ def generate_stream_falcon(
386
391
  stop_str = generate_config.get("stop", None)
387
392
  stop_token_ids = generate_config.get("stop_token_ids", None) or []
388
393
  stop_token_ids.append(tokenizer.eos_token_id)
394
+ chunk_id = str(uuid.uuid4())
389
395
 
390
396
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
391
397
  input_ids = inputs["input_ids"]
@@ -469,7 +475,7 @@ def generate_stream_falcon(
469
475
  text=output, index=0, logprobs=None, finish_reason=None
470
476
  )
471
477
  completion_chunk = CompletionChunk(
472
- id=str(uuid.uuid1()),
478
+ id=chunk_id,
473
479
  object="text_completion",
474
480
  created=int(time.time()),
475
481
  model=model_uid,
@@ -496,7 +502,7 @@ def generate_stream_falcon(
496
502
  text=output, index=0, logprobs=None, finish_reason=finish_reason
497
503
  )
498
504
  completion_chunk = CompletionChunk(
499
- id=str(uuid.uuid1()),
505
+ id=chunk_id,
500
506
  object="text_completion",
501
507
  created=int(time.time()),
502
508
  model=model_uid,
@@ -512,7 +518,7 @@ def generate_stream_falcon(
512
518
 
513
519
  if include_usage:
514
520
  completion_chunk = CompletionChunk(
515
- id=str(uuid.uuid1()),
521
+ id=chunk_id,
516
522
  object="text_completion",
517
523
  created=int(time.time()),
518
524
  model=model_uid,
@@ -528,3 +534,394 @@ def generate_stream_falcon(
528
534
  # clean
529
535
  gc.collect()
530
536
  empty_cache()
537
+
538
+
539
+ def _get_token_from_logits(
540
+ req: InferenceRequest, i: int, logits, temperature, repetition_penalty, top_p, top_k
541
+ ):
542
+ logits_processor = prepare_logits_processor(
543
+ temperature, repetition_penalty, top_p, top_k
544
+ )
545
+
546
+ if logits_processor:
547
+ if repetition_penalty > 1.0:
548
+ tmp_output_ids = torch.as_tensor(
549
+ [req.prompt_tokens + req.new_tokens], device=logits.device
550
+ )
551
+ else:
552
+ tmp_output_ids = None
553
+ last_token_logits = logits_processor(tmp_output_ids, logits[i : i + 1, -1, :])[
554
+ 0
555
+ ]
556
+ else:
557
+ last_token_logits = logits[i : i + 1, -1, :]
558
+
559
+ if temperature < 1e-5 or top_p < 1e-8: # greedy
560
+ _, indices = torch.topk(last_token_logits, 2)
561
+ else:
562
+ probs = torch.softmax(last_token_logits, dim=-1)
563
+ indices = torch.multinomial(probs, num_samples=2)
564
+ token = indices[0].int().item()
565
+ return token
566
+
567
+
568
+ def _pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
569
+ assert len(x) <= max_len
570
+ return [pad] * (max_len - len(x)) + x
571
+
572
+
573
+ def _pad_seqs_inplace(seqs: List[List[int]], pad: int):
574
+ max_len = max(len(seq) for seq in seqs)
575
+ n = len(seqs)
576
+ i = 0
577
+ while i < n:
578
+ seqs[i] = _pad_to_max_length(seqs[i], max_len, pad)
579
+ i += 1
580
+
581
+
582
+ def get_max_src_len(context_len: int, r: InferenceRequest) -> int:
583
+ max_new_tokens = int(
584
+ r.sanitized_generate_config.get("max_tokens", max_tokens_field.default)
585
+ )
586
+ return context_len - max_new_tokens - 8
587
+
588
+
589
+ def _get_completion_chunk(
590
+ output: str,
591
+ chunk_id: str,
592
+ finish_reason: Optional[str],
593
+ model_uid: str,
594
+ r: InferenceRequest,
595
+ just_usage: bool,
596
+ ):
597
+ completion_choice = (
598
+ [
599
+ CompletionChoice(
600
+ text=output, index=0, logprobs=None, finish_reason=finish_reason
601
+ )
602
+ ]
603
+ if not just_usage
604
+ else []
605
+ )
606
+ completion_chunk = CompletionChunk(
607
+ id=chunk_id,
608
+ object="text_completion",
609
+ created=int(time.time()),
610
+ model=model_uid,
611
+ choices=completion_choice,
612
+ )
613
+ completion_usage = CompletionUsage(
614
+ prompt_tokens=len(r.prompt_tokens),
615
+ completion_tokens=len(r.new_tokens),
616
+ total_tokens=len(r.prompt_tokens) + len(r.new_tokens),
617
+ )
618
+ completion_chunk["usage"] = completion_usage
619
+ return completion_chunk
620
+
621
+
622
+ def _get_completion(
623
+ output: str,
624
+ chunk_id: str,
625
+ finish_reason: Optional[str],
626
+ model_uid: str,
627
+ r: InferenceRequest,
628
+ ):
629
+ completion_choice = CompletionChoice(
630
+ text=output, index=0, logprobs=None, finish_reason=finish_reason
631
+ )
632
+
633
+ completion_chunk = CompletionChunk(
634
+ id=chunk_id,
635
+ object="text_completion",
636
+ created=int(time.time()),
637
+ model=model_uid,
638
+ choices=[completion_choice],
639
+ )
640
+ completion_usage = CompletionUsage(
641
+ prompt_tokens=len(r.prompt_tokens),
642
+ completion_tokens=len(r.new_tokens),
643
+ total_tokens=len(r.prompt_tokens) + len(r.new_tokens),
644
+ )
645
+ completion = Completion(
646
+ id=completion_chunk["id"],
647
+ object=completion_chunk["object"],
648
+ created=completion_chunk["created"],
649
+ model=completion_chunk["model"],
650
+ choices=completion_chunk["choices"],
651
+ usage=completion_usage,
652
+ )
653
+ return completion
654
+
655
+
656
+ def _merge_kv_cache(
657
+ past_kv: Tuple[Tuple[torch.Tensor]], new_kv: Tuple[Tuple[torch.Tensor]]
658
+ ):
659
+ from torch.nn.functional import pad
660
+
661
+ past_cache = DynamicCache.from_legacy_cache(past_kv)
662
+ new_cache = DynamicCache.from_legacy_cache(new_kv)
663
+ past_seq_len = past_cache.get_seq_length()
664
+ new_seq_len = new_cache.get_seq_length()
665
+ if past_seq_len != new_seq_len:
666
+ padding_target = new_cache if past_seq_len > new_seq_len else past_cache
667
+ padding_len = abs(past_seq_len - new_seq_len)
668
+ for idx in range(len(padding_target)):
669
+ k = padding_target.key_cache[idx]
670
+ v = padding_target.value_cache[idx]
671
+ _k = pad(k, (0, 0, padding_len, 0))
672
+ _v = pad(v, (0, 0, padding_len, 0))
673
+ padding_target.key_cache[idx] = _k
674
+ padding_target.value_cache[idx] = _v
675
+
676
+ ret_kv = DynamicCache()
677
+ for idx in range(len(past_cache)):
678
+ k1, k2 = new_cache.key_cache[idx], past_cache.key_cache[idx]
679
+ v1, v2 = new_cache.value_cache[idx], past_cache.value_cache[idx]
680
+ ret_kv.update(torch.cat((k1, k2), 0), torch.cat((v1, v2), 0), idx)
681
+ return ret_kv.to_legacy_cache()
682
+
683
+
684
+ @torch.inference_mode()
685
+ def _batch_inference_one_step_internal(
686
+ req_list: List[InferenceRequest],
687
+ model_uid,
688
+ model,
689
+ tokenizer,
690
+ device,
691
+ context_len: int,
692
+ decode_round: int = 16,
693
+ bos_flag: str = "<bos_stream>",
694
+ eos_flag: str = "<eos_stream>",
695
+ ):
696
+ # need to judge stopped here,
697
+ # since some requests state may change to stopped due to invalid parameters, e.g. max_src_len
698
+ valid_req_list = [r for r in req_list if not r.stopped]
699
+ if not valid_req_list:
700
+ return
701
+ generate_config_mapping: Dict[InferenceRequest, Tuple] = {
702
+ r: r.get_generate_configs(tokenizer.eos_token_id) for r in valid_req_list
703
+ }
704
+ s_time = time.time()
705
+
706
+ prefill_reqs = []
707
+ prompts = []
708
+ decode_reqs = []
709
+ for r in valid_req_list:
710
+ if r.is_prefill:
711
+ prompts.append(r.full_prompt if r.full_prompt is not None else r.prompt)
712
+ prefill_reqs.append(r)
713
+ else:
714
+ decode_reqs.append(r)
715
+
716
+ if prompts: # prefill first
717
+ input_ids: List[List[int]] = tokenizer(prompts, padding=False).input_ids
718
+ prompt_tokens = []
719
+ for i, input_id in enumerate(input_ids):
720
+ req = valid_req_list[i]
721
+ max_src_len = get_max_src_len(context_len, req)
722
+ req.prompt_tokens = input_id[-max_src_len:]
723
+ prompt_tokens.append(req.prompt_tokens)
724
+ _pad_seqs_inplace(prompt_tokens, 0)
725
+ out = model(torch.as_tensor(prompt_tokens, device=device), use_cache=True)
726
+
727
+ logits = out.logits
728
+ past_key_values = out.past_key_values
729
+
730
+ for i, r in enumerate(prefill_reqs):
731
+ (
732
+ max_new_tokens,
733
+ stream_interval,
734
+ include_usage,
735
+ stop_str,
736
+ stop_token_ids,
737
+ temperature,
738
+ repetition_penalty,
739
+ top_p,
740
+ top_k,
741
+ ) = generate_config_mapping[r]
742
+
743
+ token = _get_token_from_logits(
744
+ r, i, logits, temperature, repetition_penalty, top_p, top_k
745
+ )
746
+ r.is_prefill = False
747
+ r.append_new_token(token)
748
+
749
+ if decode_reqs:
750
+ decode_kv = decode_reqs[0].kv_cache
751
+ # prefill and decode kv cache need to be merged at `batch_size` and `seq_len` dimensions.
752
+ merged_kv_cache = _merge_kv_cache(decode_kv, past_key_values)
753
+ for r in valid_req_list:
754
+ r.kv_cache = merged_kv_cache
755
+ empty_cache()
756
+ else:
757
+ for r in valid_req_list:
758
+ r.kv_cache = past_key_values
759
+
760
+ past_key_values = valid_req_list[0].kv_cache
761
+ stop_token_mapping: Dict[InferenceRequest, int] = {}
762
+ output_mapping: Dict[InferenceRequest, str] = {}
763
+ # here, only decode phase, just run some rounds
764
+ for _i in range(decode_round):
765
+ decode_tokens: List[List[int]] = [[r.new_tokens[-1]] for r in valid_req_list]
766
+ out = model(
767
+ input_ids=torch.as_tensor(decode_tokens, device=device),
768
+ use_cache=True,
769
+ past_key_values=past_key_values,
770
+ )
771
+ logits = out.logits
772
+ past_key_values = out.past_key_values
773
+
774
+ for i, r in enumerate(valid_req_list):
775
+ (
776
+ max_new_tokens,
777
+ stream_interval,
778
+ include_usage,
779
+ stop_str,
780
+ stop_token_ids,
781
+ temperature,
782
+ repetition_penalty,
783
+ top_p,
784
+ top_k,
785
+ ) = generate_config_mapping[r]
786
+
787
+ token = _get_token_from_logits(
788
+ r, i, logits, temperature, repetition_penalty, top_p, top_k
789
+ )
790
+ r.kv_cache = past_key_values
791
+ r.append_new_token(token)
792
+
793
+ output = None
794
+ if not r.stopped:
795
+ stopped = token in stop_token_ids
796
+
797
+ if stopped:
798
+ finish_reason = "stop"
799
+ elif len(r.new_tokens) == max_new_tokens:
800
+ finish_reason = "length"
801
+ stopped = True
802
+ else:
803
+ finish_reason = None
804
+
805
+ # handle stop str
806
+ if stop_str and r not in output_mapping:
807
+ output = tokenizer.decode(
808
+ r.new_tokens,
809
+ skip_special_tokens=True,
810
+ spaces_between_special_tokens=False,
811
+ clean_up_tokenization_spaces=True,
812
+ )
813
+ if isinstance(stop_str, str):
814
+ stop_str = [stop_str]
815
+ for stop in stop_str:
816
+ pos = output.rfind(stop)
817
+ if pos != -1:
818
+ output = output[:pos]
819
+ output_mapping[r] = output
820
+ stopped = True
821
+ finish_reason = "stop"
822
+ break
823
+
824
+ r.stopped = stopped
825
+ r.finish_reason = finish_reason
826
+
827
+ if r.stopped and r not in stop_token_mapping and r not in output_mapping:
828
+ stop_token_mapping[r] = _i + 1
829
+
830
+ if r.stream:
831
+ """
832
+ Note that you can't just decode based on the newest r.new_tokens here,
833
+ which may destroy the integrity of the parsed characters,
834
+ and at the same time is not good at handling some special characters.
835
+ So the implementation here is to decode all the tokens that have been generated each time,
836
+ and then take the slice.
837
+ """
838
+ if r.stopped or len(r.new_tokens) % stream_interval == 0:
839
+ if output is None:
840
+ output = tokenizer.decode(
841
+ r.new_tokens,
842
+ skip_special_tokens=True,
843
+ spaces_between_special_tokens=False,
844
+ clean_up_tokenization_spaces=True,
845
+ )
846
+
847
+ if r.last_output_length == 0:
848
+ r.completion.append(bos_flag)
849
+
850
+ # this special character is mainly for qwen
851
+ output = output.strip("�")
852
+ output = output[r.last_output_length :]
853
+ r.last_output_length += len(output)
854
+
855
+ completion_chunk = _get_completion_chunk(
856
+ output, r.chunk_id, r.finish_reason, model_uid, r, False
857
+ )
858
+ r.completion.append(completion_chunk)
859
+ if r.stopped:
860
+ r.completion.append(eos_flag)
861
+
862
+ # last round, handle stream result
863
+ # append usage information when enable `include_usage` for OPENAI API compatibility
864
+ # The reason for counting the usage in the last round of the iteration is that,
865
+ # these tokens are real generated and should be counted.
866
+ if r.stopped and _i == decode_round - 1 and include_usage:
867
+ r.completion.append(
868
+ _get_completion_chunk(
869
+ "", r.chunk_id, r.finish_reason, model_uid, r, True
870
+ )
871
+ )
872
+ else:
873
+ # last round, handle non-stream result
874
+ if r.stopped and _i == decode_round - 1:
875
+ invalid_token_num = decode_round - stop_token_mapping[r]
876
+ outputs = (
877
+ tokenizer.decode(
878
+ r.new_tokens[: -(invalid_token_num + 1)]
879
+ if r.finish_reason == "stop"
880
+ else r.new_tokens[:-invalid_token_num],
881
+ skip_special_tokens=True,
882
+ spaces_between_special_tokens=False,
883
+ clean_up_tokenization_spaces=True,
884
+ )
885
+ if r not in output_mapping
886
+ else output_mapping[r]
887
+ )
888
+ completion = _get_completion(
889
+ outputs, r.chunk_id, r.finish_reason, model_uid, r
890
+ )
891
+ r.completion = [completion]
892
+
893
+ e_time = time.time()
894
+ logger.debug(
895
+ f"Average throughput for a step: {(len(valid_req_list) * decode_round + len(prompts)) / (e_time - s_time)} token/s."
896
+ )
897
+
898
+
899
+ def batch_inference_one_step(
900
+ req_list: List[InferenceRequest],
901
+ model_uid,
902
+ model,
903
+ tokenizer,
904
+ device,
905
+ context_len: int,
906
+ ):
907
+ from ....core.model import OutOfMemoryError
908
+
909
+ try:
910
+ _batch_inference_one_step_internal(
911
+ req_list, model_uid, model, tokenizer, device, context_len
912
+ )
913
+ except OutOfMemoryError:
914
+ logger.exception(
915
+ f"Batch inference out of memory. "
916
+ f"Xinference will restart the model: {model_uid}. "
917
+ f"Please be patient for a few moments."
918
+ )
919
+ # Just kill the process and let xinference auto-recover the model
920
+ os._exit(1)
921
+ except Exception as e:
922
+ logger.exception(f"Internal error for batch inference: {e}.")
923
+ # If internal error happens, just skip all the requests in this batch.
924
+ # If not handle here, the client will hang.
925
+ for r in req_list:
926
+ r.stopped = True
927
+ r.error_msg = str(e)
@@ -607,7 +607,7 @@ Begin!"""
607
607
  return arguments, None, None
608
608
 
609
609
  @staticmethod
610
- def _eval_chatglm3_arguments(c, tools):
610
+ def _eval_glm_chat_arguments(c, tools):
611
611
  if isinstance(c[0], str):
612
612
  return c[0], None, None
613
613
  return None, c[0]["name"], c[0]["parameters"]
@@ -659,9 +659,9 @@ Begin!"""
659
659
  family = model_family.model_family or model_family.model_name
660
660
  if family in ["gorilla-openfunctions-v1", "gorilla-openfunctions-v2"]:
661
661
  content, func, args = cls._eval_gorilla_openfunctions_arguments(c, tools)
662
- elif "chatglm3" == family:
663
- content, func, args = cls._eval_chatglm3_arguments(c, tools)
664
- elif family in ["qwen-chat", "qwen1.5-chat"]:
662
+ elif family in ["chatglm3", "glm4-chat"]:
663
+ content, func, args = cls._eval_glm_chat_arguments(c, tools)
664
+ elif family in ["qwen-chat", "qwen1.5-chat", "qwen2-instruct"]:
665
665
  content, func, args = cls._eval_qwen_chat_arguments(c, tools)
666
666
  else:
667
667
  raise Exception(
@@ -676,28 +676,29 @@ Begin!"""
676
676
  Generates a filter function for Qwen series models to retain outputs after "\nFinal Answer:".
677
677
 
678
678
  Returns:
679
- A function that takes tokens (string output by the model so far) as input
680
- returns True if current token is after "\nFinal Answer:", else False.
679
+ A function that takes tokens (string output by the model so far) and delta (new tokens added) as input,
680
+ returns the part after "\nFinal Answer:" if found, else returns delta.
681
681
  """
682
682
  family = model_family.model_family or model_family.model_name
683
683
  if family in ["qwen-chat", "qwen1.5-chat"]:
684
684
  # Encapsulating function to reset 'found' after each call
685
685
  found = False
686
686
 
687
- def process_token(tokens: str):
687
+ def process_tokens(tokens: str, delta: str):
688
688
  nonlocal found
689
689
  # Once "Final Answer:" is found, future tokens are allowed.
690
690
  if found:
691
- return True
691
+ return delta
692
692
  # Check if the token ends with "\nFinal Answer:" and update `found`.
693
- if tokens.endswith("\nFinal Answer:"):
693
+ final_answer_idx = tokens.lower().rfind("\nfinal answer:")
694
+ if final_answer_idx != -1:
694
695
  found = True
695
- return False
696
+ return tokens[final_answer_idx + len("\nfinal answer:") :]
697
+ return ""
696
698
 
697
- return process_token
699
+ return process_tokens
698
700
  else:
699
- # For other families, allow all tokens.
700
- return lambda tokens: True
701
+ return lambda tokens, delta: delta
701
702
 
702
703
  @classmethod
703
704
  def _tool_calls_completion(cls, model_family, model_uid, c, tools):
@@ -93,6 +93,7 @@ VLLM_SUPPORTED_MODELS = [
93
93
  "baichuan",
94
94
  "internlm-16k",
95
95
  "mistral-v0.1",
96
+ "codestral-v0.1",
96
97
  "Yi",
97
98
  "Yi-1.5",
98
99
  "code-llama",
@@ -118,11 +119,14 @@ VLLM_SUPPORTED_CHAT_MODELS = [
118
119
  "code-llama-instruct",
119
120
  "mistral-instruct-v0.1",
120
121
  "mistral-instruct-v0.2",
122
+ "mistral-instruct-v0.3",
121
123
  "mixtral-instruct-v0.1",
122
124
  "mixtral-8x22B-instruct-v0.1",
123
125
  "chatglm3",
124
126
  "chatglm3-32k",
125
127
  "chatglm3-128k",
128
+ "glm4-chat",
129
+ "glm4-chat-1m",
126
130
  "deepseek-chat",
127
131
  "deepseek-coder-instruct",
128
132
  ]
@@ -130,6 +134,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
130
134
  VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-chat")
131
135
  VLLM_SUPPORTED_MODELS.append("codeqwen1.5")
132
136
  VLLM_SUPPORTED_CHAT_MODELS.append("codeqwen1.5-chat")
137
+ VLLM_SUPPORTED_CHAT_MODELS.append("qwen2-instruct")
133
138
 
134
139
  if VLLM_INSTALLED and vllm.__version__ >= "0.3.2":
135
140
  VLLM_SUPPORTED_CHAT_MODELS.append("gemma-it")
@@ -140,6 +145,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.3":
140
145
 
141
146
  if VLLM_INSTALLED and vllm.__version__ >= "0.4.0":
142
147
  VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-moe-chat")
148
+ VLLM_SUPPORTED_CHAT_MODELS.append("qwen2-moe-instruct")
143
149
  VLLM_SUPPORTED_CHAT_MODELS.append("c4ai-command-r-v01")
144
150
 
145
151
 
@@ -438,7 +444,9 @@ class VLLMModel(LLM):
438
444
  _content, func, args = ChatModelMixin._eval_tool_arguments(
439
445
  self.model_family, chunk, tools
440
446
  )
441
- choice["text"] = choice_delta
447
+ choice["text"] = tools_token_filter(
448
+ tokens=previous_texts[0], delta=choice_delta
449
+ )
442
450
  if func is not None:
443
451
  choice["text"] = None
444
452
  choice["finish_reason"] = "tool_calls"
@@ -452,9 +460,13 @@ class VLLMModel(LLM):
452
460
  ),
453
461
  )
454
462
  ]
455
- # use a filter function to skip Qwen's react thought process
456
- elif not tools_token_filter(previous_texts[0]):
457
- continue
463
+ else:
464
+ # use a filter function to skip Qwen's react thought process
465
+ choice["text"] = tools_token_filter(
466
+ tokens=previous_texts[0], delta=choice["text"]
467
+ )
468
+ if not choice["text"]:
469
+ continue
458
470
  prompt_tokens = len(_request_output.prompt_token_ids)
459
471
  completion_tokens = sum(
460
472
  len(output.token_ids) for output in _request_output.outputs
xinference/model/utils.py CHANGED
@@ -42,14 +42,20 @@ def is_locale_chinese_simplified() -> bool:
42
42
 
43
43
 
44
44
  def download_from_modelscope() -> bool:
45
- if os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "modelscope":
46
- return True
45
+ if os.environ.get(XINFERENCE_ENV_MODEL_SRC):
46
+ return os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "modelscope"
47
47
  elif is_locale_chinese_simplified():
48
48
  return True
49
49
  else:
50
50
  return False
51
51
 
52
52
 
53
+ def download_from_csghub() -> bool:
54
+ if os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "csghub":
55
+ return True
56
+ return False
57
+
58
+
53
59
  def symlink_local_file(path: str, local_dir: str, relpath: str) -> str:
54
60
  from huggingface_hub.file_download import _create_symlink
55
61
 
@@ -0,0 +1 @@
1
+ from .core import Chat