xinference 0.15.4__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (67) hide show
  1. xinference/__init__.py +0 -4
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +48 -0
  4. xinference/client/restful/restful_client.py +19 -0
  5. xinference/constants.py +4 -4
  6. xinference/core/chat_interface.py +5 -1
  7. xinference/core/image_interface.py +5 -1
  8. xinference/core/model.py +195 -34
  9. xinference/core/scheduler.py +10 -7
  10. xinference/core/utils.py +9 -0
  11. xinference/model/__init__.py +4 -0
  12. xinference/model/audio/chattts.py +25 -14
  13. xinference/model/audio/model_spec.json +1 -1
  14. xinference/model/audio/model_spec_modelscope.json +1 -1
  15. xinference/model/embedding/model_spec.json +1 -1
  16. xinference/model/image/core.py +59 -4
  17. xinference/model/image/model_spec.json +24 -3
  18. xinference/model/image/model_spec_modelscope.json +25 -3
  19. xinference/model/image/ocr/__init__.py +13 -0
  20. xinference/model/image/ocr/got_ocr2.py +76 -0
  21. xinference/model/image/scheduler/__init__.py +13 -0
  22. xinference/model/image/scheduler/flux.py +533 -0
  23. xinference/model/image/stable_diffusion/core.py +8 -34
  24. xinference/model/image/stable_diffusion/mlx.py +221 -0
  25. xinference/model/image/utils.py +39 -3
  26. xinference/model/llm/__init__.py +2 -0
  27. xinference/model/llm/llm_family.json +178 -1
  28. xinference/model/llm/llm_family_modelscope.json +119 -0
  29. xinference/model/llm/transformers/chatglm.py +104 -0
  30. xinference/model/llm/transformers/core.py +37 -111
  31. xinference/model/llm/transformers/deepseek_v2.py +0 -226
  32. xinference/model/llm/transformers/internlm2.py +3 -95
  33. xinference/model/llm/transformers/opt.py +68 -0
  34. xinference/model/llm/transformers/utils.py +4 -284
  35. xinference/model/llm/utils.py +2 -2
  36. xinference/model/llm/vllm/core.py +16 -1
  37. xinference/thirdparty/mlx/__init__.py +13 -0
  38. xinference/thirdparty/mlx/flux/__init__.py +15 -0
  39. xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
  40. xinference/thirdparty/mlx/flux/clip.py +154 -0
  41. xinference/thirdparty/mlx/flux/datasets.py +75 -0
  42. xinference/thirdparty/mlx/flux/flux.py +247 -0
  43. xinference/thirdparty/mlx/flux/layers.py +302 -0
  44. xinference/thirdparty/mlx/flux/lora.py +76 -0
  45. xinference/thirdparty/mlx/flux/model.py +134 -0
  46. xinference/thirdparty/mlx/flux/sampler.py +56 -0
  47. xinference/thirdparty/mlx/flux/t5.py +244 -0
  48. xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
  49. xinference/thirdparty/mlx/flux/trainer.py +98 -0
  50. xinference/thirdparty/mlx/flux/utils.py +179 -0
  51. xinference/utils.py +2 -3
  52. xinference/web/ui/build/asset-manifest.json +3 -3
  53. xinference/web/ui/build/index.html +1 -1
  54. xinference/web/ui/build/static/js/{main.e51a356d.js → main.b76aeeb7.js} +3 -3
  55. xinference/web/ui/build/static/js/main.b76aeeb7.js.map +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/32ea2c04cf0bba2761b4883d2c40cc259952c94d2d6bb774e510963ca37aac0a.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
  58. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/METADATA +49 -10
  59. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/RECORD +64 -44
  60. xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
  63. /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.b76aeeb7.js.LICENSE.txt} +0 -0
  64. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/LICENSE +0 -0
  65. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/WHEEL +0 -0
  66. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/entry_points.txt +0 -0
  67. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import json
15
+ import logging
15
16
  import typing
16
17
  import uuid
17
18
  from threading import Thread
@@ -29,6 +30,8 @@ from ..utils import (
29
30
  )
30
31
  from .core import PytorchChatModel, PytorchModelConfig
31
32
 
33
+ logger = logging.getLogger(__name__)
34
+
32
35
 
