xinference 0.11.0__py3-none-any.whl → 0.11.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (56) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +30 -0
  3. xinference/client/restful/restful_client.py +29 -0
  4. xinference/core/cache_tracker.py +12 -1
  5. xinference/core/chat_interface.py +10 -4
  6. xinference/core/model.py +2 -2
  7. xinference/core/supervisor.py +30 -2
  8. xinference/core/utils.py +12 -0
  9. xinference/core/worker.py +4 -1
  10. xinference/deploy/cmdline.py +126 -0
  11. xinference/deploy/test/test_cmdline.py +24 -0
  12. xinference/fields.py +3 -1
  13. xinference/model/llm/__init__.py +2 -0
  14. xinference/model/llm/ggml/chatglm.py +98 -13
  15. xinference/model/llm/ggml/llamacpp.py +49 -2
  16. xinference/model/llm/llm_family.json +633 -9
  17. xinference/model/llm/llm_family.py +84 -10
  18. xinference/model/llm/llm_family_modelscope.json +337 -10
  19. xinference/model/llm/memory.py +332 -0
  20. xinference/model/llm/pytorch/chatglm.py +48 -0
  21. xinference/model/llm/pytorch/core.py +25 -6
  22. xinference/model/llm/pytorch/deepseek_vl.py +35 -9
  23. xinference/model/llm/pytorch/intern_vl.py +387 -0
  24. xinference/model/llm/pytorch/internlm2.py +32 -1
  25. xinference/model/llm/pytorch/qwen_vl.py +38 -11
  26. xinference/model/llm/pytorch/utils.py +38 -1
  27. xinference/model/llm/pytorch/yi_vl.py +42 -14
  28. xinference/model/llm/sglang/core.py +31 -9
  29. xinference/model/llm/utils.py +38 -5
  30. xinference/model/llm/vllm/core.py +87 -5
  31. xinference/model/rerank/core.py +23 -1
  32. xinference/model/utils.py +17 -7
  33. xinference/thirdparty/deepseek_vl/models/processing_vlm.py +1 -1
  34. xinference/thirdparty/deepseek_vl/models/siglip_vit.py +2 -2
  35. xinference/thirdparty/llava/mm_utils.py +3 -2
  36. xinference/thirdparty/llava/model/llava_arch.py +1 -1
  37. xinference/thirdparty/omnilmm/chat.py +6 -5
  38. xinference/types.py +10 -1
  39. xinference/web/ui/build/asset-manifest.json +3 -3
  40. xinference/web/ui/build/index.html +1 -1
  41. xinference/web/ui/build/static/js/{main.8e44da4b.js → main.551aa479.js} +3 -3
  42. xinference/web/ui/build/static/js/main.551aa479.js.map +1 -0
  43. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +1 -0
  44. xinference/web/ui/node_modules/.cache/babel-loader/23caf6f1e52c43e983ca3bfd4189f41dbd645fa78f2dfdcd7f6b69bc41678665.json +1 -0
  45. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +1 -0
  46. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/METADATA +10 -8
  47. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/RECORD +52 -50
  48. xinference/web/ui/build/static/js/main.8e44da4b.js.map +0 -1
  49. xinference/web/ui/node_modules/.cache/babel-loader/1870cd6f7054d04e049e363c0a85526584fe25519378609d2838e28d7492bbf1.json +0 -1
  50. xinference/web/ui/node_modules/.cache/babel-loader/5393569d846332075b93b55656716a34f50e0a8c970be789502d7e6c49755fd7.json +0 -1
  51. xinference/web/ui/node_modules/.cache/babel-loader/ddaec68b88e5eff792df1e39a4b4b8b737bfc832293c015660c3c69334e3cf5c.json +0 -1
  52. /xinference/web/ui/build/static/js/{main.8e44da4b.js.LICENSE.txt → main.551aa479.js.LICENSE.txt} +0 -0
  53. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/LICENSE +0 -0
  54. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/WHEEL +0 -0
  55. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/entry_points.txt +0 -0
  56. {xinference-0.11.0.dist-info → xinference-0.11.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,387 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import base64
15
+ import logging
16
+ import time
17
+ import uuid
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ from io import BytesIO
20
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
21
+
22
+ import requests
23
+ import torch
24
+ import torchvision.transforms as T
25
+ from PIL import Image
26
+ from torchvision.transforms.functional import InterpolationMode
27
+
28
+ from ....model.utils import select_device
29
+ from ....types import (
30
+ ChatCompletion,
31
+ ChatCompletionChunk,
32
+ ChatCompletionMessage,
33
+ Completion,
34
+ CompletionChoice,
35
+ CompletionUsage,
36
+ )
37
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
38
+ from .core import PytorchChatModel, PytorchGenerateConfig
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
43
+ IMAGENET_STD = (0.229, 0.224, 0.225)
44
+
45
+
46
+ class InternVLChatModel(PytorchChatModel):
47
+ def __init__(self, *args, **kwargs):
48
+ super().__init__(*args, **kwargs)
49
+ self._tokenizer = None
50
+ self._model = None
51
+
52
+ @classmethod
53
+ def match(
54
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
55
+ ) -> bool:
56
+ family = model_family.model_family or model_family.model_name
57
+ if "internvl" in family.lower():
58
+ return True
59
+ return False
60
+
61
+ def load(self, **kwargs):
62
+ from transformers import AutoModel, AutoTokenizer
63
+ from transformers.generation import GenerationConfig
64
+
65
+ device = self._pytorch_model_config.get("device", "auto")
66
+ device = select_device(device)
67
+ # for multiple GPU, set back to auto to make multiple devices work
68
+ device = "auto" if device == "cuda" else device
69
+
70
+ self._tokenizer = AutoTokenizer.from_pretrained(
71
+ self.model_path,
72
+ trust_remote_code=True,
73
+ )
74
+
75
+ kwargs = {
76
+ "torch_dtype": torch.bfloat16,
77
+ "low_cpu_mem_usage": True,
78
+ "trust_remote_code": True,
79
+ "device_map": device,
80
+ }
81
+
82
+ if "Int8" in self.model_spec.quantizations:
83
+ kwargs.update(
84
+ {
85
+ "load_in_8bit": True,
86
+ "device_map": device,
87
+ }
88
+ )
89
+ elif "mini" in self.model_family.model_name:
90
+ kwargs.pop("device_map")
91
+
92
+ self._model = AutoModel.from_pretrained(self.model_path, **kwargs).eval()
93
+
94
+ if "Int8" not in self.model_spec.quantizations:
95
+ self._model.cuda()
96
+
97
+ # Specify hyperparameters for generation
98
+ self._model.generation_config = GenerationConfig.from_pretrained(
99
+ self.model_path,
100
+ trust_remote_code=True,
101
+ )
102
+
103
+ def _message_content_to_intern(self, content):
104
+ def _load_image(_url):
105
+ if _url.startswith("data:"):
106
+ logging.info("Parse url by base64 decoder.")
107
+ # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
108
+ # e.g. f"data:image/jpeg;base64,{base64_image}"
109
+ _type, data = _url.split(";")
110
+ _, ext = _type.split("/")
111
+ data = data[len("base64,") :]
112
+ data = base64.b64decode(data.encode("utf-8"))
113
+ return Image.open(BytesIO(data)).convert("RGB")
114
+ else:
115
+ try:
116
+ response = requests.get(_url)
117
+ except requests.exceptions.MissingSchema:
118
+ return Image.open(_url).convert("RGB")
119
+ else:
120
+ return Image.open(BytesIO(response.content)).convert("RGB")
121
+
122
+ if not isinstance(content, str):
123
+ texts = []
124
+ image_urls = []
125
+ for c in content:
126
+ c_type = c.get("type")
127
+ if c_type == "text":
128
+ texts.append(c["text"])
129
+ elif c_type == "image_url":
130
+ image_urls.append(c["image_url"]["url"])
131
+ image_futures = []
132
+ with ThreadPoolExecutor() as executor:
133
+ for image_url in image_urls:
134
+ fut = executor.submit(_load_image, image_url)
135
+ image_futures.append(fut)
136
+ images = [fut.result() for fut in image_futures]
137
+ text = " ".join(texts)
138
+ if len(images) == 0:
139
+ return text, None
140
+ else:
141
+ return text, images
142
+ return content, None
143
+
144
+ def _history_content_to_intern(
145
+ self,
146
+ chat_history: List[ChatCompletionMessage],
147
+ IMG_START_TOKEN="<img>",
148
+ IMG_END_TOKEN="</img>",
149
+ IMG_CONTEXT_TOKEN="<IMG_CONTEXT>",
150
+ ):
151
+ def _image_to_piexl_values(images):
152
+ load_images = []
153
+ for image in images:
154
+ if image.startswith("data:"):
155
+ logging.info("Parse url by base64 decoder.")
156
+ # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
157
+ # e.g. f"data:image/jpeg;base64,{base64_image}"
158
+ _type, data = image.split(";")
159
+ _, ext = _type.split("/")
160
+ data = data[len("base64,") :]
161
+ data = base64.b64decode(data.encode("utf-8"))
162
+ img = Image.open(BytesIO(data)).convert("RGB")
163
+ pixel_value = (
164
+ self._load_image(img, max_num=6).to(torch.bfloat16).cuda()
165
+ )
166
+ load_images.append(pixel_value)
167
+ else:
168
+ try:
169
+ response = requests.get(image)
170
+ except requests.exceptions.MissingSchema:
171
+ img = Image.open(image).convert("RGB")
172
+ else:
173
+ img = Image.open(BytesIO(response.content)).convert("RGB")
174
+ pixel_value = (
175
+ self._load_image(img, max_num=6).to(torch.bfloat16).cuda()
176
+ )
177
+ load_images.append(pixel_value)
178
+ return torch.cat(tuple(load_images), dim=0)
179
+
180
+ history: List[Tuple] = []
181
+ pixel_values = None
182
+ for i in range(0, len(chat_history), 2):
183
+ tmp = []
184
+ images: List[str] = []
185
+ user = chat_history[i]["content"]
186
+ if isinstance(user, List):
187
+ for content in user:
188
+ c_type = content.get("type")
189
+ if c_type == "text":
190
+ tmp.append(content["text"])
191
+ elif c_type == "image_url" and not history:
192
+ images.append(content["image_url"]["url"])
193
+ if not history:
194
+ pixel_values = _image_to_piexl_values(images)
195
+ image_bs = pixel_values.shape[0]
196
+ image_tokens = (
197
+ IMG_START_TOKEN
198
+ + IMG_CONTEXT_TOKEN * self._model.num_image_token * image_bs
199
+ + IMG_END_TOKEN
200
+ )
201
+ tmp[0] = image_tokens + "\n" + tmp[0]
202
+ else:
203
+ tmp.append(user)
204
+ tmp.append(chat_history[i + 1]["content"])
205
+ history.append(tuple(tmp))
206
+ return history, pixel_values
207
+
208
+ def _load_image(_url):
209
+ if _url.startswith("data:"):
210
+ logging.info("Parse url by base64 decoder.")
211
+ # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
212
+ # e.g. f"data:image/jpeg;base64,{base64_image}"
213
+ _type, data = _url.split(";")
214
+ _, ext = _type.split("/")
215
+ data = data[len("base64,") :]
216
+ data = base64.b64decode(data.encode("utf-8"))
217
+
218
+ return Image.open(BytesIO(data)).convert("RGB")
219
+ else:
220
+ try:
221
+ response = requests.get(_url)
222
+ except requests.exceptions.MissingSchema:
223
+ return Image.open(_url).convert("RGB")
224
+ else:
225
+ return Image.open(BytesIO(response.content)).convert("RGB")
226
+
227
+ if not isinstance(content, str):
228
+ texts = []
229
+ image_urls = []
230
+ for c in content:
231
+ c_type = c.get("type")
232
+ if c_type == "text":
233
+ texts.append(c["text"])
234
+ elif c_type == "image_url":
235
+ image_urls.append(c["image_url"]["url"])
236
+ image_futures = []
237
+ with ThreadPoolExecutor() as executor:
238
+ for image_url in image_urls:
239
+ fut = executor.submit(_load_image, image_url)
240
+ image_futures.append(fut)
241
+ images = [fut.result() for fut in image_futures]
242
+ text = " ".join(texts)
243
+ if len(images) == 0:
244
+ return text
245
+ else:
246
+ return text, images
247
+ return content
248
+
249
+ def _find_closest_aspect_ratio(
250
+ self, aspect_ratio, target_ratios, width, height, image_size
251
+ ):
252
+ best_ratio_diff = float("inf")
253
+ best_ratio = (1, 1)
254
+ area = width * height
255
+ for ratio in target_ratios:
256
+ target_aspect_ratio = ratio[0] / ratio[1]
257
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
258
+ if ratio_diff < best_ratio_diff:
259
+ best_ratio_diff = ratio_diff
260
+ best_ratio = ratio
261
+ elif ratio_diff == best_ratio_diff:
262
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
263
+ best_ratio = ratio
264
+ return best_ratio
265
+
266
+ def _dynamic_preprocess(
267
+ self, image, min_num=1, max_num=6, image_size=448, use_thumbnail=False
268
+ ):
269
+ orig_width, orig_height = image.size
270
+ aspect_ratio = orig_width / orig_height
271
+
272
+ # calculate the existing image aspect ratio
273
+ target_ratios = set(
274
+ (i, j)
275
+ for n in range(min_num, max_num + 1)
276
+ for i in range(1, n + 1)
277
+ for j in range(1, n + 1)
278
+ if i * j <= max_num and i * j >= min_num
279
+ )
280
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
281
+
282
+ # find the closest aspect ratio to the target
283
+ target_aspect_ratio = self._find_closest_aspect_ratio(
284
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
285
+ )
286
+
287
+ # calculate the target width and height
288
+ target_width = image_size * target_aspect_ratio[0]
289
+ target_height = image_size * target_aspect_ratio[1]
290
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
291
+
292
+ # resize the image
293
+ resized_img = image.resize((target_width, target_height))
294
+ processed_images = []
295
+ for i in range(blocks):
296
+ box = (
297
+ (i % (target_width // image_size)) * image_size,
298
+ (i // (target_width // image_size)) * image_size,
299
+ ((i % (target_width // image_size)) + 1) * image_size,
300
+ ((i // (target_width // image_size)) + 1) * image_size,
301
+ )
302
+ # split the image
303
+ split_img = resized_img.crop(box)
304
+ processed_images.append(split_img)
305
+ assert len(processed_images) == blocks
306
+ if use_thumbnail and len(processed_images) != 1:
307
+ thumbnail_img = image.resize((image_size, image_size))
308
+ processed_images.append(thumbnail_img)
309
+ return processed_images
310
+
311
+ def _build_transform(self, input_size):
312
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
313
+ transform = T.Compose(
314
+ [
315
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
316
+ T.Resize(
317
+ (input_size, input_size), interpolation=InterpolationMode.BICUBIC
318
+ ),
319
+ T.ToTensor(),
320
+ T.Normalize(mean=MEAN, std=STD),
321
+ ]
322
+ )
323
+ return transform
324
+
325
+ def _load_image(self, image_file, input_size=448, max_num=6):
326
+ transform = self._build_transform(input_size=input_size)
327
+ images = self._dynamic_preprocess(
328
+ image_file, image_size=input_size, use_thumbnail=True, max_num=max_num
329
+ )
330
+ pixel_values = [transform(image) for image in images]
331
+ pixel_values = torch.stack(pixel_values)
332
+ return pixel_values
333
+
334
+ def chat(
335
+ self,
336
+ prompt: Union[str, List[Dict]],
337
+ system_prompt: Optional[str] = None,
338
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
339
+ generate_config: Optional[PytorchGenerateConfig] = None,
340
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
341
+ if generate_config and generate_config.pop("stream"):
342
+ raise Exception(
343
+ f"Chat with model {self.model_family.model_name} does not support stream."
344
+ )
345
+ sanitized_config = {
346
+ "num_beams": 1,
347
+ "max_new_tokens": generate_config.get("max_tokens", 512)
348
+ if generate_config
349
+ else 512,
350
+ "do_sample": False,
351
+ }
352
+
353
+ content, image = self._message_content_to_intern(prompt)
354
+
355
+ history = None
356
+ if chat_history:
357
+ history, pixel_values = self._history_content_to_intern(chat_history)
358
+ else:
359
+ load_images = []
360
+ for img in image:
361
+ pixel_value = self._load_image(img, max_num=6).to(torch.bfloat16).cuda()
362
+ load_images.append(pixel_value)
363
+ pixel_values = torch.cat(tuple(load_images), dim=0)
364
+
365
+ response, history = self._model.chat(
366
+ self._tokenizer,
367
+ pixel_values,
368
+ content,
369
+ sanitized_config,
370
+ history=history,
371
+ return_history=True,
372
+ )
373
+ chunk = Completion(
374
+ id=str(uuid.uuid1()),
375
+ object="text_completion",
376
+ created=int(time.time()),
377
+ model=self.model_uid,
378
+ choices=[
379
+ CompletionChoice(
380
+ index=0, text=response, finish_reason="stop", logprobs=None
381
+ )
382
+ ],
383
+ usage=CompletionUsage(
384
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
385
+ ),
386
+ )
387
+ return self._to_chat_completion(chunk)
@@ -108,6 +108,12 @@ class Internlm2PytorchChatModel(PytorchChatModel):
108
108
  kwargs["max_length"] = int(max_new_tokens)
109
109
 
110
110
  stream = generate_config.get("stream", False)
111
+ stream_options = generate_config.pop("stream_options", None)
112
+ include_usage = (
113
+ stream_options["include_usage"]
114
+ if isinstance(stream_options, dict)
115
+ else False
116
+ )
111
117
  if chat_history:
112
118
  input_history = [
113
119
  (chat_history[i]["content"], (chat_history[i + 1]["content"]))
@@ -122,9 +128,15 @@ class Internlm2PytorchChatModel(PytorchChatModel):
122
128
  def _stream_generator():
123
129
  last_chunk_text_length = 0
124
130
  chunk_id = "chat-" + str(uuid.uuid1())
131
+ prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
132
+ inputs = self._tokenizer([prompt], return_tensors="pt")
133
+ inputs = inputs.to(self._model.device)
134
+ prompt_tokens = len(inputs["input_ids"][0])
125
135
  for chunk_text, _ in self._model.stream_chat(
126
- self._tokenizer, prompt, input_history, **kwargs
136
+ self._tokenizer, prompt, chat_history, **kwargs
127
137
  ):
138
+ completion_tokens = completion_tokens + 1
139
+ total_tokens = prompt_tokens + completion_tokens
128
140
  chunk_text = chunk_text[last_chunk_text_length:]
129
141
  last_chunk_text_length += len(chunk_text)
130
142
  completion_choice = CompletionChoice(
@@ -136,7 +148,26 @@ class Internlm2PytorchChatModel(PytorchChatModel):
136
148
  created=int(time.time()),
137
149
  model=self.model_uid,
138
150
  choices=[completion_choice],
151
+ usage=CompletionUsage(
152
+ prompt_tokens=prompt_tokens,
153
+ completion_tokens=completion_tokens,
154
+ total_tokens=total_tokens,
155
+ ),
156
+ )
157
+ if include_usage:
158
+ chunk = CompletionChunk(
159
+ id=chunk_id,
160
+ object="text_completion",
161
+ created=int(time.time()),
162
+ model=self.model_uid,
163
+ choices=[],
164
+ )
165
+ chunk["usage"] = CompletionUsage(
166
+ prompt_tokens=prompt_tokens,
167
+ completion_tokens=completion_tokens,
168
+ total_tokens=total_tokens,
139
169
  )
170
+ yield chunk
140
171
 
141
172
  return self._to_chat_completion_chunks(_stream_generator())
142
173
  else:
@@ -134,9 +134,16 @@ class QwenVLChatModel(PytorchChatModel):
134
134
  query_to_response = []
135
135
 
136
136
  stream = generate_config.get("stream", False) if generate_config else False
137
-
137
+ stream_options = (
138
+ generate_config.pop("stream_options", None) if generate_config else None
139
+ )
140
+ include_usage = (
141
+ stream_options["include_usage"]
142
+ if isinstance(stream_options, dict)
143
+ else False
144
+ )
138
145
  if stream:
139
- it = self._generate_stream(prompt, qwen_history)
146
+ it = self._generate_stream(prompt, qwen_history, include_usage)
140
147
  return self._to_chat_completion_chunks(it)
141
148
  else:
142
149
  c = self._generate(prompt, qwen_history)
@@ -163,12 +170,16 @@ class QwenVLChatModel(PytorchChatModel):
163
170
  return c
164
171
 
165
172
  def _generate_stream(
166
- self, prompt: str, qwen_history: List
173
+ self, prompt: str, qwen_history: List, include_usage
167
174
  ) -> Iterator[CompletionChunk]:
168
175
  # response, history = model.chat(tokenizer, message, history=history)
169
176
  response_generator = self._model.chat_stream(
170
177
  self._tokenizer, query=prompt, history=qwen_history
171
178
  )
179
+ completion_id = str(uuid.uuid1())
180
+ prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
181
+ input_ids = self._tokenizer(prompt, allowed_special="all").input_ids
182
+ prompt_tokens = len(input_ids)
172
183
  full_response = ""
173
184
  for response in response_generator:
174
185
  inc_content = response[len(full_response) :]
@@ -177,16 +188,18 @@ class QwenVLChatModel(PytorchChatModel):
177
188
  text=inc_content, index=0, logprobs=None, finish_reason=None
178
189
  )
179
190
  completion_chunk = CompletionChunk(
180
- id=str(uuid.uuid1()),
191
+ id=completion_id,
181
192
  object="text_completion",
182
193
  created=int(time.time()),
183
194
  model=self.model_uid,
184
195
  choices=[completion_choice],
185
196
  )
197
+ completion_tokens = completion_tokens + 1
198
+ total_tokens = prompt_tokens + completion_tokens
186
199
  completion_usage = CompletionUsage(
187
- prompt_tokens=-1,
188
- completion_tokens=-1,
189
- total_tokens=-1,
200
+ prompt_tokens=prompt_tokens,
201
+ completion_tokens=completion_tokens,
202
+ total_tokens=total_tokens,
190
203
  )
191
204
  completion_chunk["usage"] = completion_usage
192
205
  yield completion_chunk
@@ -195,16 +208,30 @@ class QwenVLChatModel(PytorchChatModel):
195
208
  text="", index=0, logprobs=None, finish_reason="stop"
196
209
  )
197
210
  completion_chunk = CompletionChunk(
198
- id=str(uuid.uuid1()),
211
+ id=completion_id,
199
212
  object="text_completion",
200
213
  created=int(time.time()),
201
214
  model=self.model_uid,
202
215
  choices=[completion_choice],
203
216
  )
204
217
  completion_usage = CompletionUsage(
205
- prompt_tokens=-1,
206
- completion_tokens=-1,
207
- total_tokens=-1,
218
+ prompt_tokens=prompt_tokens,
219
+ completion_tokens=completion_tokens,
220
+ total_tokens=total_tokens,
208
221
  )
209
222
  completion_chunk["usage"] = completion_usage
210
223
  yield completion_chunk
224
+ if include_usage:
225
+ chunk = CompletionChunk(
226
+ id=completion_id,
227
+ object="text_completion",
228
+ created=int(time.time()),
229
+ model=self.model_uid,
230
+ choices=[],
231
+ )
232
+ chunk["usage"] = CompletionUsage(
233
+ prompt_tokens=prompt_tokens,
234
+ completion_tokens=completion_tokens,
235
+ total_tokens=total_tokens,
236
+ )
237
+ yield chunk
@@ -106,6 +106,10 @@ def generate_stream(
106
106
  context_len = get_context_length(model.config)
107
107
  stream_interval = generate_config.get("stream_interval", 2)
108
108
  stream = generate_config.get("stream", False)
109
+ stream_options = generate_config.pop("stream_options", None)
110
+ include_usage = (
111
+ stream_options["include_usage"] if isinstance(stream_options, dict) else False
112
+ )
109
113
 
110
114
  len_prompt = len(prompt)
111
115
 
@@ -333,6 +337,21 @@ def generate_stream(
333
337
 
334
338
  yield completion_chunk, completion_usage
335
339
 
340
+ if include_usage:
341
+ completion_chunk = CompletionChunk(
342
+ id=str(uuid.uuid1()),
343
+ object="text_completion",
344
+ created=int(time.time()),
345
+ model=model_uid,
346
+ choices=[],
347
+ )
348
+ completion_usage = CompletionUsage(
349
+ prompt_tokens=input_echo_len,
350
+ completion_tokens=i,
351
+ total_tokens=(input_echo_len + i),
352
+ )
353
+ yield completion_chunk, completion_usage
354
+
336
355
  # clean
337
356
  del past_key_values, out
338
357
  gc.collect()
@@ -352,7 +371,10 @@ def generate_stream_falcon(
352
371
  context_len = get_context_length(model.config)
353
372
  stream_interval = generate_config.get("stream_interval", 2)
354
373
  stream = generate_config.get("stream", False)
355
-
374
+ stream_options = generate_config.pop("stream_options", None)
375
+ include_usage = (
376
+ stream_options["include_usage"] if isinstance(stream_options, dict) else False
377
+ )
356
378
  len_prompt = len(prompt)
357
379
 
358
380
  temperature = float(generate_config.get("temperature", 1.0))
@@ -488,6 +510,21 @@ def generate_stream_falcon(
488
510
 
489
511
  yield completion_chunk, completion_usage
490
512
 
513
+ if include_usage:
514
+ completion_chunk = CompletionChunk(
515
+ id=str(uuid.uuid1()),
516
+ object="text_completion",
517
+ created=int(time.time()),
518
+ model=model_uid,
519
+ choices=[],
520
+ )
521
+ completion_usage = CompletionUsage(
522
+ prompt_tokens=input_echo_len,
523
+ completion_tokens=i,
524
+ total_tokens=(input_echo_len + i),
525
+ )
526
+ yield completion_chunk, completion_usage
527
+
491
528
  # clean
492
529
  gc.collect()
493
530
  empty_cache()