xinference 0.11.0__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (37) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +10 -4
  3. xinference/core/model.py +2 -2
  4. xinference/fields.py +3 -1
  5. xinference/model/llm/ggml/chatglm.py +98 -13
  6. xinference/model/llm/ggml/llamacpp.py +49 -2
  7. xinference/model/llm/llm_family.json +132 -3
  8. xinference/model/llm/llm_family_modelscope.json +139 -3
  9. xinference/model/llm/pytorch/chatglm.py +48 -0
  10. xinference/model/llm/pytorch/core.py +23 -6
  11. xinference/model/llm/pytorch/deepseek_vl.py +35 -9
  12. xinference/model/llm/pytorch/internlm2.py +32 -1
  13. xinference/model/llm/pytorch/qwen_vl.py +38 -11
  14. xinference/model/llm/pytorch/utils.py +38 -1
  15. xinference/model/llm/pytorch/yi_vl.py +42 -14
  16. xinference/model/llm/sglang/core.py +31 -9
  17. xinference/model/llm/utils.py +25 -5
  18. xinference/model/llm/vllm/core.py +82 -3
  19. xinference/types.py +10 -1
  20. xinference/web/ui/build/asset-manifest.json +3 -3
  21. xinference/web/ui/build/index.html +1 -1
  22. xinference/web/ui/build/static/js/{main.8e44da4b.js → main.551aa479.js} +3 -3
  23. xinference/web/ui/build/static/js/main.551aa479.js.map +1 -0
  24. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +1 -0
  25. xinference/web/ui/node_modules/.cache/babel-loader/23caf6f1e52c43e983ca3bfd4189f41dbd645fa78f2dfdcd7f6b69bc41678665.json +1 -0
  26. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +1 -0
  27. {xinference-0.11.0.dist-info → xinference-0.11.1.dist-info}/METADATA +3 -2
  28. {xinference-0.11.0.dist-info → xinference-0.11.1.dist-info}/RECORD +33 -33
  29. xinference/web/ui/build/static/js/main.8e44da4b.js.map +0 -1
  30. xinference/web/ui/node_modules/.cache/babel-loader/1870cd6f7054d04e049e363c0a85526584fe25519378609d2838e28d7492bbf1.json +0 -1
  31. xinference/web/ui/node_modules/.cache/babel-loader/5393569d846332075b93b55656716a34f50e0a8c970be789502d7e6c49755fd7.json +0 -1
  32. xinference/web/ui/node_modules/.cache/babel-loader/ddaec68b88e5eff792df1e39a4b4b8b737bfc832293c015660c3c69334e3cf5c.json +0 -1
  33. /xinference/web/ui/build/static/js/{main.8e44da4b.js.LICENSE.txt → main.551aa479.js.LICENSE.txt} +0 -0
  34. {xinference-0.11.0.dist-info → xinference-0.11.1.dist-info}/LICENSE +0 -0
  35. {xinference-0.11.0.dist-info → xinference-0.11.1.dist-info}/WHEEL +0 -0
  36. {xinference-0.11.0.dist-info → xinference-0.11.1.dist-info}/entry_points.txt +0 -0
  37. {xinference-0.11.0.dist-info → xinference-0.11.1.dist-info}/top_level.txt +0 -0
@@ -1289,7 +1289,7 @@
1289
1289
  },
1290
1290
  {
1291
1291
  "version": 1,
1292
- "context_length": 204800,
1292
+ "context_length": 262144,
1293
1293
  "model_name": "Yi-200k",
1294
1294
  "model_lang": [
1295
1295
  "en",
@@ -1328,7 +1328,7 @@
1328
1328
  },
1329
1329
  {
1330
1330
  "version": 1,
1331
- "context_length": 204800,
1331
+ "context_length": 4096,
1332
1332
  "model_name": "Yi-chat",
1333
1333
  "model_lang": [
1334
1334
  "en",
@@ -1349,6 +1349,18 @@
1349
1349
  "model_id": "01ai/Yi-34B-Chat-{quantization}",
1350
1350
  "model_revision": "master"
1351
1351
  },
1352
+ {
1353
+ "model_format": "pytorch",
1354
+ "model_size_in_billions": 6,
1355
+ "quantizations": [
1356
+ "4-bit",
1357
+ "8-bit",
1358
+ "none"
1359
+ ],
1360
+ "model_hub": "modelscope",
1361
+ "model_id": "01ai/Yi-6B-Chat",
1362
+ "model_revision": "master"
1363
+ },
1352
1364
  {
1353
1365
  "model_format": "pytorch",
1354
1366
  "model_size_in_billions": 34,
@@ -1385,6 +1397,130 @@
1385
1397
  ]
1386
1398
  }
1387
1399
  },
