xinference 0.11.0__py3-none-any.whl → 0.11.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (56) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +30 -0
  3. xinference/client/restful/restful_client.py +29 -0
  4. xinference/core/cache_tracker.py +12 -1
  5. xinference/core/chat_interface.py +10 -4
  6. xinference/core/model.py +2 -2
  7. xinference/core/supervisor.py +30 -2
  8. xinference/core/utils.py +12 -0
  9. xinference/core/worker.py +4 -1
  10. xinference/deploy/cmdline.py +126 -0
  11. xinference/deploy/test/test_cmdline.py +24 -0
  12. xinference/fields.py +3 -1
  13. xinference/model/llm/__init__.py +2 -0
  14. xinference/model/llm/ggml/chatglm.py +98 -13
  15. xinference/model/llm/ggml/llamacpp.py +49 -2
  16. xinference/model/llm/llm_family.json +633 -9
  17. xinference/model/llm/llm_family.py +84 -10
  18. xinference/model/llm/llm_family_modelscope.json +337 -10
  19. xinference/model/llm/memory.py +332 -0
  20. xinference/model/llm/pytorch/chatglm.py +48 -0
  21. xinference/model/llm/pytorch/core.py +25 -6
  22. xinference/model/llm/pytorch/deepseek_vl.py +35 -9
  23. xinference/model/llm/pytorch/intern_vl.py +387 -0
  24. xinference/model/llm/pytorch/internlm2.py +32 -1
  25. xinference/model/llm/pytorch/qwen_vl.py +38 -11
  26. xinference/model/llm/pytorch/utils.py +38 -1
  27. xinference/model/llm/pytorch/yi_vl.py +42 -14
  28. xinference/model/llm/sglang/core.py +31 -9
  29. xinference/model/llm/utils.py +38 -5
  30. xinference/model/llm/vllm/core.py +87 -5
  31. xinference/model/rerank/core.py +23 -1
  32. xinference/model/utils.py +17 -7
  33. xinference/thirdparty/deepseek_vl/models/processing_vlm.py +1 -1
  34. xinference/thirdparty/deepseek_vl/models/siglip_vit.py +2 -2
  35. xinference/thirdparty/llava/mm_utils.py +3 -2
  36. xinference/thirdparty/llava/model/llava_arch.py +1 -1
  37. xinference/thirdparty/omnilmm/chat.py +6 -5
  38. xinference/types.py +10 -1
  39. xinference/web/ui/build/asset-manifest.json +3 -3
  40. xinference/web/ui/build/index.html +1 -1
  41. xinference/web/ui/build/static/js/{main.8e44da4b.js → main.551aa479.js} +3 -3
  42. xinference/web/ui/build/static/js/main.551aa479.js.map +1 -0
  43. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +1 -0
  44. xinference/web/ui/node_modules/.cache/babel-loader/23caf6f1e52c43e983ca3bfd4189f41dbd645fa78f2dfdcd7f6b69bc41678665.json +1 -0
  45. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +1 -0
  46. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/METADATA +10 -8
  47. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/RECORD +52 -50
  48. xinference/web/ui/build/static/js/main.8e44da4b.js.map +0 -1
  49. xinference/web/ui/node_modules/.cache/babel-loader/1870cd6f7054d04e049e363c0a85526584fe25519378609d2838e28d7492bbf1.json +0 -1
  50. xinference/web/ui/node_modules/.cache/babel-loader/5393569d846332075b93b55656716a34f50e0a8c970be789502d7e6c49755fd7.json +0 -1
  51. xinference/web/ui/node_modules/.cache/babel-loader/ddaec68b88e5eff792df1e39a4b4b8b737bfc832293c015660c3c69334e3cf5c.json +0 -1
  52. /xinference/web/ui/build/static/js/{main.8e44da4b.js.LICENSE.txt → main.551aa479.js.LICENSE.txt} +0 -0
  53. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/LICENSE +0 -0
  54. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/WHEEL +0 -0
  55. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/entry_points.txt +0 -0
  56. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,332 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # NOTE:
