xinference 0.15.3__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (65) hide show
  1. xinference/__init__.py +0 -4
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +29 -2
  4. xinference/client/restful/restful_client.py +10 -0
  5. xinference/constants.py +7 -3
  6. xinference/core/image_interface.py +76 -23
  7. xinference/core/model.py +158 -46
  8. xinference/core/progress_tracker.py +187 -0
  9. xinference/core/scheduler.py +10 -7
  10. xinference/core/supervisor.py +11 -0
  11. xinference/core/utils.py +9 -0
  12. xinference/core/worker.py +1 -0
  13. xinference/deploy/supervisor.py +4 -0
  14. xinference/model/__init__.py +4 -0
  15. xinference/model/audio/chattts.py +2 -1
  16. xinference/model/audio/core.py +0 -2
  17. xinference/model/audio/model_spec.json +8 -0
  18. xinference/model/audio/model_spec_modelscope.json +9 -0
  19. xinference/model/image/core.py +6 -7
  20. xinference/model/image/scheduler/__init__.py +13 -0
  21. xinference/model/image/scheduler/flux.py +533 -0
  22. xinference/model/image/sdapi.py +35 -4
  23. xinference/model/image/stable_diffusion/core.py +215 -110
  24. xinference/model/image/utils.py +39 -3
  25. xinference/model/llm/__init__.py +2 -0
  26. xinference/model/llm/llm_family.json +185 -17
  27. xinference/model/llm/llm_family_modelscope.json +124 -12
  28. xinference/model/llm/transformers/chatglm.py +104 -0
  29. xinference/model/llm/transformers/cogvlm2.py +2 -1
  30. xinference/model/llm/transformers/cogvlm2_video.py +2 -0
  31. xinference/model/llm/transformers/core.py +43 -113
  32. xinference/model/llm/transformers/deepseek_v2.py +0 -226
  33. xinference/model/llm/transformers/deepseek_vl.py +2 -0
  34. xinference/model/llm/transformers/glm4v.py +2 -1
  35. xinference/model/llm/transformers/intern_vl.py +2 -0
  36. xinference/model/llm/transformers/internlm2.py +3 -95
  37. xinference/model/llm/transformers/minicpmv25.py +2 -0
  38. xinference/model/llm/transformers/minicpmv26.py +2 -0
  39. xinference/model/llm/transformers/omnilmm.py +2 -0
  40. xinference/model/llm/transformers/opt.py +68 -0
  41. xinference/model/llm/transformers/qwen2_audio.py +11 -4
  42. xinference/model/llm/transformers/qwen2_vl.py +2 -28
  43. xinference/model/llm/transformers/qwen_vl.py +2 -1
  44. xinference/model/llm/transformers/utils.py +36 -283
  45. xinference/model/llm/transformers/yi_vl.py +2 -0
  46. xinference/model/llm/utils.py +60 -16
  47. xinference/model/llm/vllm/core.py +68 -9
  48. xinference/model/llm/vllm/utils.py +0 -1
  49. xinference/model/utils.py +7 -4
  50. xinference/model/video/core.py +0 -2
  51. xinference/utils.py +2 -3
  52. xinference/web/ui/build/asset-manifest.json +3 -3
  53. xinference/web/ui/build/index.html +1 -1
  54. xinference/web/ui/build/static/js/{main.e51a356d.js → main.f7da0140.js} +3 -3
  55. xinference/web/ui/build/static/js/main.f7da0140.js.map +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
  57. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/METADATA +38 -6
  58. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/RECORD +63 -59
  59. xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
  61. /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.f7da0140.js.LICENSE.txt} +0 -0
  62. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/LICENSE +0 -0
  63. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/WHEEL +0 -0
  64. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/entry_points.txt +0 -0
  65. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,6 @@ from ....device_utils import (
29
29
  from ....types import (
30
30
  ChatCompletion,
31
31
  ChatCompletionChunk,
32
- Completion,
33
32
  CompletionChoice,
34
33
  CompletionChunk,
35
34
  CreateCompletionTorch,
@@ -40,15 +39,13 @@ from ....types import (
40
39
  from ...utils import select_device
41
40
  from ..core import LLM
42
41
  from ..llm_family import LLMFamilyV1, LLMSpecV1
43
- from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin
42
+ from ..utils import LLAMA3_TOOL_CALL_FAMILY, QWEN_TOOL_CALL_FAMILY, ChatModelMixin
44
43
  from .utils import get_context_length, get_max_src_len, pad_prefill_tokens
45
44
 
46
45
  logger = logging.getLogger(__name__)
47
46
 
48
47
  NON_DEFAULT_MODEL_LIST: List[str] = [
49
- "chatglm3",
50
- "chatglm3-32k",
51
- "chatglm3-128k",
48
+ "opt",
52
49
  "glm4-chat",
53
50
  "glm4-chat-1m",
54
51
  "internlm2-chat",
@@ -345,69 +342,6 @@ class PytorchModel(LLM):
345
342
  return False
346
343
  return True
347
344
 
348
- def generate(
349
- self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None
350
- ) -> Union[Completion, Iterator[CompletionChunk]]:
351
- from .utils import generate_stream
352
-
353
- def generator_wrapper(
354
- prompt: str, generate_config: PytorchGenerateConfig
355
- ) -> Iterator[CompletionChunk]:
356
- for completion_chunk, completion_usage in generate_stream(
357
- self.model_uid,
358
- self._model,
359
- self._tokenizer,
360
- prompt,
361
- self._device,
362
- generate_config,
363
- ):
364
- completion_chunk["usage"] = completion_usage
365
- yield completion_chunk
366
-
367
- logger.debug(
368
- "Enter generate, prompt: %s, generate config: %s", prompt, generate_config
369
- )
370
-
371
- generate_config = self._sanitize_generate_config(generate_config)
372
-
373
- assert self._model is not None
374
- assert self._tokenizer is not None
375
-
376
- lora_model = generate_config.pop("lora_name")
377
-
378
- if lora_model is not None and self._peft_model is not None:
379
- for lora in self._peft_model:
380
- if lora_model == lora.lora_name:
381
- self._model.set_adapter(lora_model)
382
- logger.info(f"Set lora model to {lora_model}")
383
- break
384
- else:
385
- self._model.disable_adapter()
386
- logger.info(f"No lora model {lora_model} found, skip setting")
387
-
388
- stream = generate_config.get("stream", False)
389
- if not stream:
390
- for completion_chunk, completion_usage in generate_stream(
391
- self.model_uid,
392
- self._model,
393
- self._tokenizer,
394
- prompt,
395
- self._device,
396
- generate_config,
397
- ):
398
- pass
399
- completion = Completion(
400
- id=completion_chunk["id"],
401
- object=completion_chunk["object"],
402
- created=completion_chunk["created"],
403
- model=completion_chunk["model"],
404
- choices=completion_chunk["choices"],
405
- usage=completion_usage,
406
- )
407
- return completion
408
- else:
409
- return generator_wrapper(prompt, generate_config)
410
-
411
345
  def build_prefill_attention_mask(
412
346
  self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
413
347
  ):
@@ -730,10 +664,19 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
730
664
  messages: List[Dict],
731
665
  generate_config: Optional[PytorchGenerateConfig] = None,
732
666
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
733
- tools = generate_config.pop("tools", []) if generate_config else None
667
+ raise NotImplementedError
668
+
669
+ def load(self):
670
+ super().load()
671
+
672
+ def _get_full_prompt(self, messages: List[Dict], tools):
734
673
  model_family = self.model_family.model_family or self.model_family.model_name
735
674
  full_context_kwargs = {}
736
- if tools and model_family in QWEN_TOOL_CALL_FAMILY:
675
+ if (
676
+ tools
677
+ and model_family in QWEN_TOOL_CALL_FAMILY
678
+ or model_family in LLAMA3_TOOL_CALL_FAMILY
679
+ ):
737
680
  full_context_kwargs["tools"] = tools
738
681
  assert self.model_family.chat_template is not None
739
682
  full_prompt = self.get_full_context(
@@ -742,29 +685,6 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
742
685
  tokenizer=self._tokenizer,
743
686
  **full_context_kwargs,
744
687
  )
745
-
746
- generate_config = self._sanitize_generate_config(generate_config)
747
-
748
- stream = generate_config.get("stream", False)
749
- if stream:
750
- it = self.generate(full_prompt, generate_config)
751
- assert isinstance(it, Iterator)
752
- return self._to_chat_completion_chunks(it)
753
- else:
754
- c = self.generate(full_prompt, generate_config)
755
- assert not isinstance(c, Iterator)
756
- if tools:
757
- return self._tool_calls_completion(self.model_family, self.model_uid, c)
758
- return self._to_chat_completion(c)
759
-
760
- def load(self):
761
- super().load()
762
-
763
- def _get_full_prompt(self, messages: List[Dict], tools):
764
- assert self.model_family.chat_template is not None
765
- full_prompt = self.get_full_context(
766
- messages, self.model_family.chat_template, tokenizer=self._tokenizer
767
- )
768
688
  return full_prompt
769
689
 
770
690
  def prepare_batch_inference(self, req_list: List[InferenceRequest]):
@@ -772,12 +692,39 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
772
692
  for r in req_list:
773
693
  try:
774
694
  if not r.stopped and r.is_prefill:
775
- r.full_prompt = self._get_full_prompt(r.prompt, None)
695
+ tools = r.generate_config.get("tools", None)
696
+ r.full_prompt = self._get_full_prompt(r.prompt, tools)
697
+ if tools:
698
+ r.tools = tools
776
699
  except Exception as e:
777
700
  logger.exception(f"prepare inference error with {e}")
778
701
  r.stopped = True
779
702
  r.error_msg = str(e)
780
703
 
704
+ def handle_chat_result_non_streaming(self, req: InferenceRequest):
705
+ if req.tools:
706
+ req.completion[0] = self._tool_calls_completion(
707
+ self.model_family, self.model_uid, req.completion[0]
708
+ )
709
+ else:
710
+ req.completion[0] = self._to_chat_completion(req.completion[0])
711
+
712
+ def handle_chat_result_streaming(self, req: InferenceRequest):
713
+ results = []
714
+ for i, c in enumerate(req.completion):
715
+ if c == "<bos_stream>":
716
+ results.append(
717
+ self._get_first_chat_completion_chunk(req.completion[i + 1])
718
+ )
719
+ elif c == "<eos_stream>":
720
+ break
721
+ else:
722
+ results.append(self._to_chat_completion_chunk(c))
723
+
724
+ if req.stopped and req.include_usage:
725
+ results.append(self._get_final_chat_completion_chunk(req.completion[-1]))
726
+ req.completion = results
727
+
781
728
  def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
782
729
  for req in req_list:
783
730
  if req.error_msg is None and req.completion:
@@ -796,23 +743,6 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
796
743
  continue
797
744
 
798
745
  if req.stream:
799
- results = []
800
- for i, c in enumerate(req.completion):
801
- if c == "<bos_stream>":
802
- results.append(
803
- self._get_first_chat_completion_chunk(
804
- req.completion[i + 1]
805
- )
806
- )
807
- elif c == "<eos_stream>":
808
- break
809
- else:
810
- results.append(self._to_chat_completion_chunk(c))
811
-
812
- if req.stopped and req.include_usage:
813
- results.append(
814
- self._get_final_chat_completion_chunk(req.completion[-1])
815
- )
816
- req.completion = results
746
+ self.handle_chat_result_streaming(req)
817
747
  else:
818
- req.completion[0] = self._to_chat_completion(req.completion[0])
748
+ self.handle_chat_result_non_streaming(req)
@@ -12,24 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import logging
15
- import uuid
16
- from typing import Dict, Iterator, List, Optional, Union
17
15
 
18
16
  import torch
19
17
 
20
- from ....types import (
21
- ChatCompletion,
22
- ChatCompletionChunk,
23
- Completion,
24
- CompletionChunk,
25
- PytorchGenerateConfig,
26
- )
27
18
  from ..llm_family import LLMFamilyV1, LLMSpecV1
28
- from ..utils import (
29
- generate_chat_completion,
30
- generate_completion,
31
- generate_completion_chunk,
32
- )
33
19
  from .core import PytorchChatModel, PytorchModel
34
20
 
35
21
  logger = logging.getLogger(__name__)
@@ -80,95 +66,6 @@ class DeepSeekV2PytorchModel(PytorchModel):
80
66
  return False
81
67
  return True
82
68
 
83
- def generate(
84
- self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None
85
- ) -> Union[Completion, Iterator[CompletionChunk]]:
86
- input_tensor = self._tokenizer(prompt, return_tensors="pt")
87
- generate_config = self._sanitize_generate_config(generate_config)
88
- default_generate_config = self._model.generation_config
89
- generate_kwargs = {
90
- "input_ids": input_tensor["input_ids"].cuda(),
91
- "attention_mask": input_tensor["attention_mask"].cuda(),
92
- "temperature": float(
93
- generate_config.get("temperature", default_generate_config.temperature)
94
- ),
95
- "repetition_penalty": float(generate_config.get("repetition_penalty", 1.0)),
96
- "top_p": float(generate_config.get("top_p", default_generate_config.top_p)),
97
- "top_k": int(generate_config.get("top_k", -1)),
98
- "max_new_tokens": generate_config.get("max_tokens", 512),
99
- "bos_token_id": default_generate_config.bos_token_id,
100
- "do_sample": default_generate_config.do_sample,
101
- "eos_token_id": default_generate_config.eos_token_id,
102
- }
103
-
104
- stream = generate_config.get("stream", False)
105
- if stream:
106
- return self._generate_stream(generate_kwargs, input_tensor)
107
- else:
108
- return self._generate(generate_kwargs, input_tensor)
109
-
110
- def _generate(self, generate_kwargs, input_ids) -> Completion:
111
- prompt_tokens = len(input_ids[0])
112
- logger.info(f"generate_kwargs:{generate_kwargs}")
113
- generation_output = self._model.generate(**generate_kwargs)
114
- completion_tokens = len(generation_output[0])
115
- response = self._tokenizer.decode(
116
- generation_output[0], skip_special_tokens=True
117
- )
118
- return generate_completion(
119
- self.model_uid,
120
- response,
121
- prompt_tokens=prompt_tokens,
122
- completion_tokens=completion_tokens,
123
- total_tokens=prompt_tokens + completion_tokens,
124
- )
125
-
126
- def _generate_stream(self, generate_kwargs, input_ids):
127
- from threading import Thread
128
-
129
- from transformers import TextIteratorStreamer
130
-
131
- # Initialize the streamer
132
- streamer = TextIteratorStreamer(
133
- self._tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
134
- )
135
- # Define the generation configuration
136
- generate_kwargs["streamer"] = streamer
137
- # Start the model chat in a separate thread
138
- thread = Thread(
139
- target=self._model.generate,
140
- kwargs=generate_kwargs,
141
- )
142
- thread.start()
143
-
144
- completion_id = str(uuid.uuid1())
145
- prompt_tokens = len(input_ids[0])
146
- total_tokens, completion_tokens = 0, 0
147
- # Loop through the streamer to get the new text as it is generated
148
- for i, new_text in enumerate(streamer):
149
- completion_tokens = i
150
- total_tokens = prompt_tokens + completion_tokens
151
- yield generate_completion_chunk(
152
- chunk_text=new_text,
153
- finish_reason=None,
154
- chunk_id=completion_id,
155
- model_uid=self.model_uid,
156
- prompt_tokens=prompt_tokens,
157
- completion_tokens=completion_tokens,
158
- total_tokens=total_tokens,
159
- )
160
- yield generate_completion_chunk(
161
- chunk_text=None,
162
- finish_reason="stop",
163
- chunk_id=completion_id,
164
- model_uid=self.model_uid,
165
- prompt_tokens=prompt_tokens,
166
- completion_tokens=completion_tokens,
167
- total_tokens=total_tokens,
168
- has_choice=True,
169
- has_content=False,
170
- )
171
-
172
69
 
173
70
  class DeepSeekV2PytorchChatModel(PytorchChatModel):
174
71
  def _load_model(self, **kwargs):
@@ -215,126 +112,3 @@ class DeepSeekV2PytorchChatModel(PytorchChatModel):
215
112
  if "chat" not in llm_family.model_ability:
216
113
  return False
217
114
  return True
218
-
219
- def chat(
220
- self,
221
- messages: List[Dict],
222
- generate_config: Optional[PytorchGenerateConfig] = None,
223
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
224
- assert self.model_family.chat_template is not None
225
- full_prompt = self.get_full_context(
226
- messages,
227
- self.model_family.chat_template,
228
- tokenizer=self._tokenizer,
229
- )
230
- input_tensor = self._tokenizer.encode(
231
- full_prompt,
232
- padding=False,
233
- truncation=False,
234
- max_length=None,
235
- add_special_tokens=False,
236
- return_tensors="pt",
237
- )
238
-
239
- generate_config = self._sanitize_generate_config(generate_config)
240
- default_generate_config = self._model.generation_config
241
- generate_kwargs = {
242
- "input_ids": input_tensor.cuda(),
243
- "temperature": float(
244
- generate_config.get("temperature", default_generate_config.temperature)
245
- ),
246
- "repetition_penalty": float(generate_config.get("repetition_penalty", 1.0)),
247
- "top_p": float(generate_config.get("top_p", default_generate_config.top_p)),
248
- "top_k": int(generate_config.get("top_k", -1)),
249
- "max_new_tokens": generate_config.get("max_tokens", 512),
250
- "bos_token_id": default_generate_config.bos_token_id,
251
- "do_sample": default_generate_config.do_sample,
252
- "eos_token_id": default_generate_config.eos_token_id,
253
- }
254
-
255
- stream = generate_config.get("stream", False)
256
- stream_options = generate_config.get("stream_options", None)
257
- include_usage = (
258
- stream_options["include_usage"]
259
- if isinstance(stream_options, dict)
260
- else False
261
- )
262
- if stream:
263
- chunk = self._generate_stream(generate_kwargs, input_tensor, include_usage)
264
- return self._to_chat_completion_chunks(chunk)
265
- else:
266
- return self._generate(generate_kwargs, input_tensor)
267
-
268
- def _generate(self, generate_kwargs, input_ids) -> ChatCompletion:
269
- prompt_tokens = len(input_ids[0])
270
- generation_output = self._model.generate(**generate_kwargs)
271
- completion_tokens = len(generation_output[0])
272
- response = self._tokenizer.decode(
273
- generation_output[0][input_ids.shape[1] :], skip_special_tokens=True
274
- )
275
- return generate_chat_completion(
276
- self.model_uid,
277
- response,
278
- prompt_tokens=prompt_tokens,
279
- completion_tokens=completion_tokens,
280
- total_tokens=prompt_tokens + completion_tokens,
281
- )
282
-
283
- def _generate_stream(self, generate_kwargs, input_ids, include_usage):
284
- from threading import Thread
285
-
286
- from transformers import TextIteratorStreamer
287
-
288
- # Initialize the streamer
289
- streamer = TextIteratorStreamer(
290
- self._tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
291
- )
292
- # Define the generation configuration
293
- generate_kwargs["streamer"] = streamer
294
- # Start the model chat in a separate thread
295
- thread = Thread(
296
- target=self._model.generate,
297
- kwargs=generate_kwargs,
298
- )
299
- thread.start()
300
-
301
- completion_id = str(uuid.uuid1())
302
- prompt_tokens = len(input_ids[0])
303
- total_tokens, completion_tokens = 0, 0
304
- # Loop through the streamer to get the new text as it is generated
305
- for i, new_text in enumerate(streamer):
306
- completion_tokens = max(completion_tokens, len(streamer.token_cache))
307
- total_tokens = prompt_tokens + completion_tokens
308
- yield generate_completion_chunk(
309
- chunk_text=new_text,
310
- finish_reason=None,
311
- chunk_id=completion_id,
312
- model_uid=self.model_uid,
313
- prompt_tokens=prompt_tokens,
314
- completion_tokens=completion_tokens,
315
- total_tokens=total_tokens,
316
- )
317
- yield generate_completion_chunk(
318
- chunk_text=None,
319
- finish_reason="stop",
320
- chunk_id=completion_id,
321
- model_uid=self.model_uid,
322
- prompt_tokens=prompt_tokens,
323
- completion_tokens=completion_tokens,
324
- total_tokens=total_tokens,
325
- has_choice=True,
326
- has_content=False,
327
- )
328
-
329
- if include_usage:
330
- yield generate_completion_chunk(
331
- chunk_text=None,
332
- finish_reason=None,
333
- chunk_id=completion_id,
334
- model_uid=self.model_uid,
335
- prompt_tokens=prompt_tokens,
336
- completion_tokens=completion_tokens,
337
- total_tokens=total_tokens,
338
- has_choice=False,
339
- has_content=False,
340
- )
@@ -28,6 +28,7 @@ from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
28
28
  from ..llm_family import LLMFamilyV1, LLMSpecV1
29
29
  from ..utils import generate_chat_completion, generate_completion_chunk
30
30
  from .core import PytorchChatModel, PytorchGenerateConfig
31
+ from .utils import cache_clean
31
32
 
32
33
  logger = logging.getLogger(__name__)
33
34
 
@@ -137,6 +138,7 @@ class DeepSeekVLChatModel(PytorchChatModel):
137
138
  return "".join(new_content), images
138
139
  return content, []
139
140
 
141
+ @cache_clean
140
142
  def chat(
141
143
  self,
142
144
  messages: List[Dict],
@@ -26,7 +26,7 @@ from ...utils import select_device
26
26
  from ..llm_family import LLMFamilyV1, LLMSpecV1
27
27
  from ..utils import _decode_image, generate_chat_completion, generate_completion_chunk
28
28
  from .core import PytorchChatModel, PytorchGenerateConfig
29
- from .utils import get_max_src_len
29
+ from .utils import cache_clean, get_max_src_len
30
30
 
31
31
  logger = logging.getLogger(__name__)
32
32
 
@@ -129,6 +129,7 @@ class Glm4VModel(PytorchChatModel):
129
129
  res.append({"role": role, "content": text})
130
130
  return res
131
131
 
132
+ @cache_clean
132
133
  def chat(
133
134
  self,
134
135
  messages: List[Dict],
@@ -27,6 +27,7 @@ from ..utils import (
27
27
  parse_messages,
28
28
  )
29
29
  from .core import PytorchChatModel, PytorchGenerateConfig
30
+ from .utils import cache_clean
30
31
 
31
32
  logger = logging.getLogger(__name__)
32
33
 
@@ -326,6 +327,7 @@ class InternVLChatModel(PytorchChatModel):
326
327
  use_fast=False,
327
328
  )
328
329
 
330
+ @cache_clean
329
331
  def chat(
330
332
  self,
331
333
  messages: List[Dict],
@@ -11,13 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import uuid
15
- from typing import Any, Dict, Iterator, List, Optional, Union
14
+
15
+ from typing import List, Optional
16
16
 
17
17
  from ....core.scheduler import InferenceRequest
18
- from ....types import ChatCompletion, ChatCompletionChunk, LoRA, PytorchGenerateConfig
18
+ from ....types import LoRA
19
19
  from ..llm_family import LLMFamilyV1, LLMSpecV1
20
- from ..utils import generate_chat_completion, generate_completion_chunk, parse_messages
21
20
  from .core import PytorchChatModel, PytorchModelConfig
22
21
 
23
22
 
@@ -93,94 +92,3 @@ class Internlm2PytorchChatModel(PytorchChatModel):
93
92
  if top_p is None:
94
93
  raw_config["top_p"] = 0.8
95
94
  return raw_config
96
-
97
- def chat(
98
- self,
99
- messages: List[Dict],
100
- generate_config: Optional[PytorchGenerateConfig] = None,
101
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
102
- kwargs: Dict[str, Any] = {}
103
- generate_config = generate_config or {}
104
- temperature = generate_config.get("temperature")
105
- if temperature is not None:
106
- kwargs["temperature"] = float(temperature)
107
- top_p = generate_config.get("top_p")
108
- if top_p is not None:
109
- kwargs["top_p"] = float(top_p)
110
- max_new_tokens = generate_config.get("max_tokens")
111
- if max_new_tokens is not None:
112
- kwargs["max_length"] = int(max_new_tokens)
113
-
114
- stream = generate_config.get("stream", False)
115
- stream_options = generate_config.pop("stream_options", None)
116
- include_usage = (
117
- stream_options["include_usage"]
118
- if isinstance(stream_options, dict)
119
- else False
120
- )
121
-
122
- prompt, system_prompt, chat_history = parse_messages(messages)
123
- if chat_history:
124
- input_history = [
125
- (chat_history[i]["content"], (chat_history[i + 1]["content"]))
126
- for i in range(0, len(chat_history), 2)
127
- ]
128
- else:
129
- input_history = []
130
- if system_prompt:
131
- kwargs["meta_instruction"] = system_prompt
132
- if stream:
133
-
134
- def _stream_generator():
135
- last_chunk_text_length = 0
136
- chunk_id = "chat-" + str(uuid.uuid1())
137
- prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
138
- inputs = self._tokenizer([prompt], return_tensors="pt")
139
- inputs = inputs.to(self._model.device)
140
- prompt_tokens = len(inputs["input_ids"][0])
141
- for chunk_text, _ in self._model.stream_chat(
142
- self._tokenizer, prompt, input_history, **kwargs
143
- ):
144
- completion_tokens = completion_tokens + 1
145
- total_tokens = prompt_tokens + completion_tokens
146
- chunk_text = chunk_text[last_chunk_text_length:]
147
- last_chunk_text_length += len(chunk_text)
148
-
149
- yield generate_completion_chunk(
150
- chunk_text,
151
- finish_reason=None,
152
- chunk_id=chunk_id,
153
- model_uid=self.model_uid,
154
- prompt_tokens=prompt_tokens,
155
- completion_tokens=completion_tokens,
156
- total_tokens=total_tokens,
157
- )
158
- yield generate_completion_chunk(
159
- None,
160
- finish_reason="stop",
161
- chunk_id=chunk_id,
162
- model_uid=self.model_uid,
163
- prompt_tokens=prompt_tokens,
164
- completion_tokens=completion_tokens,
165
- total_tokens=total_tokens,
166
- has_choice=True,
167
- has_content=False,
168
- )
169
- if include_usage:
170
- yield generate_completion_chunk(
171
- None,
172
- finish_reason=None,
173
- chunk_id=chunk_id,
174
- model_uid=self.model_uid,
175
- prompt_tokens=prompt_tokens,
176
- completion_tokens=completion_tokens,
177
- total_tokens=total_tokens,
178
- has_choice=False,
179
- )
180
-
181
- return self._to_chat_completion_chunks(_stream_generator())
182
- else:
183
- response, _ = self._model.chat(
184
- self._tokenizer, prompt, input_history, **kwargs
185
- )
186
- return generate_chat_completion(self.model_uid, response)
@@ -29,6 +29,7 @@ from ..utils import (
29
29
  parse_messages,
30
30
  )
31
31
  from .core import PytorchChatModel, PytorchGenerateConfig
32
+ from .utils import cache_clean
32
33
 
33
34
  logger = logging.getLogger(__name__)
34
35
 
@@ -119,6 +120,7 @@ class MiniCPMV25Model(PytorchChatModel):
119
120
  raise RuntimeError("Only one image per message is supported")
120
121
  return content, []
121
122
 
123
+ @cache_clean
122
124
  def chat(
123
125
  self,
124
126
  messages: List[Dict],
@@ -30,6 +30,7 @@ from ..utils import (
30
30
  parse_messages,
31
31
  )
32
32
  from .core import PytorchChatModel, PytorchGenerateConfig
33
+ from .utils import cache_clean
33
34
 
34
35
  logger = logging.getLogger(__name__)
35
36
 
@@ -198,6 +199,7 @@ class MiniCPMV26Model(PytorchChatModel):
198
199
  msgs.append({"role": "user", "content": images_chat + [content]})
199
200
  return msgs, video_existed
200
201
 
202
+ @cache_clean
201
203
  def chat(
202
204
  self,
203
205
  messages: List[Dict],
@@ -24,6 +24,7 @@ from ...utils import select_device
24
24
  from ..llm_family import LLMFamilyV1, LLMSpecV1
25
25
  from ..utils import generate_chat_completion, parse_messages
26
26
  from .core import PytorchChatModel, PytorchGenerateConfig
27
+ from .utils import cache_clean
27
28
 
28
29
  logger = logging.getLogger(__name__)
29
30
 
@@ -87,6 +88,7 @@ class OmniLMMModel(PytorchChatModel):
87
88
  return images, other_content
88
89
  return [], [{"type": "text", "text": content}]
89
90
 
91
+ @cache_clean
90
92
  def chat(
91
93
  self,
92
94
  messages: List[Dict],