xinference 0.15.3__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (65) hide show
  1. xinference/__init__.py +0 -4
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +29 -2
  4. xinference/client/restful/restful_client.py +10 -0
  5. xinference/constants.py +7 -3
  6. xinference/core/image_interface.py +76 -23
  7. xinference/core/model.py +158 -46
  8. xinference/core/progress_tracker.py +187 -0
  9. xinference/core/scheduler.py +10 -7
  10. xinference/core/supervisor.py +11 -0
  11. xinference/core/utils.py +9 -0
  12. xinference/core/worker.py +1 -0
  13. xinference/deploy/supervisor.py +4 -0
  14. xinference/model/__init__.py +4 -0
  15. xinference/model/audio/chattts.py +2 -1
  16. xinference/model/audio/core.py +0 -2
  17. xinference/model/audio/model_spec.json +8 -0
  18. xinference/model/audio/model_spec_modelscope.json +9 -0
  19. xinference/model/image/core.py +6 -7
  20. xinference/model/image/scheduler/__init__.py +13 -0
  21. xinference/model/image/scheduler/flux.py +533 -0
  22. xinference/model/image/sdapi.py +35 -4
  23. xinference/model/image/stable_diffusion/core.py +215 -110
  24. xinference/model/image/utils.py +39 -3
  25. xinference/model/llm/__init__.py +2 -0
  26. xinference/model/llm/llm_family.json +185 -17
  27. xinference/model/llm/llm_family_modelscope.json +124 -12
  28. xinference/model/llm/transformers/chatglm.py +104 -0
  29. xinference/model/llm/transformers/cogvlm2.py +2 -1
  30. xinference/model/llm/transformers/cogvlm2_video.py +2 -0
  31. xinference/model/llm/transformers/core.py +43 -113
  32. xinference/model/llm/transformers/deepseek_v2.py +0 -226
  33. xinference/model/llm/transformers/deepseek_vl.py +2 -0
  34. xinference/model/llm/transformers/glm4v.py +2 -1
  35. xinference/model/llm/transformers/intern_vl.py +2 -0
  36. xinference/model/llm/transformers/internlm2.py +3 -95
  37. xinference/model/llm/transformers/minicpmv25.py +2 -0
  38. xinference/model/llm/transformers/minicpmv26.py +2 -0
  39. xinference/model/llm/transformers/omnilmm.py +2 -0
  40. xinference/model/llm/transformers/opt.py +68 -0
  41. xinference/model/llm/transformers/qwen2_audio.py +11 -4
  42. xinference/model/llm/transformers/qwen2_vl.py +2 -28
  43. xinference/model/llm/transformers/qwen_vl.py +2 -1
  44. xinference/model/llm/transformers/utils.py +36 -283
  45. xinference/model/llm/transformers/yi_vl.py +2 -0
  46. xinference/model/llm/utils.py +60 -16
  47. xinference/model/llm/vllm/core.py +68 -9
  48. xinference/model/llm/vllm/utils.py +0 -1
  49. xinference/model/utils.py +7 -4
  50. xinference/model/video/core.py +0 -2
  51. xinference/utils.py +2 -3
  52. xinference/web/ui/build/asset-manifest.json +3 -3
  53. xinference/web/ui/build/index.html +1 -1
  54. xinference/web/ui/build/static/js/{main.e51a356d.js → main.f7da0140.js} +3 -3
  55. xinference/web/ui/build/static/js/main.f7da0140.js.map +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
  57. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/METADATA +38 -6
  58. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/RECORD +63 -59
  59. xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
  61. /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.f7da0140.js.LICENSE.txt} +0 -0
  62. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/LICENSE +0 -0
  63. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/WHEEL +0 -0
  64. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/entry_points.txt +0 -0
  65. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,68 @@