16
+ #
17
+ # The algorithum is ported from https://github.com/RahulSChand/gpu_poor
18
+ #
19
+ # Improvement:
20
+ #
21
+ # The original js code only calculate kv_cache_dtype by float32, instead of most case we run model with float16.
22
+ #
23
+ # Known Issue:
24
+ #
25
+ # * On vllm, some MHA model use smaller memory than calculation (qwen1.5-7B-chat-gptq-int4,
26
+ # qwen1.5-14B-chat-gptq-int4 with large activation_mem).
27
+ #
28
+ # * On vllm, gemma-it-7B pytorch format model use larger gpu mem than calculation
29
+
30
+ import json
31
+ import math
32
+ from dataclasses import dataclass
33
+ from logging import getLogger
34
+ from math import ceil
35
+ from typing import Any, Optional, Union
36
+
37
+ from .llm_family import convert_model_size_to_float
38
+
39
+ logger = getLogger(__name__)
40
+
41
+
42
+ @dataclass
43
+ class ModelLayersInfo:
44
+ vocab_size: int
45
+ heads: int # num_attention_heads, num_heads or n_head
46
+ hidden_dim: int # hidden_size, d_model, or n_embd
47
+ inter_dim: int # intermediate_size, n_inner or d_ff
48
+ num_layers: int # num_layers, num_hidden_layers or n_layer
49
+
50
+
51
+ @dataclass
52
+ class ModelMemInfo:
53
+ """Memory required by model, unit in MB"""
54
+
55
+ model_mem: int
56
+ kv_cache_mem: int
57
+ activation_mem: int
58
+ overhead: int
59
+ total: int
60
+
61
+
62
+ QUANT_NORMALIZE = {"int4": "4-bit", "int8": "8-bit", "4-bit": "4-bit", "8-bit": "8-bit"}
63
+
64
+ GGML_MULTI_FACTOR_DICT = {
65
+ "q4_0": 18,
66
+ "q4_1": 20,
67
+ "q5_0": 22,
68
+ "q5_1": 24,
69
+ "q8_0": 34,
70
+ "q8_1": 40,
71
+ }
72
+
73
+ GGML_MULTI_FACTOR_DICT_64 = {
74
+ "q6_K": 54.0,
75
+ "q3": 26.0,
76
+ "q4": 38.0,
77
+ "q5": 46.0,
78
+ }
79
+
80
+ GGML_MULTI_FACTOR_DICT_COMBINE = {
81
+ "q3_K_L": [38.0, 26.0],
82
+ "q3_K_M": [46.0, 26.0],
83
+ "q4_K_S": [46.0, 38.0],
84
+ "q4_K_M": [54.0, 38.0],
85
+ "q5_K_M": [54.0, 46.0],
86
+ "q2_K": [26.0, 22.0],
87
+ }
88
+
89
+
90
+ # Return gpu memory in MB
91
+ def estimate_llm_gpu_memory(
92
+ model_size_in_billions: Union[str, int],
93
+ quantization: Optional[str],
94
+ context_length: int, # input+output
95
+ model_format: str,
96
+ model_name: Optional[str] = None,
97
+ kv_cache_dtype: int = 16,
98
+ ) -> Optional[ModelMemInfo]:
99
+ """
100
+ model_size_in_billions: must be str like 1_8 or 46_7, to match llm.
101
+ """
102
+ info = get_model_layers_info(
103
+ model_size_in_billions,
104
+ model_name,
105
+ model_format,
106
+ quantization,
107
+ )
108
+ if info is None:
109
+ return None
110
+ size_in_billions = convert_model_size_to_float(model_size_in_billions)
111
+ return estimate_llm_gpu_memory_details(
112
+ info,
113
+ size_in_billions,
114
+ quantization,
115
+ context_length,
116
+ model_format,
117
+ kv_cache_dtype,
118
+ )
119
+
120
+
121
+ def estimate_llm_gpu_memory_details(
122
+ info: ModelLayersInfo,
123
+ size_in_billions: float,
124
+ quantization: Optional[str],
125
+ context_length: int, # input+output
126
+ model_format: str,
127
+ kv_cache_dtype: int = 16,
128
+ ) -> ModelMemInfo:
129
+ """return model_mem, kv_cache, overhead, activation_mem"""
130
+ if kv_cache_dtype not in [8, 16, 32]:
131
+ raise ValueError(f"Invalid kv_cache_dtype {kv_cache_dtype}")
132
+ if kv_cache_dtype == 8:
133
+ kv_dtype_size = 1
134
+ elif kv_cache_dtype == 16:
135
+ kv_dtype_size = 2
136
+ else:
137
+ kv_dtype_size = 4
138
+ overhead = 650.0
139
+ if model_format == "ggmlv3":
140
+ assert quantization is not None and quantization != "none"
141
+ model_size_in_mb = _compute_model_size_ggml(info, quantization)
142
+ inference_mem = float(
143
+ context_length * kv_dtype_size * info.hidden_dim * info.num_layers
144
+ )
145
+ inference_mem = inference_mem / 1024.0 / 1024.0
146
+ activation_mem = _compute_inference_only_activation_memory(context_length, info)
147
+ overhead = overhead + context_length * 0.1
148
+ else:
149
+ if quantization is not None:
150
+ assert isinstance(quantization, str)
151
+ quantization = QUANT_NORMALIZE[quantization.lower()]
152
+ assert quantization is not None
153
+
154
+ model_size = size_in_billions * 1000000000.0
155
+ model_size_in_mb = _convert_to_mb_model_size(model_size, quantization)
156
+ # KV cache
157
+ inference_mem = float(
158
+ context_length * 2 * kv_dtype_size * info.hidden_dim * info.num_layers
159
+ )
160
+ inference_mem = inference_mem / 1024.0 / 1024.0
161
+ activation_mem = _compute_inference_only_activation_memory(context_length, info)
162
+
163
+ total_mem = ceil(inference_mem + model_size_in_mb + overhead + activation_mem)
164
+ return ModelMemInfo(
165
+ model_mem=ceil(model_size_in_mb),
166
+ kv_cache_mem=ceil(inference_mem),
167
+ activation_mem=ceil(activation_mem),
168
+ overhead=ceil(overhead),
169
+ total=total_mem,
170
+ )
171
+
172
+
173
+ def _load_item_from_json(config_data: Any, *keys: str) -> str:
174
+ assert len(keys) > 0
175
+ for key in keys:
176
+ v = config_data.get(key)
177
+ if v is not None:
178
+ return v
179
+ raise ValueError("load ModelLayersInfo: missing %s" % (keys[0]))
180
+
181
+
182
+ def load_model_config_json(config_path: str) -> ModelLayersInfo:
183
+ with open(config_path, "r") as f:
184
+ config_data = json.load(f)
185
+ return ModelLayersInfo(
186
+ vocab_size=int(_load_item_from_json(config_data, "vocab_size")),
187
+ heads=int(
188
+ _load_item_from_json(
189
+ config_data, "num_key_value_heads", "num_attention_heads"
190
+ )
191
+ ),
192
+ hidden_dim=int(
193
+ _load_item_from_json(config_data, "hidden_size", "d_model", "n_embd")
194
+ ),
195
+ inter_dim=int(_load_item_from_json(config_data, "intermediate_size")),
196
+ num_layers=int(
197
+ _load_item_from_json(
198
+ config_data, "num_hidden_layers", "num_layers", "n_layer"
199
+ )
200
+ ),
201
+ )
202
+
203
+
204
+ def get_model_layers_info(
205
+ model_size_in_billions: Union[str, int],
206
+ model_name: Optional[str],
207
+ model_format: Optional[str],
208
+ quantization: Optional[str],
209
+ ) -> Optional[ModelLayersInfo]:
210
+ from . import match_llm
211
+ from .llm_family import cache_model_config
212
+
213
+ if not model_name:
214
+ logger.debug("get_model_layers_info by default size=%s", model_size_in_billions)
215
+ size_in_billions = convert_model_size_to_float(model_size_in_billions)
216
+ return _get_default_layers_from_size(size_in_billions)
217
+ match_result = match_llm(
218
+ model_name=model_name,
219
+ model_format=model_format,
220
+ model_size_in_billions=model_size_in_billions,
221
+ quantization=quantization,
222
+ )
223
+ if not match_result:
224
+ return None
225
+ llm_family, llm_spec, _quant = match_result
226
+ config_path = cache_model_config(llm_family, llm_spec)
227
+ return load_model_config_json(config_path)
228
+
229
+
230
+ def _get_default_layers_from_size(size_in_billion: float) -> ModelLayersInfo:
231
+ if size_in_billion < 5:
232
+ vocab_size = 32000
233
+ heads = 32
234
+ num_layers = 24
235
+ elif size_in_billion < 10:
236
+ vocab_size = 32000
237
+ heads = 32
238
+ num_layers = 32
239
+ elif size_in_billion < 24:
240
+ vocab_size = 32000
241
+ heads = 40
242
+ num_layers = 40
243
+ elif size_in_billion < 55:
244
+ vocab_size = 32000
245
+ heads = 60
246
+ num_layers = 48
247
+ else:
248
+ vocab_size = 32000
249
+ heads = 64
250
+ num_layers = 80
251
+
252
+ model_size = int(size_in_billion * 1000000000)
253
+ A = num_layers * 4 + 3 * 4 * num_layers
254
+ B = 2 * vocab_size
255
+ C = -1 * model_size
256
+ h = (-B + math.sqrt(B**2 - 4 * A * C)) / (2 * A)
257
+ h = math.ceil(h)
258
+ return ModelLayersInfo(
259
+ vocab_size=vocab_size,
260
+ heads=heads,
261
+ hidden_dim=h,
262
+ inter_dim=4 * h,
263
+ num_layers=num_layers,
264
+ )
265
+
266
+
267
+ def _convert_to_mb_model_size(model_size: float, quantization: Optional[str]) -> float:
268
+ extra = 0.0
269
+ fB = 2.0
270
+ size = (model_size * fB) / (1024.0 * 1024.0)
271
+ # bnb_q4 == 4-bit ?
272
+ if quantization == "8-bit" or quantization == "4-bit":
273
+ extra = 0.06 * size
274
+ if quantization == "8-bit":
275
+ size = size / 2
276
+ if quantization == "4-bit":
277
+ size = size / 4
278
+ return size + extra
279
+
280
+
281
+ def _compute_inference_only_activation_memory(
282
+ context_length: int, info: ModelLayersInfo
283
+ ) -> float:
284
+ hidden_dim = info.hidden_dim
285
+ heads = info.heads
286
+ ret = (
287
+ (context_length * hidden_dim * 5 * 2 + (context_length**2) * heads * 2)
288
+ / 1024
289
+ / 1024
290
+ )
291
+ return ret
292
+
293
+
294
+ def _compute_model_size_ggml(info: ModelLayersInfo, quantization: str) -> float:
295
+ assert quantization is not None
296
+ vocab_size = info.vocab_size
297
+ num_layers = info.num_layers
298
+ hidden_dim = info.hidden_dim
299
+ inter_dim = info.inter_dim
300
+ total_params = int(
301
+ vocab_size * hidden_dim * 2
302
+ + num_layers * 4 * (hidden_dim**2)
303
+ + num_layers * 3 * inter_dim * hidden_dim
304
+ )
305
+ other_v_down_params = (
306
+ num_layers * (hidden_dim**2) + num_layers * hidden_dim * inter_dim
307
+ )
308
+ other_param_q2k = (
309
+ total_params - (hidden_dim**2) * num_layers * 2 + 2 * vocab_size * hidden_dim
310
+ )
311
+
312
+ total = 0.0
313
+ v1 = GGML_MULTI_FACTOR_DICT.get(quantization)
314
+ if v1 is not None:
315
+ total = (v1 * total_params) / (32 * 1024 * 1024)
316
+ v2 = GGML_MULTI_FACTOR_DICT_64.get(quantization)
317
+ if v2 is not None:
318
+ total = (v2 * total_params) / (64 * 1024 * 1024)
319
+ v3 = GGML_MULTI_FACTOR_DICT_COMBINE.get(quantization)
320
+ if v3 is not None:
321
+ factors = v3
322
+ if quantization == "q2_K":
323
+ total = (
324
+ (total_params - other_param_q2k) * factors[1]
325
+ + other_param_q2k * factors[0]
326
+ ) / (64 * 1024 * 1024)
327
+ else:
328
+ total = (
329
+ (total_params - other_v_down_params) * factors[1]
330
+ + other_v_down_params * factors[0]
331
+ ) / (64 * 1024 * 1024)
332
+ return total
@@ -147,14 +147,26 @@ class ChatglmPytorchChatModel(PytorchChatModel):
147
147
  )