1400
+ {
1401
+ "version": 1,
1402
+ "context_length": 4096,
1403
+ "model_name": "Yi-1.5",
1404
+ "model_lang": [
1405
+ "en",
1406
+ "zh"
1407
+ ],
1408
+ "model_ability": [
1409
+ "generate"
1410
+ ],
1411
+ "model_description": "Yi-1.5 is an upgraded version of Yi. It is continuously pre-trained on Yi with a high-quality corpus of 500B tokens and fine-tuned on 3M diverse fine-tuning samples.",
1412
+ "model_specs": [
1413
+ {
1414
+ "model_format": "pytorch",
1415
+ "model_size_in_billions": 6,
1416
+ "quantizations": [
1417
+ "4-bit",
1418
+ "8-bit",
1419
+ "none"
1420
+ ],
1421
+ "model_hub": "modelscope",
1422
+ "model_id": "01ai/Yi-1.5-6B",
1423
+ "model_revision": "master"
1424
+ },
1425
+ {
1426
+ "model_format": "pytorch",
1427
+ "model_size_in_billions": 9,
1428
+ "quantizations": [
1429
+ "4-bit",
1430
+ "8-bit",
1431
+ "none"
1432
+ ],
1433
+ "model_hub": "modelscope",
1434
+ "model_id": "01ai/Yi-1.5-9B",
1435
+ "model_revision": "master"
1436
+ },
1437
+ {
1438
+ "model_format": "pytorch",
1439
+ "model_size_in_billions": 34,
1440
+ "quantizations": [
1441
+ "4-bit",
1442
+ "8-bit",
1443
+ "none"
1444
+ ],
1445
+ "model_hub": "modelscope",
1446
+ "model_id": "01ai/Yi-1.5-34B",
1447
+ "model_revision": "master"
1448
+ }
1449
+ ]
1450
+ },
1451
+ {
1452
+ "version": 1,
1453
+ "context_length": 4096,
1454
+ "model_name": "Yi-1.5-chat",
1455
+ "model_lang": [
1456
+ "en",
1457
+ "zh"
1458
+ ],
1459
+ "model_ability": [
1460
+ "chat"
1461
+ ],
1462
+ "model_description": "Yi-1.5 is an upgraded version of Yi. It is continuously pre-trained on Yi with a high-quality corpus of 500B tokens and fine-tuned on 3M diverse fine-tuning samples.",
1463
+ "model_specs": [
1464
+ {
1465
+ "model_format": "pytorch",
1466
+ "model_size_in_billions": 6,
1467
+ "quantizations": [
1468
+ "4-bit",
1469
+ "8-bit",
1470
+ "none"
1471
+ ],
1472
+ "model_hub": "modelscope",
1473
+ "model_id": "01ai/Yi-1.5-6B-Chat",
1474
+ "model_revision": "master"
1475
+ },
1476
+ {
1477
+ "model_format": "pytorch",
1478
+ "model_size_in_billions": 9,
1479
+ "quantizations": [
1480
+ "4-bit",
1481
+ "8-bit",
1482
+ "none"
1483
+ ],
1484
+ "model_hub": "modelscope",
1485
+ "model_id": "01ai/Yi-1.5-9B-Chat",
1486
+ "model_revision": "master"
1487
+ },
1488
+ {
1489
+ "model_format": "pytorch",
1490
+ "model_size_in_billions": 34,
1491
+ "quantizations": [
1492
+ "4-bit",
1493
+ "8-bit",
1494
+ "none"
1495
+ ],
1496
+ "model_hub": "modelscope",
1497
+ "model_id": "01ai/Yi-1.5-34B-Chat",
1498
+ "model_revision": "master"
1499
+ }
1500
+ ],
1501
+ "prompt_style": {
1502
+ "style_name": "CHATML",
1503
+ "system_prompt": "",
1504
+ "roles": [
1505
+ "<|im_start|>user",
1506
+ "<|im_start|>assistant"
1507
+ ],
1508
+ "intra_message_sep": "<|im_end|>",
1509
+ "inter_message_sep": "",
1510
+ "stop_token_ids": [
1511
+ 2,
1512
+ 6,
1513
+ 7,
1514
+ 8
1515
+ ],
1516
+ "stop": [
1517
+ "<|endoftext|>",
1518
+ "<|im_start|>",
1519
+ "<|im_end|>",
1520
+ "<|im_sep|>"
1521
+ ]
1522
+ }
1523
+ },
1388
1524
  {
1389
1525
  "version": 1,
1390
1526
  "context_length": 2048,
@@ -2755,7 +2891,7 @@
2755
2891
  },