1
+ # Copyright 2022-2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from builtins import classmethod
15
+ from typing import List, Optional
16
+
17
+ from ....core.scheduler import InferenceRequest
18
+ from ....types import LoRA
19
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
20
+ from .core import PytorchModel, PytorchModelConfig
21
+
22
+
23
+ class OptPytorchModel(PytorchModel):
24
+ def __init__(
25
+ self,
26
+ model_uid: str,
27
+ model_family: "LLMFamilyV1",
28
+ model_spec: "LLMSpecV1",
29
+ quantization: str,
30
+ model_path: str,
31
+ pytorch_model_config: Optional[PytorchModelConfig] = None,
32
+ peft_model: Optional[List[LoRA]] = None,
33
+ ):
34
+ super().__init__(
35
+ model_uid,
36
+ model_family,
37
+ model_spec,
38
+ quantization,
39
+ model_path,
40
+ pytorch_model_config=pytorch_model_config,
41
+ peft_model=peft_model,
42
+ )
43
+
44
+ @classmethod
45
+ def match(
46
+ cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
47
+ ) -> bool:
48
+ if llm_spec.model_format != "pytorch":
49
+ return False
50
+ model_family = llm_family.model_family or llm_family.model_name
51
+ if model_family != "opt":
52
+ return False
53
+ return True
54
+
55
+ def build_prefill_position_ids(
56
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
57
+ ):
58
+ """
59
+ Mainly for UT.
60
+ Transformers code in `main` branch supports `position_ids` parameter (https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L1076),
61
+ while in release branch, it doesn't (https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/opt/modeling_opt.py#L886).
62
+ """
63
+ return None
64
+
65
+ def build_decode_position_ids(
66
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
67
+ ):
68
+ return None
@@ -14,16 +14,22 @@
14
14
  import logging
15
15
  import uuid
16
16
  from io import BytesIO
17
- from typing import Dict, Iterator, List, Optional, Union
17
+ from typing import Iterator, List, Optional, Union
18
18
  from urllib.request import urlopen
19
19
 
20
20
  import numpy as np
21
21
 
22
22
  from ....model.utils import select_device
23
- from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
23
+ from ....types import (
24
+ ChatCompletion,
25
+ ChatCompletionChunk,
26
+ ChatCompletionMessage,
27
+ CompletionChunk,
28
+ )
24
29
  from ..llm_family import LLMFamilyV1, LLMSpecV1
25
30
  from ..utils import generate_chat_completion, generate_completion_chunk
26
31
  from .core import PytorchChatModel, PytorchGenerateConfig
32
+ from .utils import cache_clean
27
33
 
28
34
  logger = logging.getLogger(__name__)
29
35
 
@@ -68,7 +74,7 @@ class Qwen2AudioChatModel(PytorchChatModel):
68
74
 
69
75
  def _transform_messages(
70
76
  self,
71
- messages: List[Dict],
77
+ messages: List[ChatCompletionMessage],
72
78
  ):
73
79
  import librosa
74
80
 
@@ -89,9 +95,10 @@ class Qwen2AudioChatModel(PytorchChatModel):
89
95
 
90
96
  return text, audios
91
97
 