148
148
  else:
149
149
  stream = generate_config.get("stream", False)
150
+ stream_options = generate_config.pop("stream_options", None)
151
+ include_usage = (
152
+ stream_options["include_usage"]
153
+ if isinstance(stream_options, dict)
154
+ else False
155
+ )
150
156
  if stream:
151
157
 
152
158
  def _stream_generator():
153
159
  last_chunk_text_length = 0
154
160
  chunk_id = "chat-" + str(uuid.uuid1())
161
+ prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
162
+ inputs = self._tokenizer([prompt], return_tensors="pt")
163
+ inputs = inputs.to(self._model.device)
164
+ prompt_tokens = len(inputs["input_ids"][0])
155
165
  for chunk_text, _ in self._model.stream_chat(
156
166
  self._tokenizer, prompt, chat_history, **kwargs
157
167
  ):
168
+ completion_tokens = completion_tokens + 1
169
+ total_tokens = prompt_tokens + completion_tokens
158
170
  chunk_text = chunk_text[last_chunk_text_length:]
159
171
  last_chunk_text_length += len(chunk_text)
160
172
  completion_choice = CompletionChoice(
@@ -166,7 +178,43 @@ class ChatglmPytorchChatModel(PytorchChatModel):
166
178
  created=int(time.time()),
167
179
  model=self.model_uid,
168
180
  choices=[completion_choice],
181
+ usage=CompletionUsage(
182
+ prompt_tokens=prompt_tokens,
183
+ completion_tokens=completion_tokens,
184
+ total_tokens=total_tokens,
185
+ ),
186
+ )
187
+ completion_choice = CompletionChoice(
188
+ text="", index=0, logprobs=None, finish_reason="stop"
189
+ )
190
+ chunk = CompletionChunk(
191
+ id=chunk_id,
192
+ object="text_completion",
193
+ created=int(time.time()),
194
+ model=self.model_uid,
195
+ choices=[completion_choice],
196
+ )
197
+ completion_usage = CompletionUsage(
198
+ prompt_tokens=prompt_tokens,
199
+ completion_tokens=completion_tokens,
200
+ total_tokens=total_tokens,
201
+ )
202
+ chunk["usage"] = completion_usage
203
+ yield chunk
204
+ if include_usage:
205
+ chunk = CompletionChunk(
206
+ id=chunk_id,
207
+ object="text_completion",
208
+ created=int(time.time()),
209
+ model=self.model_uid,
210
+ choices=[],
211
+ )
212
+ chunk["usage"] = CompletionUsage(
213
+ prompt_tokens=prompt_tokens,
214
+ completion_tokens=completion_tokens,
215
+ total_tokens=total_tokens,
169
216
  )