2756
2892
  {
2757
2893
  "version": 1,
2758
- "context_length": 204800,
2894
+ "context_length": 4096,
2759
2895
  "model_name": "yi-vl-chat",
2760
2896
  "model_lang": [
2761
2897
  "en",
@@ -147,14 +147,26 @@ class ChatglmPytorchChatModel(PytorchChatModel):
147
147
  )
148
148
  else:
149
149
  stream = generate_config.get("stream", False)
150
+ stream_options = generate_config.pop("stream_options", None)
151
+ include_usage = (
152
+ stream_options["include_usage"]
153
+ if isinstance(stream_options, dict)
154
+ else False
155
+ )
150
156
  if stream:
151
157
 
152
158
  def _stream_generator():
153
159
  last_chunk_text_length = 0
154
160
  chunk_id = "chat-" + str(uuid.uuid1())
161
+ prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
162
+ inputs = self._tokenizer([prompt], return_tensors="pt")
163
+ inputs = inputs.to(self._model.device)
164
+ prompt_tokens = len(inputs["input_ids"][0])
155
165
  for chunk_text, _ in self._model.stream_chat(
156
166
  self._tokenizer, prompt, chat_history, **kwargs
157
167
  ):
168
+ completion_tokens = completion_tokens + 1
169
+ total_tokens = prompt_tokens + completion_tokens
158
170
  chunk_text = chunk_text[last_chunk_text_length:]
159
171
  last_chunk_text_length += len(chunk_text)
160
172
  completion_choice = CompletionChoice(
@@ -166,7 +178,43 @@ class ChatglmPytorchChatModel(PytorchChatModel):
166
178
  created=int(time.time()),
167
179
  model=self.model_uid,
168
180
  choices=[completion_choice],
181
+ usage=CompletionUsage(
182
+ prompt_tokens=prompt_tokens,
183
+ completion_tokens=completion_tokens,
184
+ total_tokens=total_tokens,
185
+ ),
186
+ )
187
+ completion_choice = CompletionChoice(
188
+ text="", index=0, logprobs=None, finish_reason="stop"
189
+ )
190
+ chunk = CompletionChunk(
191
+ id=chunk_id,
192
+ object="text_completion",
193
+ created=int(time.time()),
194
+ model=self.model_uid,
195
+ choices=[completion_choice],
196
+ )
197
+ completion_usage = CompletionUsage(
198
+ prompt_tokens=prompt_tokens,
199
+ completion_tokens=completion_tokens,
200
+ total_tokens=total_tokens,
201
+ )
202
+ chunk["usage"] = completion_usage
203
+ yield chunk
204
+ if include_usage:
205
+ chunk = CompletionChunk(
206
+ id=chunk_id,
207
+ object="text_completion",
208
+ created=int(time.time()),
209
+ model=self.model_uid,
210
+ choices=[],
211
+ )
212
+ chunk["usage"] = CompletionUsage(
213
+ prompt_tokens=prompt_tokens,
214
+ completion_tokens=completion_tokens,
215
+ total_tokens=total_tokens,
169
216
  )
217
+ yield chunk
170
218
 
171
219
  return self._to_chat_completion_chunks(_stream_generator())
172
220
  else:
@@ -143,12 +143,17 @@ class PytorchModel(LLM):
143
143
  f"Failed to import 'PeftModel' from 'peft'. Please make sure 'peft' is installed.\n\n"
144
144
  )
145
145
 
146
- for peft_model in self._peft_model:
147
- # Apply LoRA
148
- self._model = PeftModel.from_pretrained(
149
- self._model,
150
- peft_model.local_path,
151
- )
146
+ for i, peft_model in enumerate(self._peft_model):
147
+ if i == 0:
148
+ self._model = PeftModel.from_pretrained(
149
+ self._model,
150
+ peft_model.local_path,
151
+ adapter_name=peft_model.lora_name,
152
+ )
153
+ else:
154
+ self._model.load_adapter(
155
+ peft_model.local_path, adapter_name=peft_model.lora_name
156
+ )
152
157
  logger.info(
153
158
  f"PEFT adaptor '{peft_model.lora_name}' successfully loaded for model '{self.model_uid}'."
154
159
  )
