sglang 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +6 -0
- sglang/bench_latency.py +7 -3
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +9 -0
- sglang/launch_server.py +8 -1
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -1
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/fused_moe/layer.py +2 -2
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +60 -23
- sglang/srt/layers/radix_attention.py +3 -4
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +52 -167
- sglang/srt/managers/tokenizer_manager.py +192 -83
- sglang/srt/managers/tp_worker.py +130 -43
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +49 -11
- sglang/srt/model_executor/forward_batch_info.py +59 -27
- sglang/srt/model_executor/model_runner.py +210 -61
- sglang/srt/models/chatglm.py +4 -12
- sglang/srt/models/commandr.py +5 -1
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +5 -1
- sglang/srt/models/deepseek_v2.py +5 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +15 -7
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +16 -2
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +5 -1
- sglang/srt/models/mixtral.py +5 -1
- sglang/srt/models/mixtral_quant.py +5 -1
- sglang/srt/models/qwen.py +5 -2
- sglang/srt/models/qwen2.py +13 -3
- sglang/srt/models/qwen2_moe.py +5 -14
- sglang/srt/models/stablelm.py +5 -1
- sglang/srt/openai_api/adapter.py +117 -37
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -0
- sglang/srt/server.py +84 -56
- sglang/srt/server_args.py +43 -15
- sglang/srt/utils.py +26 -16
- sglang/test/runners.py +23 -31
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_utils.py +36 -53
- sglang/version.py +1 -1
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/METADATA +92 -25
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang-0.2.13.dist-info/RECORD +0 -112
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
45
45
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
+
from sglang.srt.layers.sampler import Sampler
|
48
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
50
|
|
50
51
|
|
@@ -333,6 +334,7 @@ class QuantMixtralForCausalLM(nn.Module):
|
|
333
334
|
self.model = MixtralModel(config, quant_config=quant_config)
|
334
335
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
335
336
|
self.logits_processor = LogitsProcessor(config)
|
337
|
+
self.sampler = Sampler()
|
336
338
|
|
337
339
|
@torch.no_grad()
|
338
340
|
def forward(
|
@@ -343,9 +345,11 @@ class QuantMixtralForCausalLM(nn.Module):
|
|
343
345
|
input_embeds: torch.Tensor = None,
|
344
346
|
) -> torch.Tensor:
|
345
347
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
346
|
-
|
348
|
+
logits_output = self.logits_processor(
|
347
349
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
348
350
|
)
|
351
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
352
|
+
return sample_output, logits_output
|
349
353
|
|
350
354
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
351
355
|
stacked_params_mapping = [
|
sglang/srt/models/qwen.py
CHANGED
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
39
39
|
from sglang.srt.layers.layernorm import RMSNorm
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
+
from sglang.srt.layers.sampler import Sampler
|
42
43
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
43
44
|
|
44
45
|
|
@@ -251,6 +252,7 @@ class QWenLMHeadModel(nn.Module):
|
|
251
252
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
252
253
|
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
253
254
|
self.logits_processor = LogitsProcessor(config)
|
255
|
+
self.sampler = Sampler()
|
254
256
|
|
255
257
|
@torch.no_grad()
|
256
258
|
def forward(
|
@@ -260,10 +262,11 @@ class QWenLMHeadModel(nn.Module):
|
|
260
262
|
input_metadata: InputMetadata,
|
261
263
|
):
|
262
264
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
263
|
-
|
265
|
+
logits_output = self.logits_processor(
|
264
266
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
265
267
|
)
|
266
|
-
|
268
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
269
|
+
return sample_output, logits_output
|
267
270
|
|
268
271
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
269
272
|
stacked_params_mapping = [
|
sglang/srt/models/qwen2.py
CHANGED
@@ -38,7 +38,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
38
38
|
from sglang.srt.layers.activation import SiluAndMul
|
39
39
|
from sglang.srt.layers.layernorm import RMSNorm
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
41
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
+
from sglang.srt.layers.sampler import Sampler
|
42
44
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
43
45
|
|
44
46
|
Qwen2Config = None
|
@@ -275,6 +277,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|
275
277
|
self.model = Qwen2Model(config, quant_config=quant_config)
|
276
278
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
277
279
|
self.logits_processor = LogitsProcessor(config)
|
280
|
+
self.sampler = Sampler()
|
281
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
278
282
|
|
279
283
|
@torch.no_grad()
|
280
284
|
def forward(
|
@@ -283,11 +287,17 @@ class Qwen2ForCausalLM(nn.Module):
|
|
283
287
|
positions: torch.Tensor,
|
284
288
|
input_metadata: InputMetadata,
|
285
289
|
input_embeds: torch.Tensor = None,
|
290
|
+
get_embedding: bool = False,
|
286
291
|
) -> torch.Tensor:
|
287
292
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
288
|
-
|
289
|
-
|
290
|
-
|
293
|
+
if not get_embedding:
|
294
|
+
logits_output = self.logits_processor(
|
295
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
296
|
+
)
|
297
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
298
|
+
return sample_output, logits_output
|
299
|
+
else:
|
300
|
+
return self.pooler(hidden_states, input_metadata)
|
291
301
|
|
292
302
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
293
303
|
stacked_params_mapping = [
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -35,10 +35,8 @@ from vllm.model_executor.layers.linear import (
|
|
35
35
|
ReplicatedLinear,
|
36
36
|
RowParallelLinear,
|
37
37
|
)
|
38
|
-
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
39
38
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
40
39
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
41
|
-
from vllm.model_executor.layers.sampler import Sampler
|
42
40
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
43
41
|
ParallelLMHead,
|
44
42
|
VocabParallelEmbedding,
|
@@ -49,6 +47,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
49
47
|
from sglang.srt.layers.layernorm import RMSNorm
|
50
48
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
51
49
|
from sglang.srt.layers.radix_attention import RadixAttention
|
50
|
+
from sglang.srt.layers.sampler import Sampler
|
52
51
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
53
52
|
|
54
53
|
|
@@ -366,6 +365,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
366
365
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
367
366
|
)
|
368
367
|
self.logits_processor = LogitsProcessor(config)
|
368
|
+
self.sampler = Sampler()
|
369
369
|
|
370
370
|
@torch.no_grad()
|
371
371
|
def forward(
|
@@ -376,20 +376,11 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
376
376
|
input_embeds: torch.Tensor = None,
|
377
377
|
) -> torch.Tensor:
|
378
378
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
379
|
-
|
379
|
+
logits_output = self.logits_processor(
|
380
380
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
381
381
|
)
|
382
|
-
|
383
|
-
|
384
|
-
self,
|
385
|
-
input_ids: torch.Tensor,
|
386
|
-
hidden_states: torch.Tensor,
|
387
|
-
input_metadata: InputMetadata,
|
388
|
-
) -> torch.Tensor:
|
389
|
-
logits = self.logits_processor(
|
390
|
-
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
391
|
-
)
|
392
|
-
return logits
|
382
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
383
|
+
return sample_output, logits_output
|
393
384
|
|
394
385
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
395
386
|
stacked_params_mapping = [
|
sglang/srt/models/stablelm.py
CHANGED
@@ -40,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
40
40
|
from sglang.srt.layers.activation import SiluAndMul
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
+
from sglang.srt.layers.sampler import Sampler
|
43
44
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
44
45
|
|
45
46
|
|
@@ -249,6 +250,7 @@ class StableLmForCausalLM(nn.Module):
|
|
249
250
|
self.model = StableLMEpochModel(config, quant_config=quant_config)
|
250
251
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
251
252
|
self.logits_processor = LogitsProcessor(config)
|
253
|
+
self.sampler = Sampler()
|
252
254
|
|
253
255
|
@torch.no_grad()
|
254
256
|
def forward(
|
@@ -259,9 +261,11 @@ class StableLmForCausalLM(nn.Module):
|
|
259
261
|
input_embeds: torch.Tensor = None,
|
260
262
|
) -> torch.Tensor:
|
261
263
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
262
|
-
|
264
|
+
logits_output = self.logits_processor(
|
263
265
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
264
266
|
)
|
267
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
268
|
+
return sample_output, logits_output
|
265
269
|
|
266
270
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
267
271
|
stacked_params_mapping = [
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -17,6 +17,7 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import asyncio
|
19
19
|
import json
|
20
|
+
import logging
|
20
21
|
import os
|
21
22
|
import time
|
22
23
|
import uuid
|
@@ -64,6 +65,8 @@ from sglang.srt.openai_api.protocol import (
|
|
64
65
|
UsageInfo,
|
65
66
|
)
|
66
67
|
|
68
|
+
logger = logging.getLogger(__name__)
|
69
|
+
|
67
70
|
chat_template_name = None
|
68
71
|
|
69
72
|
|
@@ -120,7 +123,7 @@ def create_streaming_error_response(
|
|
120
123
|
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
|
121
124
|
global chat_template_name
|
122
125
|
|
123
|
-
|
126
|
+
logger.info(f"Use chat template: {chat_template_arg}")
|
124
127
|
if not chat_template_exists(chat_template_arg):
|
125
128
|
if not os.path.exists(chat_template_arg):
|
126
129
|
raise RuntimeError(
|
@@ -276,6 +279,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
276
279
|
request_data = json.loads(line)
|
277
280
|
file_request_list.append(request_data)
|
278
281
|
body = request_data["body"]
|
282
|
+
|
283
|
+
# Although streaming is supported for standalone completions, it is not supported in
|
284
|
+
# batch mode (multiple completions in single request).
|
285
|
+
if body.get("stream", False):
|
286
|
+
raise ValueError("Streaming requests are not supported in batch mode")
|
287
|
+
|
279
288
|
if end_point == "/v1/chat/completions":
|
280
289
|
all_requests.append(ChatCompletionRequest(**body))
|
281
290
|
elif end_point == "/v1/completions":
|
@@ -346,7 +355,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
346
355
|
}
|
347
356
|
|
348
357
|
except Exception as e:
|
349
|
-
|
358
|
+
logger.error("error in SGLang:", e)
|
350
359
|
# Update batch status to "failed"
|
351
360
|
retrieve_batch = batch_storage[batch_id]
|
352
361
|
retrieve_batch.status = "failed"
|
@@ -383,20 +392,33 @@ async def v1_retrieve_file_content(file_id: str):
|
|
383
392
|
return StreamingResponse(iter_file(), media_type="application/octet-stream")
|
384
393
|
|
385
394
|
|
386
|
-
def v1_generate_request(all_requests):
|
395
|
+
def v1_generate_request(all_requests: List[CompletionRequest]):
|
387
396
|
prompts = []
|
388
397
|
sampling_params_list = []
|
389
398
|
return_logprobs = []
|
399
|
+
logprob_start_lens = []
|
390
400
|
top_logprobs_nums = []
|
391
|
-
first_prompt_type = type(all_requests[0].prompt)
|
392
401
|
|
402
|
+
# NOTE: with openai API, the prompt's logprobs are always not computed
|
403
|
+
first_prompt_type = type(all_requests[0].prompt)
|
393
404
|
for request in all_requests:
|
394
|
-
prompt = request.prompt
|
395
405
|
assert (
|
396
|
-
type(prompt) == first_prompt_type
|
406
|
+
type(request.prompt) == first_prompt_type
|
397
407
|
), "All prompts must be of the same type in file input settings"
|
398
|
-
|
408
|
+
if len(all_requests) > 1 and request.n > 1:
|
409
|
+
raise ValueError(
|
410
|
+
"Parallel sampling is not supported for completions from files"
|
411
|
+
)
|
412
|
+
if request.echo and request.logprobs:
|
413
|
+
logger.warning(
|
414
|
+
"Echo is not compatible with logprobs. "
|
415
|
+
"To compute logprobs of input prompt, please use SGLang /request API."
|
416
|
+
)
|
417
|
+
|
418
|
+
for request in all_requests:
|
419
|
+
prompts.append(request.prompt)
|
399
420
|
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
|
421
|
+
logprob_start_lens.append(-1)
|
400
422
|
top_logprobs_nums.append(
|
401
423
|
request.logprobs if request.logprobs is not None else 0
|
402
424
|
)
|
@@ -416,14 +438,11 @@ def v1_generate_request(all_requests):
|
|
416
438
|
"ignore_eos": request.ignore_eos,
|
417
439
|
}
|
418
440
|
)
|
419
|
-
if len(all_requests) > 1 and request.n > 1:
|
420
|
-
raise ValueError(
|
421
|
-
"Parallel sampling is not supported for completions from files"
|
422
|
-
)
|
423
441
|
|
424
442
|
if len(all_requests) == 1:
|
425
443
|
prompt = prompts[0]
|
426
444
|
sampling_params_list = sampling_params_list[0]
|
445
|
+
logprob_start_lens = logprob_start_lens[0]
|
427
446
|
return_logprobs = return_logprobs[0]
|
428
447
|
top_logprobs_nums = top_logprobs_nums[0]
|
429
448
|
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
@@ -441,6 +460,7 @@ def v1_generate_request(all_requests):
|
|
441
460
|
sampling_params=sampling_params_list,
|
442
461
|
return_logprob=return_logprobs,
|
443
462
|
top_logprobs_num=top_logprobs_nums,
|
463
|
+
logprob_start_len=logprob_start_lens,
|
444
464
|
return_text_in_logprobs=True,
|
445
465
|
stream=all_requests[0].stream,
|
446
466
|
)
|
@@ -580,27 +600,45 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
580
600
|
if adapted_request.stream:
|
581
601
|
|
582
602
|
async def generate_stream_resp():
|
583
|
-
|
584
|
-
|
603
|
+
stream_buffers = {}
|
604
|
+
n_prev_tokens = {}
|
605
|
+
prompt_tokens = {}
|
606
|
+
completion_tokens = {}
|
585
607
|
try:
|
586
608
|
async for content in tokenizer_manager.generate_request(
|
587
609
|
adapted_request, raw_request
|
588
610
|
):
|
611
|
+
index = content["index"]
|
612
|
+
|
613
|
+
stream_buffer = stream_buffers.get(index, "")
|
614
|
+
n_prev_token = n_prev_tokens.get(index, 0)
|
615
|
+
|
589
616
|
text = content["text"]
|
590
|
-
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
591
|
-
completion_tokens = content["meta_info"]["completion_tokens"]
|
617
|
+
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
618
|
+
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
592
619
|
|
593
620
|
if not stream_buffer: # The first chunk
|
594
621
|
if request.echo:
|
595
622
|
if isinstance(request.prompt, str):
|
596
623
|
# for the case of single str prompts
|
597
624
|
prompts = request.prompt
|
598
|
-
elif isinstance(request.prompt, list)
|
599
|
-
request.prompt[0],
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
625
|
+
elif isinstance(request.prompt, list):
|
626
|
+
if isinstance(request.prompt[0], str):
|
627
|
+
# for the case of multiple str prompts
|
628
|
+
prompts = request.prompt[index // request.n]
|
629
|
+
elif isinstance(request.prompt[0], int):
|
630
|
+
# for the case of single token ids prompt
|
631
|
+
prompts = tokenizer_manager.tokenizer.decode(
|
632
|
+
request.prompt, skip_special_tokens=True
|
633
|
+
)
|
634
|
+
elif isinstance(request.prompt[0], list) and isinstance(
|
635
|
+
request.prompt[0][0], int
|
636
|
+
):
|
637
|
+
# for the case of multiple token ids prompts
|
638
|
+
prompts = tokenizer_manager.tokenizer.decode(
|
639
|
+
request.prompt[index // request.n],
|
640
|
+
skip_special_tokens=True,
|
641
|
+
)
|
604
642
|
|
605
643
|
# Prepend prompt in response text.
|
606
644
|
text = prompts + text
|
@@ -637,7 +675,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
637
675
|
delta = text[len(stream_buffer) :]
|
638
676
|
stream_buffer = stream_buffer + delta
|
639
677
|
choice_data = CompletionResponseStreamChoice(
|
640
|
-
index=
|
678
|
+
index=index,
|
641
679
|
text=delta,
|
642
680
|
logprobs=logprobs,
|
643
681
|
finish_reason=format_finish_reason(
|
@@ -650,12 +688,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
650
688
|
choices=[choice_data],
|
651
689
|
model=request.model,
|
652
690
|
)
|
691
|
+
|
692
|
+
stream_buffers[index] = stream_buffer
|
693
|
+
n_prev_tokens[index] = n_prev_token
|
694
|
+
|
653
695
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
654
696
|
if request.stream_options and request.stream_options.include_usage:
|
697
|
+
total_prompt_tokens = sum(
|
698
|
+
tokens
|
699
|
+
for i, tokens in prompt_tokens.items()
|
700
|
+
if i % request.n == 0
|
701
|
+
)
|
702
|
+
total_completion_tokens = sum(
|
703
|
+
tokens for tokens in completion_tokens.values()
|
704
|
+
)
|
655
705
|
usage = UsageInfo(
|
656
|
-
prompt_tokens=
|
657
|
-
completion_tokens=
|
658
|
-
total_tokens=
|
706
|
+
prompt_tokens=total_prompt_tokens,
|
707
|
+
completion_tokens=total_completion_tokens,
|
708
|
+
total_tokens=total_prompt_tokens + total_completion_tokens,
|
659
709
|
)
|
660
710
|
|
661
711
|
final_usage_chunk = CompletionStreamResponse(
|
@@ -694,12 +744,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
694
744
|
return response
|
695
745
|
|
696
746
|
|
697
|
-
def v1_chat_generate_request(
|
747
|
+
def v1_chat_generate_request(
|
748
|
+
all_requests: List[ChatCompletionRequest], tokenizer_manager
|
749
|
+
):
|
698
750
|
input_ids = []
|
699
751
|
sampling_params_list = []
|
700
752
|
image_data_list = []
|
701
753
|
return_logprobs = []
|
754
|
+
logprob_start_lens = []
|
702
755
|
top_logprobs_nums = []
|
756
|
+
|
757
|
+
# NOTE: with openai API, the prompt's logprobs are always not computed
|
758
|
+
|
703
759
|
for request in all_requests:
|
704
760
|
# Prep the data needed for the underlying GenerateReqInput:
|
705
761
|
# - prompt: The full prompt string.
|
@@ -732,6 +788,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
732
788
|
image_data = None
|
733
789
|
input_ids.append(prompt_ids)
|
734
790
|
return_logprobs.append(request.logprobs)
|
791
|
+
logprob_start_lens.append(-1)
|
735
792
|
top_logprobs_nums.append(request.top_logprobs)
|
736
793
|
sampling_params_list.append(
|
737
794
|
{
|
@@ -758,17 +815,20 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
758
815
|
sampling_params_list = sampling_params_list[0]
|
759
816
|
image_data = image_data_list[0]
|
760
817
|
return_logprobs = return_logprobs[0]
|
818
|
+
logprob_start_lens = logprob_start_lens[0]
|
761
819
|
top_logprobs_nums = top_logprobs_nums[0]
|
762
820
|
else:
|
763
821
|
if isinstance(input_ids[0], str):
|
764
822
|
prompt_kwargs = {"text": input_ids}
|
765
823
|
else:
|
766
824
|
prompt_kwargs = {"input_ids": input_ids}
|
825
|
+
|
767
826
|
adapted_request = GenerateReqInput(
|
768
827
|
**prompt_kwargs,
|
769
828
|
image_data=image_data,
|
770
829
|
sampling_params=sampling_params_list,
|
771
830
|
return_logprob=return_logprobs,
|
831
|
+
logprob_start_len=logprob_start_lens,
|
772
832
|
top_logprobs_num=top_logprobs_nums,
|
773
833
|
stream=all_requests[0].stream,
|
774
834
|
return_text_in_logprobs=True,
|
@@ -892,16 +952,23 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
892
952
|
if adapted_request.stream:
|
893
953
|
|
894
954
|
async def generate_stream_resp():
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
955
|
+
is_firsts = {}
|
956
|
+
stream_buffers = {}
|
957
|
+
n_prev_tokens = {}
|
958
|
+
prompt_tokens = {}
|
959
|
+
completion_tokens = {}
|
899
960
|
try:
|
900
961
|
async for content in tokenizer_manager.generate_request(
|
901
962
|
adapted_request, raw_request
|
902
963
|
):
|
903
|
-
|
904
|
-
|
964
|
+
index = content["index"]
|
965
|
+
|
966
|
+
is_first = is_firsts.get(index, True)
|
967
|
+
stream_buffer = stream_buffers.get(index, "")
|
968
|
+
n_prev_token = n_prev_tokens.get(index, 0)
|
969
|
+
|
970
|
+
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
971
|
+
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
905
972
|
if request.logprobs:
|
906
973
|
logprobs = to_openai_style_logprobs(
|
907
974
|
output_token_logprobs=content["meta_info"][
|
@@ -951,7 +1018,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
951
1018
|
# First chunk with role
|
952
1019
|
is_first = False
|
953
1020
|
choice_data = ChatCompletionResponseStreamChoice(
|
954
|
-
index=
|
1021
|
+
index=index,
|
955
1022
|
delta=DeltaMessage(role="assistant"),
|
956
1023
|
finish_reason=format_finish_reason(
|
957
1024
|
content["meta_info"]["finish_reason"]
|
@@ -969,7 +1036,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
969
1036
|
delta = text[len(stream_buffer) :]
|
970
1037
|
stream_buffer = stream_buffer + delta
|
971
1038
|
choice_data = ChatCompletionResponseStreamChoice(
|
972
|
-
index=
|
1039
|
+
index=index,
|
973
1040
|
delta=DeltaMessage(content=delta),
|
974
1041
|
finish_reason=format_finish_reason(
|
975
1042
|
content["meta_info"]["finish_reason"]
|
@@ -981,12 +1048,25 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
981
1048
|
choices=[choice_data],
|
982
1049
|
model=request.model,
|
983
1050
|
)
|
1051
|
+
|
1052
|
+
is_firsts[index] = is_first
|
1053
|
+
stream_buffers[index] = stream_buffer
|
1054
|
+
n_prev_tokens[index] = n_prev_token
|
1055
|
+
|
984
1056
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
985
1057
|
if request.stream_options and request.stream_options.include_usage:
|
1058
|
+
total_prompt_tokens = sum(
|
1059
|
+
tokens
|
1060
|
+
for i, tokens in prompt_tokens.items()
|
1061
|
+
if i % request.n == 0
|
1062
|
+
)
|
1063
|
+
total_completion_tokens = sum(
|
1064
|
+
tokens for tokens in completion_tokens.values()
|
1065
|
+
)
|
986
1066
|
usage = UsageInfo(
|
987
|
-
prompt_tokens=
|
988
|
-
completion_tokens=
|
989
|
-
total_tokens=
|
1067
|
+
prompt_tokens=total_prompt_tokens,
|
1068
|
+
completion_tokens=total_completion_tokens,
|
1069
|
+
total_tokens=total_prompt_tokens + total_completion_tokens,
|
990
1070
|
)
|
991
1071
|
|
992
1072
|
final_usage_chunk = ChatCompletionStreamResponse(
|