217
+ yield chunk
170
218
 
171
219
  return self._to_chat_completion_chunks(_stream_generator())
172
220
  else:
@@ -60,6 +60,8 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
60
60
  "OmniLMM",
61
61
  "yi-vl-chat",
62
62
  "deepseek-vl-chat",
63
+ "internvl-chat",
64
+ "mini-internvl-chat",
63
65
  ]
64
66
 
65
67
 
@@ -143,12 +145,17 @@ class PytorchModel(LLM):
143
145
  f"Failed to import 'PeftModel' from 'peft'. Please make sure 'peft' is installed.\n\n"
144
146
  )
145
147
 
146
- for peft_model in self._peft_model:
147
- # Apply LoRA
148
- self._model = PeftModel.from_pretrained(
149
- self._model,
150
- peft_model.local_path,
151
- )
148
+ for i, peft_model in enumerate(self._peft_model):
149
+ if i == 0:
150
+ self._model = PeftModel.from_pretrained(
151
+ self._model,
152
+ peft_model.local_path,
153
+ adapter_name=peft_model.lora_name,
154
+ )
155
+ else:
156
+ self._model.load_adapter(
157
+ peft_model.local_path, adapter_name=peft_model.lora_name
158
+ )
152
159
  logger.info(
153
160
  f"PEFT adaptor '{peft_model.lora_name}' successfully loaded for model '{self.model_uid}'."
154
161
  )