@@ -302,6 +307,18 @@ class PytorchModel(LLM):
302
307
  assert self._model is not None
303
308
  assert self._tokenizer is not None
304
309
 
310
+ lora_model = generate_config.pop("lora_name")
311
+
312
+ if lora_model is not None and self._peft_model is not None:
313
+ for lora in self._peft_model:
314
+ if lora_model == lora.lora_name:
315
+ self._model.set_adapter(lora_model)
316
+ logger.info(f"Set lora model to {lora_model}")
317
+ break
318
+ else:
319
+ self._model.disable_adapter()
320
+ logger.info(f"No lora model {lora_model} found, skip setting")
321
+
305
322
  stream = generate_config.get("stream", False)
306
323
  if not stream:
307
324
  if "falcon" in model_family_name:
@@ -155,7 +155,12 @@ class DeepSeekVLChatModel(PytorchChatModel):
155
155
  generate_config = {}
156
156
 
157
157
  stream = generate_config.get("stream", False)
158
-
158
+ stream_options = generate_config.pop("stream_options", None)
159
+ include_usage = (
160
+ stream_options["include_usage"]
161
+ if isinstance(stream_options, dict)
162
+ else False
163
+ )
159
164
  prompt, images = self._message_content_to_deepseek(prompt)
160
165
  prompt_messages: List[Dict[str, Any]] = [
161
166
  {
@@ -217,7 +222,7 @@ class DeepSeekVLChatModel(PytorchChatModel):
217
222
  )
218
223
 
219
224
  if stream:
220
- it = self._generate_stream(streamer, stop_str)
225
+ it = self._generate_stream(streamer, stop_str, include_usage, prompt)
221
226
  return self._to_chat_completion_chunks(it)
222
227
  else:
223
228
  c = self._generate(streamer, stop_str)
@@ -246,8 +251,13 @@ class DeepSeekVLChatModel(PytorchChatModel):
246
251
  )
247
252
  return c
248
253
 
249
- def _generate_stream(self, streamer, stop_str) -> Iterator[CompletionChunk]:
254
+ def _generate_stream(
255
+ self, streamer, stop_str, include_usage, prompt
256
+ ) -> Iterator[CompletionChunk]:
250
257
  completion_id = str(uuid.uuid1())
258
+ prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
259
+ input_ids = self._tokenizer(prompt).input_ids
260
+ prompt_tokens = len(input_ids)
251
261
  for i, new_text in enumerate(streamer):
252
262
  if new_text.endswith(stop_str):
253
263
  new_text = new_text[: -len(stop_str)]
@@ -261,10 +271,12 @@ class DeepSeekVLChatModel(PytorchChatModel):
261
271
  model=self.model_uid,
262
272
  choices=[completion_choice],
263
273
  )
274
+ completion_tokens = i
275
+ total_tokens = prompt_tokens + completion_tokens
264
276
  completion_usage = CompletionUsage(
265
- prompt_tokens=-1,
266
- completion_tokens=-1,
267
- total_tokens=-1,
277
+ prompt_tokens=prompt_tokens,
278
+ completion_tokens=completion_tokens,
279
+ total_tokens=total_tokens,
268
280
  )
269
281
  chunk["usage"] = completion_usage
270
282
  yield chunk
@@ -280,9 +292,23 @@ class DeepSeekVLChatModel(PytorchChatModel):
280
292
  choices=[completion_choice],
281
293
  )
282
294
  completion_usage = CompletionUsage(
283
- prompt_tokens=-1,
284
- completion_tokens=-1,
285
- total_tokens=-1,
295
+ prompt_tokens=prompt_tokens,
296
+ completion_tokens=completion_tokens,
297
+ total_tokens=total_tokens,
286
298
  )
287
299
  chunk["usage"] = completion_usage
288
300
  yield chunk
301
+ if include_usage:
302
+ chunk = CompletionChunk(
303
+ id=completion_id,
304
+ object="text_completion",
305
+ created=int(time.time()),
306
+ model=self.model_uid,
307
+ choices=[],
308
+ )
309
+ chunk["usage"] = CompletionUsage(
310
+ prompt_tokens=prompt_tokens,
311
+ completion_tokens=completion_tokens,
312
+ total_tokens=total_tokens,
313
+ )
314
+ yield chunk
@@ -108,6 +108,12 @@ class Internlm2PytorchChatModel(PytorchChatModel):
108
108
  kwargs["max_length"] = int(max_new_tokens)
109
109
 