33
36
  class ChatglmPytorchChatModel(PytorchChatModel):
34
37
  def __init__(
@@ -445,3 +448,104 @@ class ChatglmPytorchChatModel(PytorchChatModel):
445
448
  raw_config["top_p"] = 0.8
446
449
 
447
450
  return raw_config
451
+
452
+ def prepare_batch_inference(self, req_list: List[InferenceRequest]):
453
+ super(PytorchChatModel, self).prepare_batch_inference(req_list)
454
+ for r in req_list:
455
+ try:
456
+ if not r.stopped and r.is_prefill:
457
+ tools = r.generate_config.get("tools", None)
458
+ tools = list(tools) if tools is not None else None
459
+ tool_choice = r.generate_config.get("tool_choice", "none")
460
+
461
+ r.prompt = self._process_messages(
462
+ r.prompt, tools=tools, tool_choice=tool_choice
463
+ )
464
+ r.full_prompt = self.get_full_context(
465
+ r.prompt,
466
+ self.model_family.chat_template, # type: ignore
467
+ tokenizer=self._tokenizer,
468
+ )
469
+ if tools:
470
+ r.tools = tools
471
+ except Exception as e:
472
+ logger.exception(f"prepare inference error with {e}")
473
+ r.stopped = True
474
+ r.error_msg = str(e)
475
+
476
+ def handle_chat_result_non_streaming(self, req: InferenceRequest):
477
+ if req.tools:
478
+ response = req.completion[0]["choices"][0]["text"]
479
+ usage = req.completion[0]["usage"]
480
+ function_call = self._process_response_non_streaming(
481
+ response, req.tools, use_tool=True
482
+ )
483
+ req.completion[0] = self._tool_calls_completion(
484
+ self.model_family, self.model_uid, function_call
485
+ )
486
+ req.completion[0]["usage"] = usage
487
+ else:
488
+ req.completion[0] = self._to_chat_completion(req.completion[0])
489
+
490
+ def handle_chat_result_streaming(self, req: InferenceRequest):
491
+ results = []
492
+ tools = {tool["function"]["name"] for tool in req.tools} if req.tools else {}
493
+ response = "".join(req.outputs)
494
+ eos_pos = response.find("<eos_stream>")
495
+ if eos_pos != -1:
496
+ response = response[:eos_pos]
497
+
498
+ if "<bos_stream>" in req.completion:
499
+ bos_pos = req.completion.index("<bos_stream>")
500
+ results.append(
501
+ self._get_first_chat_completion_chunk(req.completion[bos_pos + 1])
502
+ )
503
+
504
+ if req.stopped:
505
+ if tools:
506
+ new_response = self._process_response_streaming(
507
+ response, tools, end=True
508
+ )
509
+ if new_response:
510
+ if isinstance(new_response, dict): # tool call case
511
+ chunk_id = [
512
+ c for c in req.completion if not isinstance(c, str)
513
+ ][0]["id"]
514
+ results.append(
515
+ self._tool_calls_completion_chunk(
516
+ self.model_family,
517
+ self.model_uid,
518
+ new_response,
519
+ chunk_id=chunk_id,
520
+ )
521
+ )
522
+ else: # normal case
523
+ for c in req.completion:
524
+ if c == "<bos_stream>":
525
+ continue
526
+ elif c == "<eos_stream>":
527
+ break
528
+ else:
529
+ results.append(self._to_chat_completion_chunk(c))
530
+ else:
531
+ for c in req.completion:
532
+ if c == "<bos_stream>":
533
+ continue
534
+ elif c == "<eos_stream>":
535
+ break
536
+ else:
537
+ results.append(self._to_chat_completion_chunk(c))
538
+ else:
539
+ if response and response[-1] != "�":
540
+ new_response = self._process_response_streaming(
541
+ response, tools, end=False
542
+ )
543
+ if new_response is not None: # normal case
544
+ for c in req.completion:
545
+ if c == "<bos_stream>":
546
+ continue
547
+ results.append(self._to_chat_completion_chunk(c))
548
+
549
+ if req.stopped and req.include_usage:
550
+ results.append(self._get_final_chat_completion_chunk(req.completion[-1]))
551
+ req.completion = results
@@ -29,7 +29,6 @@ from ....device_utils import (
29
29
  from ....types import (
30
30
  ChatCompletion,
31
31
  ChatCompletionChunk,
32
- Completion,
33
32
  CompletionChoice,
34
33
  CompletionChunk,
35
34
  CreateCompletionTorch,
@@ -46,9 +45,7 @@ from .utils import get_context_length, get_max_src_len, pad_prefill_tokens
46
45
  logger = logging.getLogger(__name__)
47
46
 
48
47
  NON_DEFAULT_MODEL_LIST: List[str] = [
49
- "chatglm3",
50
- "chatglm3-32k",
51
- "chatglm3-128k",
48
+ "opt",
52
49
  "glm4-chat",
53
50
  "glm4-chat-1m",
54
51
  "internlm2-chat",
@@ -345,69 +342,6 @@ class PytorchModel(LLM):
345
342
  return False
346
343
  return True
347
344
 
348
- def generate(
349
- self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None
350
- ) -> Union[Completion, Iterator[CompletionChunk]]:
351
- from .utils import generate_stream
352
-
353
- def generator_wrapper(
354
- prompt: str, generate_config: PytorchGenerateConfig
355
- ) -> Iterator[CompletionChunk]:
356
- for completion_chunk, completion_usage in generate_stream(
357
- self.model_uid,
358
- self._model,
359
- self._tokenizer,
360
- prompt,
361
- self._device,
362
- generate_config,
363
- ):
364
- completion_chunk["usage"] = completion_usage
365
- yield completion_chunk
366
-
367
- logger.debug(
368
- "Enter generate, prompt: %s, generate config: %s", prompt, generate_config
369
- )
370
-
371
- generate_config = self._sanitize_generate_config(generate_config)
372
-
373
- assert self._model is not None
374
- assert self._tokenizer is not None
375
-
376
- lora_model = generate_config.pop("lora_name")
377
-
378
- if lora_model is not None and self._peft_model is not None:
379
- for lora in self._peft_model:
380
- if lora_model == lora.lora_name:
381
- self._model.set_adapter(lora_model)
382
- logger.info(f"Set lora model to {lora_model}")
383
- break
384
- else:
385
- self._model.disable_adapter()
386
- logger.info(f"No lora model {lora_model} found, skip setting")
387
-
388
- stream = generate_config.get("stream", False)
389
- if not stream:
390
- for completion_chunk, completion_usage in generate_stream(
391
- self.model_uid,
392
- self._model,
393
- self._tokenizer,
394
- prompt,
395
- self._device,
396
- generate_config,
397
- ):
398
- pass
399
- completion = Completion(
400
- id=completion_chunk["id"],
401
- object=completion_chunk["object"],
402
- created=completion_chunk["created"],
403
- model=completion_chunk["model"],
404
- choices=completion_chunk["choices"],
405
- usage=completion_usage,
406
- )
407
- return completion
408
- else:
409
- return generator_wrapper(prompt, generate_config)
410
-
411
345
  def build_prefill_attention_mask(
412
346
  self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
413
347
  ):
@@ -730,7 +664,12 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
730
664
  messages: List[Dict],
731
665
  generate_config: Optional[PytorchGenerateConfig] = None,
732
666
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
733
- tools = generate_config.pop("tools", []) if generate_config else None
667
+ raise NotImplementedError
668
+
669
+ def load(self):
670
+ super().load()
671
+
672
+ def _get_full_prompt(self, messages: List[Dict], tools):
734
673
  model_family = self.model_family.model_family or self.model_family.model_name
735
674
  full_context_kwargs = {}
736
675
  if (
@@ -746,29 +685,6 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
746
685
  tokenizer=self._tokenizer,
747
686
  **full_context_kwargs,
748
687
  )
749
-
750
- generate_config = self._sanitize_generate_config(generate_config)
751
-
752
- stream = generate_config.get("stream", False)
753
- if stream:
754
- it = self.generate(full_prompt, generate_config)
755
- assert isinstance(it, Iterator)
756
- return self._to_chat_completion_chunks(it)
757
- else:
758
- c = self.generate(full_prompt, generate_config)
759
- assert not isinstance(c, Iterator)
760
- if tools:
761
- return self._tool_calls_completion(self.model_family, self.model_uid, c)
762
- return self._to_chat_completion(c)
763
-
764
- def load(self):
765
- super().load()
766
-
767
- def _get_full_prompt(self, messages: List[Dict], tools):
768
- assert self.model_family.chat_template is not None
769
- full_prompt = self.get_full_context(
770
- messages, self.model_family.chat_template, tokenizer=self._tokenizer
771
- )
772
688
  return full_prompt
773
689
 
774
690
  def prepare_batch_inference(self, req_list: List[InferenceRequest]):
@@ -776,12 +692,39 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
776
692
  for r in req_list:
777
693
  try:
778
694
  if not r.stopped and r.is_prefill:
779
- r.full_prompt = self._get_full_prompt(r.prompt, None)
695
+ tools = r.generate_config.get("tools", None)
696
+ r.full_prompt = self._get_full_prompt(r.prompt, tools)
697
+ if tools:
698
+ r.tools = tools
780
699
  except Exception as e:
781
700
  logger.exception(f"prepare inference error with {e}")
782
701
  r.stopped = True
783
702
  r.error_msg = str(e)
784
703
 
704
+ def handle_chat_result_non_streaming(self, req: InferenceRequest):
705
+ if req.tools:
706
+ req.completion[0] = self._tool_calls_completion(
707
+ self.model_family, self.model_uid, req.completion[0]
708
+ )
709
+ else:
710
+ req.completion[0] = self._to_chat_completion(req.completion[0])
711
+
712
+ def handle_chat_result_streaming(self, req: InferenceRequest):
713
+ results = []
714
+ for i, c in enumerate(req.completion):
715
+ if c == "<bos_stream>":
716
+ results.append(
717
+ self._get_first_chat_completion_chunk(req.completion[i + 1])
718
+ )
719
+ elif c == "<eos_stream>":
720
+ break
721
+ else:
722
+ results.append(self._to_chat_completion_chunk(c))
723
+
724
+ if req.stopped and req.include_usage:
725
+ results.append(self._get_final_chat_completion_chunk(req.completion[-1]))
726
+ req.completion = results
727
+
785
728
  def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
786
729
  for req in req_list:
787
730
  if req.error_msg is None and req.completion:
@@ -800,23 +743,6 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
800
743
  continue
801
744
 
802
745
  if req.stream:
803
- results = []
804
- for i, c in enumerate(req.completion):
805
- if c == "<bos_stream>":
806
- results.append(
807
- self._get_first_chat_completion_chunk(
808
- req.completion[i + 1]
809
- )
810
- )
811
- elif c == "<eos_stream>":
812
- break
813
- else:
814
- results.append(self._to_chat_completion_chunk(c))
815
-
816
- if req.stopped and req.include_usage:
817
- results.append(
818
- self._get_final_chat_completion_chunk(req.completion[-1])
819
- )
820
- req.completion = results
746
+ self.handle_chat_result_streaming(req)
821
747
  else:
822
- req.completion[0] = self._to_chat_completion(req.completion[0])
748
+ self.handle_chat_result_non_streaming(req)
@@ -12,24 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import logging
15
- import uuid
16
- from typing import Dict, Iterator, List, Optional, Union
17
15
 
18
16
  import torch
19
17
 
20
- from ....types import (
21
- ChatCompletion,
22
- ChatCompletionChunk,
23
- Completion,
24
- CompletionChunk,
25
- PytorchGenerateConfig,
26
- )
27
18
  from ..llm_family import LLMFamilyV1, LLMSpecV1
28
- from ..utils import (
29
- generate_chat_completion,
30
- generate_completion,
31
- generate_completion_chunk,
32
- )
33
19
  from .core import PytorchChatModel, PytorchModel
34
20
 
35
21
  logger = logging.getLogger(__name__)
@@ -80,95 +66,6 @@ class DeepSeekV2PytorchModel(PytorchModel):
80
66
  return False
81
67
  return True
82
68
 
83
- def generate(
84
- self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None
85
- ) -> Union[Completion, Iterator[CompletionChunk]]:
86
- input_tensor = self._tokenizer(prompt, return_tensors="pt")
87
- generate_config = self._sanitize_generate_config(generate_config)
88
- default_generate_config = self._model.generation_config
89
- generate_kwargs = {
90
- "input_ids": input_tensor["input_ids"].cuda(),
91
- "attention_mask": input_tensor["attention_mask"].cuda(),
92
- "temperature": float(
93
- generate_config.get("temperature", default_generate_config.temperature)
94
- ),
95
- "repetition_penalty": float(generate_config.get("repetition_penalty", 1.0)),
96
- "top_p": float(generate_config.get("top_p", default_generate_config.top_p)),
97
- "top_k": int(generate_config.get("top_k", -1)),
98
- "max_new_tokens": generate_config.get("max_tokens", 512),
99
- "bos_token_id": default_generate_config.bos_token_id,
100
- "do_sample": default_generate_config.do_sample,
101
- "eos_token_id": default_generate_config.eos_token_id,
102
- }
103
-
104
- stream = generate_config.get("stream", False)
105
- if stream:
106
- return self._generate_stream(generate_kwargs, input_tensor)
107
- else:
108
- return self._generate(generate_kwargs, input_tensor)
109
-
110
- def _generate(self, generate_kwargs, input_ids) -> Completion:
111
- prompt_tokens = len(input_ids[0])
112
- logger.info(f"generate_kwargs:{generate_kwargs}")
113
- generation_output = self._model.generate(**generate_kwargs)
114
- completion_tokens = len(generation_output[0])
115
- response = self._tokenizer.decode(
116
- generation_output[0], skip_special_tokens=True
117
- )
118
- return generate_completion(
119
- self.model_uid,
120
- response,
121
- prompt_tokens=prompt_tokens,
122
- completion_tokens=completion_tokens,
123
- total_tokens=prompt_tokens + completion_tokens,
124
- )
125
-
126
- def _generate_stream(self, generate_kwargs, input_ids):
127
- from threading import Thread
128
-
129
- from transformers import TextIteratorStreamer
130
-
131
- # Initialize the streamer
132
- streamer = TextIteratorStreamer(
133
- self._tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
134
- )
135
- # Define the generation configuration
136
- generate_kwargs["streamer"] = streamer
137
- # Start the model chat in a separate thread
138
- thread = Thread(
139
- target=self._model.generate,
140
- kwargs=generate_kwargs,
141
- )
142
- thread.start()
143
-
144
- completion_id = str(uuid.uuid1())
145
- prompt_tokens = len(input_ids[0])
146
- total_tokens, completion_tokens = 0, 0
147
- # Loop through the streamer to get the new text as it is generated
148
- for i, new_text in enumerate(streamer):
149
- completion_tokens = i
150
- total_tokens = prompt_tokens + completion_tokens
151
- yield generate_completion_chunk(
152
- chunk_text=new_text,
153
- finish_reason=None,
154
- chunk_id=completion_id,
155
- model_uid=self.model_uid,
156
- prompt_tokens=prompt_tokens,
157
- completion_tokens=completion_tokens,
158
- total_tokens=total_tokens,
159
- )
160
- yield generate_completion_chunk(
161
- chunk_text=None,
162
- finish_reason="stop",
163
- chunk_id=completion_id,
164
- model_uid=self.model_uid,
165
- prompt_tokens=prompt_tokens,
166
- completion_tokens=completion_tokens,
167
- total_tokens=total_tokens,
168
- has_choice=True,
169
- has_content=False,
170
- )
171
-
172
69
 
173
70
  class DeepSeekV2PytorchChatModel(PytorchChatModel):
174
71
  def _load_model(self, **kwargs):
@@ -215,126 +112,3 @@ class DeepSeekV2PytorchChatModel(PytorchChatModel):
215
112
  if "chat" not in llm_family.model_ability:
216
113
  return False
217
114
  return True
218
-
219
- def chat(
220
- self,
221
- messages: List[Dict],
222
- generate_config: Optional[PytorchGenerateConfig] = None,
223
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
224
- assert self.model_family.chat_template is not None
225
- full_prompt = self.get_full_context(
226
- messages,
227
- self.model_family.chat_template,
228
- tokenizer=self._tokenizer,
229
- )
230
- input_tensor = self._tokenizer.encode(
231
- full_prompt,
232
- padding=False,
233
- truncation=False,
234
- max_length=None,
235
- add_special_tokens=False,
236
- return_tensors="pt",
237
- )
238
-
239
- generate_config = self._sanitize_generate_config(generate_config)
240
- default_generate_config = self._model.generation_config
241
- generate_kwargs = {
242
- "input_ids": input_tensor.cuda(),
243
- "temperature": float(
244
- generate_config.get("temperature", default_generate_config.temperature)
245
- ),
246
- "repetition_penalty": float(generate_config.get("repetition_penalty", 1.0)),
247
- "top_p": float(generate_config.get("top_p", default_generate_config.top_p)),
248
- "top_k": int(generate_config.get("top_k", -1)),
249
- "max_new_tokens": generate_config.get("max_tokens", 512),
250
- "bos_token_id": default_generate_config.bos_token_id,
251
- "do_sample": default_generate_config.do_sample,
252
- "eos_token_id": default_generate_config.eos_token_id,
253
- }
254
-
255
- stream = generate_config.get("stream", False)
256
- stream_options = generate_config.get("stream_options", None)
257
- include_usage = (
258
- stream_options["include_usage"]
259
- if isinstance(stream_options, dict)
260
- else False
261
- )
262
- if stream:
263
- chunk = self._generate_stream(generate_kwargs, input_tensor, include_usage)
264
- return self._to_chat_completion_chunks(chunk)
265
- else:
266
- return self._generate(generate_kwargs, input_tensor)
267
-
268
- def _generate(self, generate_kwargs, input_ids) -> ChatCompletion:
269
- prompt_tokens = len(input_ids[0])
270
- generation_output = self._model.generate(**generate_kwargs)
271
- completion_tokens = len(generation_output[0])
272
- response = self._tokenizer.decode(
273
- generation_output[0][input_ids.shape[1] :], skip_special_tokens=True
274
- )
275
- return generate_chat_completion(
276
- self.model_uid,
277
- response,
278
- prompt_tokens=prompt_tokens,
279
- completion_tokens=completion_tokens,
280
- total_tokens=prompt_tokens + completion_tokens,
281
- )
282
-
283
- def _generate_stream(self, generate_kwargs, input_ids, include_usage):
284
- from threading import Thread
285
-
286
- from transformers import TextIteratorStreamer
287
-
288
- # Initialize the streamer
289
- streamer = TextIteratorStreamer(
290
- self._tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
291
- )
292
- # Define the generation configuration
293
- generate_kwargs["streamer"] = streamer
294
- # Start the model chat in a separate thread
295
- thread = Thread(
296
- target=self._model.generate,
297
- kwargs=generate_kwargs,
298
- )
299
- thread.start()
300
-
301
- completion_id = str(uuid.uuid1())
302
- prompt_tokens = len(input_ids[0])
303
- total_tokens, completion_tokens = 0, 0
304
- # Loop through the streamer to get the new text as it is generated
305
- for i, new_text in enumerate(streamer):
306
- completion_tokens = max(completion_tokens, len(streamer.token_cache))
307
- total_tokens = prompt_tokens + completion_tokens
308
- yield generate_completion_chunk(
309
- chunk_text=new_text,
310
- finish_reason=None,
311
- chunk_id=completion_id,
312
- model_uid=self.model_uid,
313
- prompt_tokens=prompt_tokens,
314
- completion_tokens=completion_tokens,
315
- total_tokens=total_tokens,
316
- )
317
- yield generate_completion_chunk(
318
- chunk_text=None,
319
- finish_reason="stop",
320
- chunk_id=completion_id,
321
- model_uid=self.model_uid,
322
- prompt_tokens=prompt_tokens,
323
- completion_tokens=completion_tokens,
324
- total_tokens=total_tokens,
325
- has_choice=True,
326
- has_content=False,
327
- )
328
-
329
- if include_usage:
330
- yield generate_completion_chunk(
331
- chunk_text=None,
332
- finish_reason=None,
333
- chunk_id=completion_id,
334
- model_uid=self.model_uid,
335
- prompt_tokens=prompt_tokens,
336
- completion_tokens=completion_tokens,
337
- total_tokens=total_tokens,
338
- has_choice=False,
339
- has_content=False,
340
- )
@@ -11,13 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import uuid
15
- from typing import Any, Dict, Iterator, List, Optional, Union
14
+
15
+ from typing import List, Optional
16
16
 
17
17
  from ....core.scheduler import InferenceRequest
18
- from ....types import ChatCompletion, ChatCompletionChunk, LoRA, PytorchGenerateConfig
18
+ from ....types import LoRA
19
19
  from ..llm_family import LLMFamilyV1, LLMSpecV1
20
- from ..utils import generate_chat_completion, generate_completion_chunk, parse_messages
21
20
  from .core import PytorchChatModel, PytorchModelConfig
22
21
 
23
22
 
@@ -93,94 +92,3 @@ class Internlm2PytorchChatModel(PytorchChatModel):
93
92
  if top_p is None:
94
93
  raw_config["top_p"] = 0.8
95
94
  return raw_config
96
-
97
- def chat(
98
- self,
99
- messages: List[Dict],
100
- generate_config: Optional[PytorchGenerateConfig] = None,
101
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
102
- kwargs: Dict[str, Any] = {}
103
- generate_config = generate_config or {}
104
- temperature = generate_config.get("temperature")
105
- if temperature is not None:
106
- kwargs["temperature"] = float(temperature)
107
- top_p = generate_config.get("top_p")
108
- if top_p is not None:
109
- kwargs["top_p"] = float(top_p)
110
- max_new_tokens = generate_config.get("max_tokens")
111
- if max_new_tokens is not None:
112
- kwargs["max_length"] = int(max_new_tokens)
113
-
114
- stream = generate_config.get("stream", False)
115
- stream_options = generate_config.pop("stream_options", None)
116
- include_usage = (
117
- stream_options["include_usage"]
118
- if isinstance(stream_options, dict)
119
- else False
120
- )
121
-
122
- prompt, system_prompt, chat_history = parse_messages(messages)
123
- if chat_history:
124
- input_history = [
125
- (chat_history[i]["content"], (chat_history[i + 1]["content"]))
126
- for i in range(0, len(chat_history), 2)
127
- ]
128
- else:
129
- input_history = []
130
- if system_prompt:
131
- kwargs["meta_instruction"] = system_prompt
132
- if stream:
133
-
134
- def _stream_generator():
135
- last_chunk_text_length = 0
136
- chunk_id = "chat-" + str(uuid.uuid1())
137
- prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
138
- inputs = self._tokenizer([prompt], return_tensors="pt")
139
- inputs = inputs.to(self._model.device)
140
- prompt_tokens = len(inputs["input_ids"][0])
141
- for chunk_text, _ in self._model.stream_chat(
142
- self._tokenizer, prompt, input_history, **kwargs
143
- ):
144
- completion_tokens = completion_tokens + 1
145
- total_tokens = prompt_tokens + completion_tokens
146
- chunk_text = chunk_text[last_chunk_text_length:]
147
- last_chunk_text_length += len(chunk_text)
148
-
149
- yield generate_completion_chunk(
150
- chunk_text,
151
- finish_reason=None,
152
- chunk_id=chunk_id,
153
- model_uid=self.model_uid,
154
- prompt_tokens=prompt_tokens,
155
- completion_tokens=completion_tokens,
156
- total_tokens=total_tokens,
157
- )
158
- yield generate_completion_chunk(
159
- None,
160
- finish_reason="stop",
161
- chunk_id=chunk_id,
162
- model_uid=self.model_uid,
163
- prompt_tokens=prompt_tokens,
164
- completion_tokens=completion_tokens,
165
- total_tokens=total_tokens,
166
- has_choice=True,
167
- has_content=False,
168
- )
169
- if include_usage:
170
- yield generate_completion_chunk(
171
- None,
172
- finish_reason=None,
173
- chunk_id=chunk_id,
174
- model_uid=self.model_uid,
175
- prompt_tokens=prompt_tokens,
176
- completion_tokens=completion_tokens,
177
- total_tokens=total_tokens,
178
- has_choice=False,
179
- )
180
-
181
- return self._to_chat_completion_chunks(_stream_generator())
182
- else:
183
- response, _ = self._model.chat(
184
- self._tokenizer, prompt, input_history, **kwargs
185
- )
186
- return generate_chat_completion(self.model_uid, response)