@@ -302,6 +309,18 @@ class PytorchModel(LLM):
302
309
  assert self._model is not None
303
310
  assert self._tokenizer is not None
304
311
 
312
+ lora_model = generate_config.pop("lora_name")
313
+
314
+ if lora_model is not None and self._peft_model is not None:
315
+ for lora in self._peft_model:
316
+ if lora_model == lora.lora_name:
317
+ self._model.set_adapter(lora_model)
318
+ logger.info(f"Set lora model to {lora_model}")
319
+ break
320
+ else:
321
+ self._model.disable_adapter()
322
+ logger.info(f"No lora model {lora_model} found, skip setting")
323
+
305
324
  stream = generate_config.get("stream", False)
306
325
  if not stream:
307
326
  if "falcon" in model_family_name:
@@ -155,7 +155,12 @@ class DeepSeekVLChatModel(PytorchChatModel):
155
155
  generate_config = {}
156
156
 
157
157
  stream = generate_config.get("stream", False)
158
-
158
+ stream_options = generate_config.pop("stream_options", None)
159
+ include_usage = (
160
+ stream_options["include_usage"]
161
+ if isinstance(stream_options, dict)
162
+ else False
163
+ )
159
164
  prompt, images = self._message_content_to_deepseek(prompt)
160
165
  prompt_messages: List[Dict[str, Any]] = [
161
166
  {
@@ -217,7 +222,7 @@ class DeepSeekVLChatModel(PytorchChatModel):
217
222
  )
218
223
 
219
224
  if stream:
220
- it = self._generate_stream(streamer, stop_str)
225
+ it = self._generate_stream(streamer, stop_str, include_usage, prompt)
221
226
  return self._to_chat_completion_chunks(it)
222
227
  else:
223
228
  c = self._generate(streamer, stop_str)
@@ -246,8 +251,13 @@ class DeepSeekVLChatModel(PytorchChatModel):
246
251
  )