110
110
  stream = generate_config.get("stream", False)
111
+ stream_options = generate_config.pop("stream_options", None)
112
+ include_usage = (
113
+ stream_options["include_usage"]
114
+ if isinstance(stream_options, dict)
115
+ else False
116
+ )
111
117
  if chat_history:
112
118
  input_history = [
113
119
  (chat_history[i]["content"], (chat_history[i + 1]["content"]))
@@ -122,9 +128,15 @@ class Internlm2PytorchChatModel(PytorchChatModel):
122
128
  def _stream_generator():
123
129
  last_chunk_text_length = 0
124
130
  chunk_id = "chat-" + str(uuid.uuid1())
131
+ prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
132
+ inputs = self._tokenizer([prompt], return_tensors="pt")
133
+ inputs = inputs.to(self._model.device)
134
+ prompt_tokens = len(inputs["input_ids"][0])
125
135
  for chunk_text, _ in self._model.stream_chat(
126
- self._tokenizer, prompt, input_history, **kwargs
136
+ self._tokenizer, prompt, chat_history, **kwargs
127
137
  ):
138
+ completion_tokens = completion_tokens + 1
139
+ total_tokens = prompt_tokens + completion_tokens
128
140
  chunk_text = chunk_text[last_chunk_text_length:]
129
141
  last_chunk_text_length += len(chunk_text)
130
142
  completion_choice = CompletionChoice(
@@ -136,7 +148,26 @@ class Internlm2PytorchChatModel(PytorchChatModel):
136
148
  created=int(time.time()),
137
149
  model=self.model_uid,
138
150
  choices=[completion_choice],
151
+ usage=CompletionUsage(
152
+ prompt_tokens=prompt_tokens,
153
+ completion_tokens=completion_tokens,
154
+ total_tokens=total_tokens,
155
+ ),
156
+ )
157
+ if include_usage:
158
+ chunk = CompletionChunk(
159
+ id=chunk_id,
160
+ object="text_completion",
161
+ created=int(time.time()),
162
+ model=self.model_uid,
163
+ choices=[],
164
+ )
165
+ chunk["usage"] = CompletionUsage(
166
+ prompt_tokens=prompt_tokens,
167
+ completion_tokens=completion_tokens,
168
+ total_tokens=total_tokens,
139
169
  )
170
+ yield chunk
140
171
 
141
172
  return self._to_chat_completion_chunks(_stream_generator())
142
173
  else:
@@ -134,9 +134,16 @@ class QwenVLChatModel(PytorchChatModel):
134
134
  query_to_response = []
135
135
 
136
136
  stream = generate_config.get("stream", False) if generate_config else False
137
-
137
+ stream_options = (
138
+ generate_config.pop("stream_options", None) if generate_config else None
139
+ )
140
+ include_usage = (
141
+ stream_options["include_usage"]
142
+ if isinstance(stream_options, dict)
143
+ else False
144
+ )
138
145
  if stream:
139
- it = self._generate_stream(prompt, qwen_history)
146
+ it = self._generate_stream(prompt, qwen_history, include_usage)
140
147
  return self._to_chat_completion_chunks(it)
141
148
  else:
142
149
  c = self._generate(prompt, qwen_history)
@@ -163,12 +170,16 @@ class QwenVLChatModel(PytorchChatModel):
163
170
  return c
164
171
 
165
172
  def _generate_stream(
166
- self, prompt: str, qwen_history: List
173
+ self, prompt: str, qwen_history: List, include_usage
167
174
  ) -> Iterator[CompletionChunk]:
168
175
  # response, history = model.chat(tokenizer, message, history=history)
169
176
  response_generator = self._model.chat_stream(
170
177
  self._tokenizer, query=prompt, history=qwen_history
171
178
  )
179
+ completion_id = str(uuid.uuid1())
180
+ prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
181
+ input_ids = self._tokenizer(prompt, allowed_special="all").input_ids
182
+ prompt_tokens = len(input_ids)
172
183
  full_response = ""
173
184
  for response in response_generator:
174
185
  inc_content = response[len(full_response) :]
@@ -177,16 +188,18 @@ class QwenVLChatModel(PytorchChatModel):
177
188
  text=inc_content, index=0, logprobs=None, finish_reason=None
178
189
  )
179
190
  completion_chunk = CompletionChunk(
180
- id=str(uuid.uuid1()),
191
+ id=completion_id,
181
192
  object="text_completion",
182
193
  created=int(time.time()),
183
194
  model=self.model_uid,
184
195
  choices=[completion_choice],
185
196
  )