98
+ @cache_clean
92
99
  def chat(
93
100
  self,
94
- messages: List[Dict],
101
+ messages: List[ChatCompletionMessage],
95
102
  generate_config: Optional[PytorchGenerateConfig] = None,
96
103
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
97
104
  text, audios = self._transform_messages(messages)
@@ -27,6 +27,7 @@ from ....types import (
27
27
  from ..llm_family import LLMFamilyV1, LLMSpecV1
28
28
  from ..utils import generate_chat_completion, generate_completion_chunk
29
29
  from .core import PytorchChatModel, PytorchGenerateConfig
30
+ from .utils import cache_clean
30
31
 
31
32
  logger = logging.getLogger(__name__)
32
33
 
@@ -75,34 +76,7 @@ class Qwen2VLChatModel(PytorchChatModel):
75
76
  self.model_path, device_map=device, trust_remote_code=True
76
77
  ).eval()
77
78
 
78
- def _transform_messages(
79
- self,
80
- messages: List[ChatCompletionMessage],
81
- ):
82
- transformed_messages = []
83
- for msg in messages:
84
- new_content = []
85
- role = msg["role"]
86
- content = msg["content"]
87
- if isinstance(content, str):
88
- new_content.append({"type": "text", "text": content})
89
- elif isinstance(content, List):
90
- for item in content: # type: ignore
91
- if "text" in item:
92
- new_content.append({"type": "text", "text": item["text"]})
93
- elif "image_url" in item:
94
- new_content.append(
95
- {"type": "image", "image": item["image_url"]["url"]}
96
- )
97
- elif "video_url" in item:
98
- new_content.append(
99
- {"type": "video", "video": item["video_url"]["url"]}
100
- )
101
- new_message = {"role": role, "content": new_content}
102
- transformed_messages.append(new_message)
103
-
104
- return transformed_messages
105
-
79
+ @cache_clean
106
80
  def chat(
107
81
  self,
108
82
  messages: List[ChatCompletionMessage], # type: ignore
@@ -28,7 +28,7 @@ from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
28
28
  from ..llm_family import LLMFamilyV1, LLMSpecV1
29
29
  from ..utils import generate_chat_completion, generate_completion_chunk
30
30
  from .core import PytorchChatModel, PytorchGenerateConfig
31
- from .utils import pad_prefill_tokens
31
+ from .utils import cache_clean, pad_prefill_tokens
32
32
 
33
33
  logger = logging.getLogger(__name__)
34
34
 
@@ -137,6 +137,7 @@ class QwenVLChatModel(PytorchChatModel):
137
137
  prompt = self._message_content_to_qwen(messages[-1]["content"])
138
138
  return prompt, qwen_history
139
139
 
140
+ @cache_clean
140
141
  def chat(
141
142
  self,
142
143
  messages: List[Dict],
@@ -12,12 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import gc
15
+ import asyncio
16
+ import functools
16
17
  import logging
17
18
  import os
18
19
  import time
19
- import uuid
20
- from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Tuple
20
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
21
21
 
22
22
  import torch
23
23
  from transformers.cache_utils import DynamicCache
@@ -45,20 +45,6 @@ if TYPE_CHECKING:
45
45
  logger = logging.getLogger(__name__)
46
46
 
47
47
 
48
- def is_sentence_complete(output: str):
49
- """Check whether the output is a complete sentence."""
50
- end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”")
51
- return output.endswith(end_symbols)
52
-
53
-
54
- def is_partial_stop(output: str, stop_str: str):
55
- """Check whether the output contains a partial stop str."""
56
- for i in range(0, min(len(output), len(stop_str))):
57
- if stop_str.startswith(output[-i:]):
58
- return True
59
- return False
60
-
61
-
62
48
  def get_context_length(config) -> int:
63
49
  """Get the context length of a model from a huggingface model config."""
64
50
  if (
@@ -98,272 +84,6 @@ def prepare_logits_processor(
98
84
  return processor_list
99
85
 
100
86
 
101
- @torch.inference_mode()
102
- def generate_stream(
103
- model_uid,
104
- model,
105
- tokenizer,
106
- prompt,
107
- device,
108
- generate_config,
109
- judge_sent_end=False,
110
- ) -> Iterator[Tuple[CompletionChunk, CompletionUsage]]:
111
- context_len = get_context_length(model.config)
112
- stream_interval = generate_config.get("stream_interval", 2)
113
- stream = generate_config.get("stream", False)
114
- stream_options = generate_config.pop("stream_options", None)
115
- include_usage = (
116
- stream_options["include_usage"] if isinstance(stream_options, dict) else False
117
- )
118
-
119
- len_prompt = len(prompt)
120
-
121
- temperature = float(generate_config.get("temperature", 1.0))
122
- repetition_penalty = float(generate_config.get("repetition_penalty", 1.0))
123
- top_p = float(generate_config.get("top_p", 1.0))
124
- top_k = int(generate_config.get("top_k", -1)) # -1 means disable
125
- max_new_tokens = int(generate_config.get("max_tokens", max_tokens_field.default))
126
- echo = bool(generate_config.get("echo", False))
127
- stop_str = generate_config.get("stop", None)
128
- stop_token_ids = generate_config.get("stop_token_ids", None) or []
129
- stop_token_ids.append(tokenizer.eos_token_id)
130
- chunk_id = str(uuid.uuid4())
131
-
132
- logits_processor = prepare_logits_processor(
133
- temperature, repetition_penalty, top_p, top_k
134
- )
135
-
136
- if ".modeling_qwen." in str(type(model)).lower():
137
- # TODO: hacky
138
- input_ids = tokenizer(prompt, allowed_special="all").input_ids
139
- else:
140
- input_ids = tokenizer(prompt).input_ids
141
- output_ids = list(input_ids)
142
-
143
- if model.config.is_encoder_decoder:
144
- max_src_len = context_len
145
- else:
146
- max_src_len = context_len - max_new_tokens - 8
147
- if max_src_len < 0:
148
- raise ValueError("Max tokens exceeds model's max length")
149
-
150
- input_ids = input_ids[-max_src_len:]
151
- input_echo_len = len(input_ids)
152
-
153
- if model.config.is_encoder_decoder:
154
- encoder_output = model.encoder(
155
- input_ids=torch.as_tensor([input_ids], device=device)
156
- )[0]
157
- start_ids = torch.as_tensor(
158
- [[model.generation_config.decoder_start_token_id]],
159
- dtype=torch.int64,
160
- device=device,
161
- )
162
-
163
- start = time.time()
164
- past_key_values = out = None
165
- sent_interrupt = False
166
- token = None
167
- last_output_length = 0
168
- for i in range(max_new_tokens):
169
- if i == 0:
170
- if model.config.is_encoder_decoder:
171
- out = model.decoder(
172
- input_ids=start_ids,
173
- encoder_hidden_states=encoder_output,
174
- use_cache=True,
175
- )
176
- logits = model.lm_head(out[0])
177
- else:
178
- out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
179
- logits = out.logits
180
- past_key_values = out.past_key_values
181
- else:
182
- if model.config.is_encoder_decoder:
183
- out = model.decoder(
184
- input_ids=torch.as_tensor(
185
- [[token] if not sent_interrupt else output_ids], device=device
186
- ),
187
- encoder_hidden_states=encoder_output,
188
- use_cache=True,
189
- past_key_values=past_key_values if not sent_interrupt else None,
190
- )
191
- sent_interrupt = False
192
-
193
- logits = model.lm_head(out[0])
194
- else:
195
- out = model(
196
- input_ids=torch.as_tensor(
197
- [[token] if not sent_interrupt else output_ids], device=device
198
- ),
199
- use_cache=True,
200
- past_key_values=past_key_values if not sent_interrupt else None,
201
- )
202
- sent_interrupt = False
203
- logits = out.logits
204
- past_key_values = out.past_key_values
205
-
206
- if logits_processor:
207
- if repetition_penalty > 1.0:
208
- tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
209
- else:
210
- tmp_output_ids = None
211
- last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
212
- else:
213
- last_token_logits = logits[0, -1, :]
214
-
215
- if device == "mps":
216
- # Switch to CPU by avoiding some bugs in mps backend.
217
- last_token_logits = last_token_logits.float().to("cpu")
218
-
219
- if temperature < 1e-5 or top_p < 1e-8: # greedy
220
- _, indices = torch.topk(last_token_logits, 2)
221
- tokens = [int(index) for index in indices.tolist()]
222
- else:
223
- probs = torch.softmax(last_token_logits, dim=-1)
224
- indices = torch.multinomial(probs, num_samples=2)
225
- tokens = [int(token) for token in indices.tolist()]
226
- token = tokens[0]
227
- output_ids.append(token)
228
-
229
- if token in stop_token_ids:
230
- stopped = True
231
- else:
232
- stopped = False
233
-
234
- if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
235
- if echo:
236
- tmp_output_ids = output_ids
237
- rfind_start = len_prompt
238
- else:
239
- tmp_output_ids = output_ids[input_echo_len:]
240
- rfind_start = 0
241
-
242
- output = tokenizer.decode(
243
- tmp_output_ids,
244
- skip_special_tokens=True,
245
- spaces_between_special_tokens=False,
246
- clean_up_tokenization_spaces=True,
247
- )
248
-
249
- # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
250
- if judge_sent_end and stopped and not is_sentence_complete(output):
251
- if len(tokens) > 1:
252
- token = tokens[1]
253
- output_ids[-1] = token
254
- else:
255
- output_ids.pop()
256
- stopped = False
257
- sent_interrupt = True
258
-
259
- partially_stopped = False
260
- if stop_str:
261
- if isinstance(stop_str, str):
262
- pos = output.rfind(stop_str, rfind_start)
263
- if pos != -1:
264
- output = output[:pos]
265
- stopped = True
266
- else:
267
- partially_stopped = is_partial_stop(output, stop_str)
268
- elif isinstance(stop_str, Iterable):
269
- for each_stop in stop_str:
270
- pos = output.rfind(each_stop, rfind_start)
271
- if pos != -1:
272
- output = output[:pos]
273
- stopped = True
274
- break
275
- else:
276
- partially_stopped = is_partial_stop(output, each_stop)
277
- if partially_stopped:
278
- break
279
- else:
280
- raise ValueError("Invalid stop field type.")
281
-
282
- if stream:
283
- output = output.strip("�")
284
- tmp_output_length = len(output)
285
- output = output[last_output_length:]
286
- last_output_length = tmp_output_length
287
-
288
- # prevent yielding partial stop sequence
289
- if not partially_stopped:
290
- completion_choice = CompletionChoice(
291
- text=output, index=0, logprobs=None, finish_reason=None
292
- )
293
- completion_chunk = CompletionChunk(
294
- id=chunk_id,
295
- object="text_completion",
296
- created=int(time.time()),
297
- model=model_uid,
298
- choices=[completion_choice],
299
- )
300
- completion_usage = CompletionUsage(
301
- prompt_tokens=input_echo_len,
302
- completion_tokens=i,
303
- total_tokens=(input_echo_len + i),
304
- )
305
-
306
- yield completion_chunk, completion_usage
307
-
308
- if stopped:
309
- break
310
-
311
- elapsed_time = time.time() - start
312
- logger.info(f"Average generation speed: {i / elapsed_time:.2f} tokens/s.")
313
-
314
- # finish stream event, which contains finish reason
315
- if stopped:
316
- finish_reason = "stop"
317
- elif i == max_new_tokens - 1:
318
- finish_reason = "length"
319
- else:
320
- finish_reason = None
321
-
322
- if stream:
323
- completion_choice = CompletionChoice(
324
- text=output, index=0, logprobs=None, finish_reason=finish_reason
325
- )
326
- else:
327
- completion_choice = CompletionChoice(
328
- text=output, index=0, logprobs=None, finish_reason=finish_reason
329
- )
330
-
331
- completion_chunk = CompletionChunk(
332
- id=chunk_id,
333
- object="text_completion",
334
- created=int(time.time()),
335
- model=model_uid,
336
- choices=[completion_choice],
337
- )
338
- completion_usage = CompletionUsage(
339
- prompt_tokens=input_echo_len,
340
- completion_tokens=i,
341
- total_tokens=(input_echo_len + i),
342
- )
343
-
344
- yield completion_chunk, completion_usage
345
-
346
- if include_usage:
347
- completion_chunk = CompletionChunk(
348
- id=chunk_id,
349
- object="text_completion",
350
- created=int(time.time()),
351
- model=model_uid,
352
- choices=[],
353
- )
354
- completion_usage = CompletionUsage(
355
- prompt_tokens=input_echo_len,
356
- completion_tokens=i,
357
- total_tokens=(input_echo_len + i),
358
- )
359
- yield completion_chunk, completion_usage
360
-
361
- # clean
362
- del past_key_values, out
363
- gc.collect()
364
- empty_cache()
365
-
366
-
367
87
  def _get_token_from_logits(
368
88
  req: InferenceRequest, i: int, logits, temperature, repetition_penalty, top_p, top_k
369
89
  ):
@@ -678,6 +398,7 @@ def _batch_inference_one_step_internal(
678
398
  output = output.strip("�")
679
399
  output = output[r.last_output_length :]
680
400
  r.last_output_length += len(output)
401
+ r.outputs.append(output)
681
402
 
682
403
  completion_chunk = generate_completion_chunk(
683
404
  chunk_text=output,
@@ -702,6 +423,7 @@ def _batch_inference_one_step_internal(
702
423
  )
703
424
  r.completion.append(completion_chunk)
704
425
  r.completion.append(eos_flag)
426
+ r.outputs.append(eos_flag)
705
427
 
706
428
  # last round, handle stream result
707
429
  # append usage information when enable `include_usage` for OPENAI API compatibility
@@ -776,3 +498,34 @@ def batch_inference_one_step(
776
498
  for r in req_list:
777
499
  r.stopped = True
778
500
  r.error_msg = str(e)
501
+
502
+
503
+ def cache_clean(fn):
504
+ @functools.wraps(fn)
505
+ async def _async_wrapper(self, *args, **kwargs):
506
+ import gc
507
+
508
+ from ....device_utils import empty_cache
509
+
510
+ result = await fn(self, *args, **kwargs)
511
+
512
+ gc.collect()
513
+ empty_cache()
514
+ return result
515
+
516
+ @functools.wraps(fn)
517
+ def _wrapper(self, *args, **kwargs):
518
+ import gc
519
+
520
+ from ....device_utils import empty_cache
521
+
522
+ result = fn(self, *args, **kwargs)
523
+
524
+ gc.collect()
525
+ empty_cache()
526
+ return result
527
+
528
+ if asyncio.iscoroutinefunction(fn):
529
+ return _async_wrapper
530
+ else:
531
+ return _wrapper
@@ -29,6 +29,7 @@ from ..utils import (
29
29
  parse_messages,
30
30
  )
31
31
  from .core import PytorchChatModel, PytorchGenerateConfig
32
+ from .utils import cache_clean
32
33
 
33
34
  logger = logging.getLogger(__name__)
34
35
 
@@ -99,6 +100,7 @@ class YiVLChatModel(PytorchChatModel):
99
100
  raise RuntimeError("Only one image per message is supported by Yi VL.")
100
101
  return content
101
102
 
103
+ @cache_clean
102
104
  def chat(
103
105
  self,
104
106
  messages: List[Dict],