247
252
  return c
248
253
 
249
- def _generate_stream(self, streamer, stop_str) -> Iterator[CompletionChunk]:
254
+ def _generate_stream(
255
+ self, streamer, stop_str, include_usage, prompt
256
+ ) -> Iterator[CompletionChunk]:
250
257
  completion_id = str(uuid.uuid1())
258
+ prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
259
+ input_ids = self._tokenizer(prompt).input_ids
260
+ prompt_tokens = len(input_ids)
251
261
  for i, new_text in enumerate(streamer):
252
262
  if new_text.endswith(stop_str):
253
263
  new_text = new_text[: -len(stop_str)]
@@ -261,10 +271,12 @@ class DeepSeekVLChatModel(PytorchChatModel):
261
271
  model=self.model_uid,
262
272
  choices=[completion_choice],
263
273
  )
274
+ completion_tokens = i
275
+ total_tokens = prompt_tokens + completion_tokens
264
276
  completion_usage = CompletionUsage(
265
- prompt_tokens=-1,
266
- completion_tokens=-1,
267
- total_tokens=-1,
277
+ prompt_tokens=prompt_tokens,
278
+ completion_tokens=completion_tokens,
279
+ total_tokens=total_tokens,
268
280
  )
269
281
  chunk["usage"] = completion_usage
270
282
  yield chunk
@@ -280,9 +292,23 @@ class DeepSeekVLChatModel(PytorchChatModel):
280
292
  choices=[completion_choice],
281
293
  )
282
294
  completion_usage = CompletionUsage(
283
- prompt_tokens=-1,
284
- completion_tokens=-1,
285
- total_tokens=-1,
295
+ prompt_tokens=prompt_tokens,
296
+ completion_tokens=completion_tokens,
297
+ total_tokens=total_tokens,
286
298
  )
287
299
  chunk["usage"] = completion_usage
288
300
  yield chunk
301
+ if include_usage:
302
+ chunk = CompletionChunk(
303
+ id=completion_id,
304
+ object="text_completion",
305
+ created=int(time.time()),
306
+ model=self.model_uid,
307
+ choices=[],
308
+ )
309
+ chunk["usage"] = CompletionUsage(
310
+ prompt_tokens=prompt_tokens,
311
+ completion_tokens=completion_tokens,
312
+ total_tokens=total_tokens,
313
+ )
314
+ yield chunk