197
+ completion_tokens = completion_tokens + 1
198
+ total_tokens = prompt_tokens + completion_tokens
186
199
  completion_usage = CompletionUsage(
187
- prompt_tokens=-1,
188
- completion_tokens=-1,
189
- total_tokens=-1,
200
+ prompt_tokens=prompt_tokens,
201
+ completion_tokens=completion_tokens,
202
+ total_tokens=total_tokens,
190
203
  )
191
204
  completion_chunk["usage"] = completion_usage
192
205
  yield completion_chunk
@@ -195,16 +208,30 @@ class QwenVLChatModel(PytorchChatModel):
195
208
  text="", index=0, logprobs=None, finish_reason="stop"
196
209
  )
197
210
  completion_chunk = CompletionChunk(
198
- id=str(uuid.uuid1()),
211
+ id=completion_id,
199
212
  object="text_completion",
200
213
  created=int(time.time()),
201
214
  model=self.model_uid,
202
215
  choices=[completion_choice],
203
216
  )
204
217
  completion_usage = CompletionUsage(
205
- prompt_tokens=-1,
206
- completion_tokens=-1,
207
- total_tokens=-1,
218
+ prompt_tokens=prompt_tokens,
219
+ completion_tokens=completion_tokens,
220
+ total_tokens=total_tokens,
208
221
  )
209
222
  completion_chunk["usage"] = completion_usage
210
223
  yield completion_chunk
224
+ if include_usage:
225
+ chunk = CompletionChunk(
226
+ id=completion_id,
227
+ object="text_completion",
228
+ created=int(time.time()),
229
+ model=self.model_uid,
230
+ choices=[],
231
+ )
232
+ chunk["usage"] = CompletionUsage(
233
+ prompt_tokens=prompt_tokens,
234
+ completion_tokens=completion_tokens,
235
+ total_tokens=total_tokens,
236
+ )
237
+ yield chunk
@@ -106,6 +106,10 @@ def generate_stream(
106
106
  context_len = get_context_length(model.config)
107
107
  stream_interval = generate_config.get("stream_interval", 2)
108
108
  stream = generate_config.get("stream", False)
109
+ stream_options = generate_config.pop("stream_options", None)
110
+ include_usage = (
111
+ stream_options["include_usage"] if isinstance(stream_options, dict) else False
112
+ )
109
113
 
110
114
  len_prompt = len(prompt)
111
115
 
@@ -333,6 +337,21 @@ def generate_stream(
333
337
 
334
338
  yield completion_chunk, completion_usage
335
339
 
340
+ if include_usage:
341
+ completion_chunk = CompletionChunk(
342
+ id=str(uuid.uuid1()),
343
+ object="text_completion",
344
+ created=int(time.time()),
345
+ model=model_uid,
346
+ choices=[],
347
+ )
348
+ completion_usage = CompletionUsage(
349
+ prompt_tokens=input_echo_len,
350
+ completion_tokens=i,
351
+ total_tokens=(input_echo_len + i),
352
+ )
353
+ yield completion_chunk, completion_usage
354
+
336
355
  # clean
337
356
  del past_key_values, out
338
357
  gc.collect()
@@ -352,7 +371,10 @@ def generate_stream_falcon(
352
371
  context_len = get_context_length(model.config)
353
372
  stream_interval = generate_config.get("stream_interval", 2)
354
373
  stream = generate_config.get("stream", False)
355
-
374
+ stream_options = generate_config.pop("stream_options", None)
375
+ include_usage = (
376
+ stream_options["include_usage"] if isinstance(stream_options, dict) else False
377
+ )
356
378
  len_prompt = len(prompt)
357
379
 
358
380
  temperature = float(generate_config.get("temperature", 1.0))
@@ -488,6 +510,21 @@ def generate_stream_falcon(
488
510
 
489
511
  yield completion_chunk, completion_usage
490
512
 
513
+ if include_usage:
514
+ completion_chunk = CompletionChunk(
515
+ id=str(uuid.uuid1()),
516
+ object="text_completion",
517
+ created=int(time.time()),
518
+ model=model_uid,
519
+ choices=[],
520
+ )
521
+ completion_usage = CompletionUsage(
522
+ prompt_tokens=input_echo_len,
523
+ completion_tokens=i,
524
+ total_tokens=(input_echo_len + i),
525
+ )
526
+ yield completion_chunk, completion_usage
527
+
491
528
  # clean
492
529
  gc.collect()
493
530
  empty_cache()