xinference 0.10.2.post1__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (92) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/auth_service.py +1 -1
  3. xinference/api/restful_api.py +53 -61
  4. xinference/client/restful/restful_client.py +52 -57
  5. xinference/conftest.py +1 -1
  6. xinference/core/cache_tracker.py +1 -1
  7. xinference/core/event.py +1 -1
  8. xinference/core/model.py +15 -4
  9. xinference/core/status_guard.py +1 -1
  10. xinference/core/supervisor.py +58 -72
  11. xinference/core/worker.py +73 -102
  12. xinference/deploy/cmdline.py +175 -6
  13. xinference/deploy/test/test_cmdline.py +2 -0
  14. xinference/deploy/utils.py +1 -1
  15. xinference/device_utils.py +29 -3
  16. xinference/fields.py +5 -1
  17. xinference/model/audio/model_spec.json +8 -1
  18. xinference/model/audio/whisper.py +88 -12
  19. xinference/model/core.py +2 -2
  20. xinference/model/embedding/core.py +13 -0
  21. xinference/model/image/__init__.py +29 -0
  22. xinference/model/image/core.py +6 -0
  23. xinference/model/image/custom.py +109 -0
  24. xinference/model/llm/__init__.py +92 -32
  25. xinference/model/llm/core.py +57 -102
  26. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +2 -2
  27. xinference/model/llm/llm_family.json +446 -2
  28. xinference/model/llm/llm_family.py +45 -41
  29. xinference/model/llm/llm_family_modelscope.json +208 -1
  30. xinference/model/llm/pytorch/deepseek_vl.py +89 -33
  31. xinference/model/llm/pytorch/qwen_vl.py +67 -12
  32. xinference/model/llm/pytorch/yi_vl.py +62 -45
  33. xinference/model/llm/utils.py +45 -15
  34. xinference/model/llm/vllm/core.py +21 -4
  35. xinference/model/rerank/core.py +48 -20
  36. xinference/thirdparty/omnilmm/chat.py +2 -1
  37. xinference/thirdparty/omnilmm/model/omnilmm.py +2 -1
  38. xinference/types.py +2 -0
  39. xinference/web/ui/build/asset-manifest.json +6 -3
  40. xinference/web/ui/build/index.html +1 -1
  41. xinference/web/ui/build/static/css/main.54bca460.css +2 -0
  42. xinference/web/ui/build/static/css/main.54bca460.css.map +1 -0
  43. xinference/web/ui/build/static/js/main.8e44da4b.js +3 -0
  44. xinference/web/ui/build/static/js/{main.26fdbfbe.js.LICENSE.txt → main.8e44da4b.js.LICENSE.txt} +7 -0
  45. xinference/web/ui/build/static/js/main.8e44da4b.js.map +1 -0
  46. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +1 -0
  47. xinference/web/ui/node_modules/.cache/babel-loader/29dda700ab913cf7f2cfabe450ddabfb283e96adfa3ec9d315b2fa6c63cd375c.json +1 -0
  48. xinference/web/ui/node_modules/.cache/babel-loader/2c63e940b945fd5817157e08a42b889b30d668ea4c91332f48ef2b1b9d26f520.json +1 -0
  49. xinference/web/ui/node_modules/.cache/babel-loader/4135fe8745434cbce6438d1ebfa47422e0c77d884db4edc75c8bf32ea1d50621.json +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/4de0a71074f9cbe1e7862750dcdd08cbc1bae7d9d9849a78b1783ca670017b3c.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/53f6c0c0afb51265cd8fb940daeb65523501879ac2a8c03a1ead22b9793c5041.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/8ccbb839002bc5bc03e0a0e7612362bf92f6ae64f87e094f8682d6a6fe4619bb.json +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/97ed30d6e22cf76f0733651e2c18364689a01665d0b5fe811c1b7ca3eb713c82.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/9c0c70f1838913aaa792a0d2260f17f90fd177b95698ed46b7bc3050eb712c1c.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/9cfd33238ca43e5bf9fc7e442690e8cc6027c73553db36de87e3597ed524ee4b.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/ada71518a429f821a9b1dea38bc951447f03c8db509887e0980b893acac938f3.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/b6c9558d28b5972bb8b2691c5a76a2c8814a815eb3443126da9f49f7d6a0c118.json +1 -0
  59. xinference/web/ui/node_modules/.cache/babel-loader/bb0f721c084a4d85c09201c984f02ee8437d3b6c5c38a57cb4a101f653daef1b.json +1 -0
  60. xinference/web/ui/node_modules/.cache/babel-loader/ddaec68b88e5eff792df1e39a4b4b8b737bfc832293c015660c3c69334e3cf5c.json +1 -0
  61. xinference/web/ui/node_modules/.package-lock.json +33 -0
  62. xinference/web/ui/node_modules/clipboard/.babelrc.json +11 -0
  63. xinference/web/ui/node_modules/clipboard/.eslintrc.json +24 -0
  64. xinference/web/ui/node_modules/clipboard/.prettierrc.json +9 -0
  65. xinference/web/ui/node_modules/clipboard/bower.json +18 -0
  66. xinference/web/ui/node_modules/clipboard/composer.json +25 -0
  67. xinference/web/ui/node_modules/clipboard/package.json +63 -0
  68. xinference/web/ui/node_modules/delegate/package.json +31 -0
  69. xinference/web/ui/node_modules/good-listener/bower.json +11 -0
  70. xinference/web/ui/node_modules/good-listener/package.json +35 -0
  71. xinference/web/ui/node_modules/select/bower.json +13 -0
  72. xinference/web/ui/node_modules/select/package.json +29 -0
  73. xinference/web/ui/node_modules/tiny-emitter/package.json +53 -0
  74. xinference/web/ui/package-lock.json +34 -0
  75. xinference/web/ui/package.json +1 -0
  76. {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/METADATA +14 -13
  77. {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/RECORD +81 -60
  78. xinference/client/oscar/__init__.py +0 -13
  79. xinference/client/oscar/actor_client.py +0 -611
  80. xinference/model/llm/pytorch/spec_decoding_utils.py +0 -531
  81. xinference/model/llm/pytorch/spec_model.py +0 -186
  82. xinference/web/ui/build/static/js/main.26fdbfbe.js +0 -3
  83. xinference/web/ui/build/static/js/main.26fdbfbe.js.map +0 -1
  84. xinference/web/ui/node_modules/.cache/babel-loader/63a4c48f0326d071c7772c46598215c006ae41fd3d4ff3577fe717de66ad6e89.json +0 -1
  85. xinference/web/ui/node_modules/.cache/babel-loader/de0299226173b0662b573f49e3992220f6611947073bd66ac079728a8bc8837d.json +0 -1
  86. xinference/web/ui/node_modules/.cache/babel-loader/e9b52d171223bb59fb918316297a051cdfd42dd453e8260fd918e90bc0a4ebdf.json +0 -1
  87. xinference/web/ui/node_modules/.cache/babel-loader/f4d5d1a41892a754c1ee0237450d804b20612d1b657945b59e564161ea47aa7a.json +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/fad4cd70de36ef6e6d5f8fd74a10ded58d964a8a91ef7681693fbb8376552da7.json +0 -1
  89. {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/LICENSE +0 -0
  90. {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/WHEEL +0 -0
  91. {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/entry_points.txt +0 -0
  92. {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/top_level.txt +0 -0
@@ -1,531 +0,0 @@
1
- # Copyright 2022-2023 XProbe Inc.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import gc
15
- import logging
16
- import time
17
- import uuid
18
- from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple
19
-
20
- from ....device_utils import empty_cache
21
-
22
- try:
23
- import torch
24
- from torch.nn import functional as F
25
- except ImportError:
26
- raise ImportError(
27
- f"Failed to import module 'torch'. Please make sure 'torch' is installed.\n\n"
28
- )
29
-
30
- try:
31
- from transformers import PreTrainedModel, PreTrainedTokenizer
32
- from transformers.generation.logits_process import (
33
- LogitsProcessorList,
34
- TemperatureLogitsWarper,
35
- TopKLogitsWarper,
36
- TopPLogitsWarper,
37
- )
38
- except ImportError:
39
- error_message = "Failed to import module 'transformers'"
40
- installation_guide = [
41
- "Please make sure 'transformers' is installed. ",
42
- "You can install it by `pip install transformers`\n",
43
- ]
44
-
45
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
46
-
47
-
48
- from ....types import CompletionChoice, CompletionChunk, CompletionUsage
49
-
50
- logger = logging.getLogger(__name__)
51
-
52
-
53
- def prepare_logits_processor(
54
- temperature: float, top_p: float, top_k: int
55
- ) -> LogitsProcessorList:
56
- processor_list = LogitsProcessorList()
57
- # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases.
58
- if temperature >= 1e-5 and temperature != 1.0:
59
- processor_list.append(TemperatureLogitsWarper(temperature))
60
- if 1e-8 <= top_p < 1.0:
61
- processor_list.append(TopPLogitsWarper(top_p))
62
- if top_k > 0:
63
- processor_list.append(TopKLogitsWarper(top_k))
64
- return processor_list
65
-
66
-
67
- def get_context_length(config):
68
- """Get the context length of a model from a huggingface model config."""
69
- if (
70
- hasattr(config, "max_sequence_length")
71
- and config.max_sequence_length is not None
72
- ):
73
- return config.max_sequence_length
74
- elif hasattr(config, "seq_length") and config.seq_length is not None:
75
- return config.seq_length
76
- elif (
77
- hasattr(config, "max_position_embeddings")
78
- and config.max_position_embeddings is not None
79
- ):
80
- return config.max_position_embeddings
81
- else:
82
- return 2048
83
-
84
-
85
- def normalize_logits(
86
- logits_processor: LogitsProcessorList,
87
- input_ids: List[int],
88
- logits: torch.FloatTensor, # [1, n_seq, n_vocab]
89
- ) -> torch.Tensor:
90
- """
91
- Parameters
92
- ----------
93
- logits : torch.Tensor
94
- Logits of shape `(n_batch, n_seq, n_vocab)`.
95
-
96
- Returns
97
- -------
98
- torch.Tensor
99
- Normalized logits of shape `(n_batch, n_seq, n_vocab)`.
100
- """
101
-
102
- def _helper(
103
- _input_ids: torch.LongTensor, _logits: torch.FloatTensor # [1, n_vocab]
104
- ) -> torch.Tensor:
105
- if logits_processor:
106
- last_token_logits = logits_processor(
107
- _input_ids,
108
- _logits,
109
- )[0]
110
- else:
111
- return _logits[0]
112
-
113
- return last_token_logits # [n_vocab,]
114
-
115
- input_ids = torch.as_tensor([input_ids], device=logits.device).long()
116
- for i in range(logits.shape[1]):
117
- normalized = _helper(
118
- input_ids[
119
- : -logits.shape[1] + i
120
- ], # input_ids may not equal logits.shape[1]
121
- logits[:, i, :],
122
- )
123
- logits[:, i, :] = normalized.clone()
124
- return F.softmax(logits, dim=-1)
125
-
126
-
127
- def sample(
128
- last_token_logits: torch.FloatTensor, temperature: float, top_p: float
129
- ) -> int:
130
- """
131
- Parameters
132
- ----------
133
- last_token_logits : torch.FloatTensor
134
- Last token logits of shape [n_vocab,]
135
-
136
- Returns
137
- -------
138
- int
139
- Token ID.
140
- """
141
- if temperature < 1e-5 or top_p < 1e-8: # greedy
142
- _, indices = torch.topk(last_token_logits, 2)
143
- tokens = [int(index) for index in indices.tolist()]
144
- else:
145
- indices = torch.multinomial(last_token_logits, num_samples=2)
146
- tokens = [int(token) for token in indices.tolist()]
147
- return tokens[0]
148
-
149
-
150
- def rollback_kv_cache(
151
- kv_cache: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], n: int
152
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
153
- ret = []
154
- for k_cache, v_cache in kv_cache:
155
- k_cache = k_cache[:, :, :-n, :] # [1, n_head, n_seq - n, n_dim]
156
- v_cache = v_cache[:, :, :-n, :]
157
-
158
- assert isinstance(k_cache, torch.Tensor)
159
- assert isinstance(v_cache, torch.Tensor)
160
- ret.append((k_cache, v_cache))
161
-
162
- return tuple(ret)
163
-
164
-
165
- def rollback_logits(logits: torch.Tensor, n: int):
166
- return logits[:, :-n, :] # [1, n_seq, n_vocab]
167
-
168
-
169
- def is_partial_stop(output: str, stop_str: str):
170
- """Check whether the output contains a partial stop str."""
171
- for i in range(0, min(len(output), len(stop_str))):
172
- if stop_str.startswith(output[-i:]):
173
- return True
174
- return False
175
-
176
-
177
- def draft(
178
- input_ids: List[int],
179
- kv_cache: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]],
180
- logits: Optional[torch.FloatTensor],
181
- draft_model: "PreTrainedModel",
182
- gamma: int,
183
- logits_processor: LogitsProcessorList,
184
- temperature: float,
185
- top_p: float,
186
- ):
187
- """
188
- Parameters
189
- ----------
190
- input_ids : List[int]
191
- On the prefill stage, `input_ids` are the prompt tokens.
192
-
193
- On the decode stage. It includes the prompt tokens, the token generated by the original model
194
- at the end of each full iteration, or the token generated by the draft model draft
195
- iteration.
196
-
197
- Returns
198
- -------
199
- int
200
- The number of generated draft tokens.
201
- List[int]
202
- Outputs, including the draft tokens.
203
- Tuple[Tuple[torch.Tensor, torch.Tensor], ...]
204
- KV cache.
205
- torch.FloatTensor
206
- Logits.
207
- """
208
- draft_output_ids = input_ids.copy()
209
-
210
- if kv_cache is not None:
211
- input_ids = draft_output_ids[-2:]
212
-
213
- num_draft_tokens = 0
214
- while num_draft_tokens < gamma:
215
- if kv_cache is None:
216
- # prefill.
217
- draft_model_out = draft_model(
218
- torch.as_tensor([input_ids], device=draft_model.device),
219
- use_cache=True,
220
- )
221
- logits = normalize_logits(
222
- logits_processor, input_ids, draft_model_out.logits
223
- )
224
- else:
225
- draft_model_out = draft_model(
226
- torch.as_tensor([input_ids], device=draft_model.device),
227
- use_cache=True,
228
- past_key_values=kv_cache,
229
- )
230
- normalized = normalize_logits(
231
- logits_processor, draft_output_ids, draft_model_out.logits
232
- )
233
- assert logits is not None
234
- logits = torch.cat((logits, normalized), dim=1)
235
- kv_cache = draft_model_out.past_key_values
236
- draft_token = sample(
237
- logits[0, -1, :],
238
- temperature,
239
- top_p,
240
- )
241
- draft_output_ids.append(draft_token)
242
- input_ids = [draft_token]
243
- num_draft_tokens += 1
244
-
245
- assert kv_cache is not None
246
- return num_draft_tokens, draft_output_ids, kv_cache, logits
247
-
248
-
249
- @torch.inference_mode()
250
- def speculative_generate_stream(
251
- model_uid: str,
252
- draft_model: "PreTrainedModel",
253
- model: "PreTrainedModel",
254
- tokenizer: "PreTrainedTokenizer",
255
- prompt: str,
256
- generate_config: Dict[str, Any],
257
- ) -> Iterator[Tuple[CompletionChunk, CompletionUsage]]:
258
- logger.debug(
259
- f"Enter speculative_generate_stream, prompt: {prompt}, generate_config: {generate_config}"
260
- )
261
-
262
- # TODO: currently, repetition penalty leads to garbled outputs.
263
- if float(generate_config.get("repetition_penalty", 1.0)) != 1.0:
264
- raise ValueError(
265
- "repetition penalty is not supported by speculative decoding yet"
266
- )
267
-
268
- gamma = generate_config.get("gamma", 4)
269
- stream = generate_config.get("stream", False)
270
- temperature = float(generate_config.get("temperature", 1.0))
271
- top_p = float(generate_config.get("top_p", 1.0))
272
- top_k = int(generate_config.get("top_k", -1)) # -1 means disable
273
- max_new_tokens = int(generate_config.get("max_tokens", 256))
274
- echo = bool(generate_config.get("echo", False))
275
- stop_str = generate_config.get("stop", None)
276
- stop_token_ids = generate_config.get("stop_token_ids", None) or []
277
- stop_token_ids.append(tokenizer.eos_token_id)
278
-
279
- logits_processor = prepare_logits_processor(temperature, top_p, top_k)
280
- request_id = str(uuid.uuid1())
281
-
282
- if "qwen" in str(type(model)).lower():
283
- # TODO: hacky.
284
- input_ids = tokenizer(prompt, allowed_special="all").input_ids
285
- else:
286
- input_ids = tokenizer(prompt).input_ids
287
-
288
- num_prompt_tokens = len(input_ids)
289
- output_ids = list(input_ids)
290
-
291
- # internal states.
292
- draft_kv_cache = None
293
- draft_logits = None
294
- kv_cache = None
295
- logits = None
296
- next_token = (
297
- None # the token generated by the original model at each full iteration.
298
- )
299
- last_output_length = 0
300
- finish_reason = "stop"
301
-
302
- # performance stats.
303
- total_seconds_on_drafting = 0.0
304
- total_seconds_on_eval = 0.0
305
- total_seconds_on_accepting = 0.0
306
- total_num_draft_tokens = 0
307
- total_num_accepted_tokens = 0
308
-
309
- while len(output_ids) < max_new_tokens + num_prompt_tokens:
310
- # allow the draft model to generate more than max_tokens since some of the generated
311
- # tokens could be rejected.
312
- start = time.time()
313
- num_draft_tokens, output_ids, draft_kv_cache, draft_logits = draft(
314
- input_ids=output_ids,
315
- kv_cache=draft_kv_cache,
316
- logits=draft_logits,
317
- draft_model=draft_model,
318
- gamma=gamma,
319
- logits_processor=logits_processor,
320
- temperature=temperature
321
- * 0.5, # make the draft model outputs less random for better quality.
322
- top_p=top_p,
323
- )
324
- total_seconds_on_drafting += time.time() - start
325
- total_num_draft_tokens += num_draft_tokens
326
-
327
- # eval stage.
328
- start = time.time()
329
- if kv_cache is None:
330
- # prefill.
331
- out = model(
332
- torch.as_tensor([output_ids], device=model.device), use_cache=True
333
- )
334
- logits = normalize_logits(logits_processor, output_ids, out.logits)
335
- else:
336
- out = model(
337
- torch.as_tensor(
338
- [[next_token] + output_ids[-num_draft_tokens:]], device=model.device
339
- ),
340
- use_cache=True,
341
- past_key_values=kv_cache,
342
- )
343
- normalized = normalize_logits(logits_processor, output_ids, out.logits)
344
- logits = torch.cat((logits, normalized), dim=1)
345
- kv_cache = out.past_key_values
346
- total_seconds_on_eval += time.time() - start
347
-
348
- # accepting stage.
349
- start = time.time()
350
- assert draft_logits is not None
351
- assert draft_kv_cache is not None
352
- accepted = 0
353
- stopped = False
354
- for draft_token_idx in range(-num_draft_tokens, 0):
355
- r = torch.rand(1, device=logits.device)
356
- draft_token = output_ids[draft_token_idx]
357
- token_logits = logits[:, draft_token_idx - 1, :] # [1, n_vocab,]
358
- draft_token_logits = draft_logits[:, draft_token_idx, :].to(
359
- logits.device
360
- ) # [1, n_vocab,]
361
- if token_logits[0, draft_token] / draft_token_logits[0, draft_token] > r:
362
- accepted += 1
363
- total_num_accepted_tokens += 1
364
- if draft_token in stop_token_ids:
365
- stopped = True
366
- else:
367
- if logger.getEffectiveLevel() <= logging.DEBUG:
368
- logger.debug(
369
- f"Accepted ({accepted}/{num_draft_tokens}): '{tokenizer.decode(output_ids[-num_draft_tokens: draft_token_idx])}'"
370
- )
371
- logger.debug(
372
- f"Rejected: '{tokenizer.decode(output_ids[draft_token_idx:])}'"
373
- )
374
- # rollback.
375
- output_ids = output_ids[:draft_token_idx]
376
- draft_kv_cache = rollback_kv_cache(
377
- draft_kv_cache, num_draft_tokens - accepted
378
- )
379
- kv_cache = rollback_kv_cache(kv_cache, num_draft_tokens - accepted)
380
- draft_logits = rollback_logits(
381
- draft_logits, num_draft_tokens - accepted
382
- )
383
- logits = rollback_logits(logits, num_draft_tokens - accepted)
384
-
385
- # sample the next token according to the modified distribution of shape [1, n_vocab]
386
- modified_dist = token_logits - draft_token_logits
387
- modified_dist = torch.where(
388
- modified_dist > 0, modified_dist, torch.zeros_like(modified_dist)
389
- )
390
- normalized = normalize_logits(
391
- logits_processor,
392
- output_ids,
393
- modified_dist.unsqueeze(1), # [1, 1, n_vocab]
394
- )
395
- next_token = sample(
396
- normalized[0, -1, :],
397
- 0, # must be 0, since the dist is quiet unified, higher temperature results in garbled text
398
- top_p,
399
- )
400
- output_ids.append(next_token)
401
- if logger.getEffectiveLevel() <= logging.DEBUG:
402
- logger.debug(f"Generated: '{tokenizer.decode([next_token])}'")
403
- if next_token in stop_token_ids:
404
- stopped = True
405
- break
406
-
407
- if accepted == num_draft_tokens:
408
- if logger.getEffectiveLevel() <= logging.DEBUG:
409
- logger.debug(
410
- f"Accepted ({accepted}/{num_draft_tokens}): '{tokenizer.decode(output_ids[-num_draft_tokens:])}'"
411
- )
412
- next_token = sample(
413
- logits[0, -1, :],
414
- temperature,
415
- top_p,
416
- )
417
- output_ids.append(next_token)
418
- if logger.getEffectiveLevel() <= logging.DEBUG:
419
- logger.debug(f"Generated: '{tokenizer.decode([next_token])}'")
420
- if next_token in stop_token_ids:
421
- stopped = True
422
-
423
- total_seconds_on_accepting += time.time() - start
424
-
425
- if (
426
- accepted > 0 # more than 2 tokens has been generated, flush.
427
- or len(output_ids) >= max_new_tokens
428
- or stopped
429
- ):
430
- output = tokenizer.decode(
431
- output_ids if echo else output_ids[num_prompt_tokens:],
432
- spaces_between_special_tokens=False,
433
- clean_up_tokenization_spaces=True,
434
- )
435
- rfind_start = len(prompt) if echo else 0
436
-
437
- partially_stopped = False
438
- if stop_str:
439
- if isinstance(stop_str, str):
440
- pos = output.rfind(stop_str, rfind_start)
441
- if pos != -1:
442
- output = output[:pos]
443
- stopped = True
444
- else:
445
- partially_stopped = is_partial_stop(output, stop_str)
446
- elif isinstance(stop_str, Iterable):
447
- for each_stop in stop_str:
448
- pos = output.rfind(each_stop, rfind_start)
449
- if pos != -1:
450
- output = output[:pos]
451
- stopped = True
452
- break
453
- else:
454
- partially_stopped = is_partial_stop(output, each_stop)
455
- if partially_stopped:
456
- break
457
- else:
458
- raise ValueError(f"Invalid stop field type {type(stop_str)}")
459
-
460
- if stream:
461
- # return the delta.
462
- output_length = len(output)
463
- output = output[last_output_length:]
464
- last_output_length = output_length
465
-
466
- # prevent yielding partial stop sequence.
467
- if not partially_stopped:
468
- completion_choice = CompletionChoice(
469
- text=output, index=0, logprobs=None, finish_reason=None
470
- )
471
- completion_chunk = CompletionChunk(
472
- id=request_id,
473
- object="text_completion",
474
- created=int(time.time()),
475
- model=model_uid,
476
- choices=[completion_choice],
477
- )
478
- completion_usage = CompletionUsage(
479
- prompt_tokens=num_prompt_tokens,
480
- completion_tokens=len(output_ids) - num_prompt_tokens,
481
- total_tokens=len(output_ids),
482
- )
483
-
484
- yield completion_chunk, completion_usage
485
- if stopped:
486
- break
487
- else:
488
- finish_reason = "length"
489
-
490
- logger.info(
491
- f"In total, {total_num_accepted_tokens}/{total_num_draft_tokens} draft tokens are "
492
- f"accepted, acceptance rate: {total_num_accepted_tokens / total_num_draft_tokens:.2f}"
493
- )
494
- total_seconds = (
495
- total_seconds_on_drafting + total_seconds_on_eval + total_seconds_on_accepting
496
- )
497
- logger.info(
498
- f"In total, {total_seconds_on_drafting:.2f}s, {total_seconds_on_eval:.2f}s and "
499
- f"{total_seconds_on_accepting:.2f}s are spent on drafting, eval, and accepting "
500
- f"respectively. Average generation speed: {(len(output_ids) - num_prompt_tokens) / total_seconds:.2f} tokens/s."
501
- )
502
-
503
- if stream:
504
- completion_choice = CompletionChoice(
505
- text="", index=0, logprobs=None, finish_reason=finish_reason
506
- )
507
- else:
508
- completion_choice = CompletionChoice(
509
- text=output, index=0, logprobs=None, finish_reason=finish_reason
510
- )
511
-
512
- completion_chunk = CompletionChunk(
513
- id=request_id,
514
- object="text_completion",
515
- created=int(time.time()),
516
- model=model_uid,
517
- choices=[completion_choice],
518
- )
519
- completion_usage = CompletionUsage(
520
- prompt_tokens=num_prompt_tokens,
521
- completion_tokens=len(output_ids) - num_prompt_tokens,
522
- total_tokens=len(output_ids),
523
- )
524
-
525
- yield completion_chunk, completion_usage
526
-
527
- # clean up.
528
- del kv_cache
529
- del draft_kv_cache
530
- gc.collect()
531
- empty_cache()