sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +13 -1
- sglang/bench_latency.py +10 -5
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +5 -2
- sglang/lang/ir.py +22 -4
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -2
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +64 -27
- sglang/srt/layers/radix_attention.py +41 -18
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +59 -179
- sglang/srt/managers/tokenizer_manager.py +193 -84
- sglang/srt/managers/tp_worker.py +131 -50
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +97 -28
- sglang/srt/model_executor/forward_batch_info.py +188 -82
- sglang/srt/model_executor/model_runner.py +269 -87
- sglang/srt/models/chatglm.py +6 -14
- sglang/srt/models/commandr.py +6 -2
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +7 -3
- sglang/srt/models/deepseek_v2.py +12 -7
- sglang/srt/models/gemma.py +6 -2
- sglang/srt/models/gemma2.py +22 -8
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +66 -398
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +7 -3
- sglang/srt/models/mixtral.py +61 -255
- sglang/srt/models/mixtral_quant.py +6 -5
- sglang/srt/models/qwen.py +7 -4
- sglang/srt/models/qwen2.py +15 -5
- sglang/srt/models/qwen2_moe.py +7 -16
- sglang/srt/models/stablelm.py +6 -2
- sglang/srt/openai_api/adapter.py +149 -58
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
- sglang/srt/server.py +107 -71
- sglang/srt/server_args.py +49 -15
- sglang/srt/utils.py +27 -18
- sglang/test/runners.py +38 -38
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_programs.py +32 -5
- sglang/test/test_utils.py +37 -50
- sglang/version.py +1 -1
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.12.dist-info/RECORD +0 -112
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -28,27 +28,26 @@ from vllm.distributed import (
|
|
28
28
|
get_tensor_model_parallel_world_size,
|
29
29
|
tensor_model_parallel_all_reduce,
|
30
30
|
)
|
31
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
32
31
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
33
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
34
32
|
from vllm.model_executor.layers.linear import (
|
35
33
|
MergedColumnParallelLinear,
|
36
34
|
QKVParallelLinear,
|
37
35
|
ReplicatedLinear,
|
38
36
|
RowParallelLinear,
|
39
37
|
)
|
40
|
-
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
41
38
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
42
39
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
43
|
-
from vllm.model_executor.layers.sampler import Sampler
|
44
40
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
45
41
|
ParallelLMHead,
|
46
42
|
VocabParallelEmbedding,
|
47
43
|
)
|
48
44
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
49
45
|
|
46
|
+
from sglang.srt.layers.activation import SiluAndMul
|
47
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
50
48
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
51
49
|
from sglang.srt.layers.radix_attention import RadixAttention
|
50
|
+
from sglang.srt.layers.sampler import Sampler
|
52
51
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
53
52
|
|
54
53
|
|
@@ -366,6 +365,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
366
365
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
367
366
|
)
|
368
367
|
self.logits_processor = LogitsProcessor(config)
|
368
|
+
self.sampler = Sampler()
|
369
369
|
|
370
370
|
@torch.no_grad()
|
371
371
|
def forward(
|
@@ -376,20 +376,11 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
376
376
|
input_embeds: torch.Tensor = None,
|
377
377
|
) -> torch.Tensor:
|
378
378
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
379
|
-
|
379
|
+
logits_output = self.logits_processor(
|
380
380
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
381
381
|
)
|
382
|
-
|
383
|
-
|
384
|
-
self,
|
385
|
-
input_ids: torch.Tensor,
|
386
|
-
hidden_states: torch.Tensor,
|
387
|
-
input_metadata: InputMetadata,
|
388
|
-
) -> torch.Tensor:
|
389
|
-
logits = self.logits_processor(
|
390
|
-
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
391
|
-
)
|
392
|
-
return logits
|
382
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
383
|
+
return sample_output, logits_output
|
393
384
|
|
394
385
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
395
386
|
stacked_params_mapping = [
|
sglang/srt/models/stablelm.py
CHANGED
@@ -24,7 +24,6 @@ from torch import nn
|
|
24
24
|
from transformers import PretrainedConfig
|
25
25
|
from vllm.config import CacheConfig
|
26
26
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
27
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
28
27
|
from vllm.model_executor.layers.linear import (
|
29
28
|
MergedColumnParallelLinear,
|
30
29
|
QKVParallelLinear,
|
@@ -38,8 +37,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
38
37
|
)
|
39
38
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
39
|
|
40
|
+
from sglang.srt.layers.activation import SiluAndMul
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
+
from sglang.srt.layers.sampler import Sampler
|
43
44
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
44
45
|
|
45
46
|
|
@@ -249,6 +250,7 @@ class StableLmForCausalLM(nn.Module):
|
|
249
250
|
self.model = StableLMEpochModel(config, quant_config=quant_config)
|
250
251
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
251
252
|
self.logits_processor = LogitsProcessor(config)
|
253
|
+
self.sampler = Sampler()
|
252
254
|
|
253
255
|
@torch.no_grad()
|
254
256
|
def forward(
|
@@ -259,9 +261,11 @@ class StableLmForCausalLM(nn.Module):
|
|
259
261
|
input_embeds: torch.Tensor = None,
|
260
262
|
) -> torch.Tensor:
|
261
263
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
262
|
-
|
264
|
+
logits_output = self.logits_processor(
|
263
265
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
264
266
|
)
|
267
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
268
|
+
return sample_output, logits_output
|
265
269
|
|
266
270
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
267
271
|
stacked_params_mapping = [
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -17,6 +17,7 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import asyncio
|
19
19
|
import json
|
20
|
+
import logging
|
20
21
|
import os
|
21
22
|
import time
|
22
23
|
import uuid
|
@@ -64,6 +65,8 @@ from sglang.srt.openai_api.protocol import (
|
|
64
65
|
UsageInfo,
|
65
66
|
)
|
66
67
|
|
68
|
+
logger = logging.getLogger(__name__)
|
69
|
+
|
67
70
|
chat_template_name = None
|
68
71
|
|
69
72
|
|
@@ -117,37 +120,48 @@ def create_streaming_error_response(
|
|
117
120
|
return json_str
|
118
121
|
|
119
122
|
|
120
|
-
def load_chat_template_for_openai_api(chat_template_arg):
|
123
|
+
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
|
121
124
|
global chat_template_name
|
122
125
|
|
123
|
-
|
126
|
+
logger.info(f"Use chat template: {chat_template_arg}")
|
124
127
|
if not chat_template_exists(chat_template_arg):
|
125
128
|
if not os.path.exists(chat_template_arg):
|
126
129
|
raise RuntimeError(
|
127
130
|
f"Chat template {chat_template_arg} is not a built-in template name "
|
128
131
|
"or a valid chat template file path."
|
129
132
|
)
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
raise ValueError(
|
136
|
-
f"Unknown separator style: {template['sep_style']}"
|
137
|
-
) from None
|
138
|
-
register_conv_template(
|
139
|
-
Conversation(
|
140
|
-
name=template["name"],
|
141
|
-
system_template=template["system"] + "\n{system_message}",
|
142
|
-
system_message=template.get("system_message", ""),
|
143
|
-
roles=(template["user"], template["assistant"]),
|
144
|
-
sep_style=sep_style,
|
145
|
-
sep=template.get("sep", "\n"),
|
146
|
-
stop_str=template["stop_str"],
|
147
|
-
),
|
148
|
-
override=True,
|
133
|
+
if chat_template_arg.endswith(".jinja"):
|
134
|
+
with open(chat_template_arg, "r") as f:
|
135
|
+
chat_template = "".join(f.readlines()).strip("\n")
|
136
|
+
tokenizer_manager.tokenizer.chat_template = chat_template.replace(
|
137
|
+
"\\n", "\n"
|
149
138
|
)
|
150
|
-
|
139
|
+
chat_template_name = None
|
140
|
+
else:
|
141
|
+
assert chat_template_arg.endswith(
|
142
|
+
".json"
|
143
|
+
), "unrecognized format of chat template file"
|
144
|
+
with open(chat_template_arg, "r") as filep:
|
145
|
+
template = json.load(filep)
|
146
|
+
try:
|
147
|
+
sep_style = SeparatorStyle[template["sep_style"]]
|
148
|
+
except KeyError:
|
149
|
+
raise ValueError(
|
150
|
+
f"Unknown separator style: {template['sep_style']}"
|
151
|
+
) from None
|
152
|
+
register_conv_template(
|
153
|
+
Conversation(
|
154
|
+
name=template["name"],
|
155
|
+
system_template=template["system"] + "\n{system_message}",
|
156
|
+
system_message=template.get("system_message", ""),
|
157
|
+
roles=(template["user"], template["assistant"]),
|
158
|
+
sep_style=sep_style,
|
159
|
+
sep=template.get("sep", "\n"),
|
160
|
+
stop_str=template["stop_str"],
|
161
|
+
),
|
162
|
+
override=True,
|
163
|
+
)
|
164
|
+
chat_template_name = template["name"]
|
151
165
|
else:
|
152
166
|
chat_template_name = chat_template_arg
|
153
167
|
|
@@ -265,6 +279,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
265
279
|
request_data = json.loads(line)
|
266
280
|
file_request_list.append(request_data)
|
267
281
|
body = request_data["body"]
|
282
|
+
|
283
|
+
# Although streaming is supported for standalone completions, it is not supported in
|
284
|
+
# batch mode (multiple completions in single request).
|
285
|
+
if body.get("stream", False):
|
286
|
+
raise ValueError("Streaming requests are not supported in batch mode")
|
287
|
+
|
268
288
|
if end_point == "/v1/chat/completions":
|
269
289
|
all_requests.append(ChatCompletionRequest(**body))
|
270
290
|
elif end_point == "/v1/completions":
|
@@ -335,7 +355,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
335
355
|
}
|
336
356
|
|
337
357
|
except Exception as e:
|
338
|
-
|
358
|
+
logger.error("error in SGLang:", e)
|
339
359
|
# Update batch status to "failed"
|
340
360
|
retrieve_batch = batch_storage[batch_id]
|
341
361
|
retrieve_batch.status = "failed"
|
@@ -372,20 +392,33 @@ async def v1_retrieve_file_content(file_id: str):
|
|
372
392
|
return StreamingResponse(iter_file(), media_type="application/octet-stream")
|
373
393
|
|
374
394
|
|
375
|
-
def v1_generate_request(all_requests):
|
395
|
+
def v1_generate_request(all_requests: List[CompletionRequest]):
|
376
396
|
prompts = []
|
377
397
|
sampling_params_list = []
|
378
398
|
return_logprobs = []
|
399
|
+
logprob_start_lens = []
|
379
400
|
top_logprobs_nums = []
|
380
|
-
first_prompt_type = type(all_requests[0].prompt)
|
381
401
|
|
402
|
+
# NOTE: with openai API, the prompt's logprobs are always not computed
|
403
|
+
first_prompt_type = type(all_requests[0].prompt)
|
382
404
|
for request in all_requests:
|
383
|
-
prompt = request.prompt
|
384
405
|
assert (
|
385
|
-
type(prompt) == first_prompt_type
|
406
|
+
type(request.prompt) == first_prompt_type
|
386
407
|
), "All prompts must be of the same type in file input settings"
|
387
|
-
|
408
|
+
if len(all_requests) > 1 and request.n > 1:
|
409
|
+
raise ValueError(
|
410
|
+
"Parallel sampling is not supported for completions from files"
|
411
|
+
)
|
412
|
+
if request.echo and request.logprobs:
|
413
|
+
logger.warning(
|
414
|
+
"Echo is not compatible with logprobs. "
|
415
|
+
"To compute logprobs of input prompt, please use SGLang /request API."
|
416
|
+
)
|
417
|
+
|
418
|
+
for request in all_requests:
|
419
|
+
prompts.append(request.prompt)
|
388
420
|
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
|
421
|
+
logprob_start_lens.append(-1)
|
389
422
|
top_logprobs_nums.append(
|
390
423
|
request.logprobs if request.logprobs is not None else 0
|
391
424
|
)
|
@@ -405,14 +438,11 @@ def v1_generate_request(all_requests):
|
|
405
438
|
"ignore_eos": request.ignore_eos,
|
406
439
|
}
|
407
440
|
)
|
408
|
-
if len(all_requests) > 1 and request.n > 1:
|
409
|
-
raise ValueError(
|
410
|
-
"Parallel sampling is not supported for completions from files"
|
411
|
-
)
|
412
441
|
|
413
442
|
if len(all_requests) == 1:
|
414
443
|
prompt = prompts[0]
|
415
444
|
sampling_params_list = sampling_params_list[0]
|
445
|
+
logprob_start_lens = logprob_start_lens[0]
|
416
446
|
return_logprobs = return_logprobs[0]
|
417
447
|
top_logprobs_nums = top_logprobs_nums[0]
|
418
448
|
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
@@ -430,6 +460,7 @@ def v1_generate_request(all_requests):
|
|
430
460
|
sampling_params=sampling_params_list,
|
431
461
|
return_logprob=return_logprobs,
|
432
462
|
top_logprobs_num=top_logprobs_nums,
|
463
|
+
logprob_start_len=logprob_start_lens,
|
433
464
|
return_text_in_logprobs=True,
|
434
465
|
stream=all_requests[0].stream,
|
435
466
|
)
|
@@ -569,27 +600,45 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
569
600
|
if adapted_request.stream:
|
570
601
|
|
571
602
|
async def generate_stream_resp():
|
572
|
-
|
573
|
-
|
603
|
+
stream_buffers = {}
|
604
|
+
n_prev_tokens = {}
|
605
|
+
prompt_tokens = {}
|
606
|
+
completion_tokens = {}
|
574
607
|
try:
|
575
608
|
async for content in tokenizer_manager.generate_request(
|
576
609
|
adapted_request, raw_request
|
577
610
|
):
|
611
|
+
index = content["index"]
|
612
|
+
|
613
|
+
stream_buffer = stream_buffers.get(index, "")
|
614
|
+
n_prev_token = n_prev_tokens.get(index, 0)
|
615
|
+
|
578
616
|
text = content["text"]
|
579
|
-
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
580
|
-
completion_tokens = content["meta_info"]["completion_tokens"]
|
617
|
+
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
618
|
+
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
581
619
|
|
582
620
|
if not stream_buffer: # The first chunk
|
583
621
|
if request.echo:
|
584
622
|
if isinstance(request.prompt, str):
|
585
623
|
# for the case of single str prompts
|
586
624
|
prompts = request.prompt
|
587
|
-
elif isinstance(request.prompt, list)
|
588
|
-
request.prompt[0],
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
625
|
+
elif isinstance(request.prompt, list):
|
626
|
+
if isinstance(request.prompt[0], str):
|
627
|
+
# for the case of multiple str prompts
|
628
|
+
prompts = request.prompt[index // request.n]
|
629
|
+
elif isinstance(request.prompt[0], int):
|
630
|
+
# for the case of single token ids prompt
|
631
|
+
prompts = tokenizer_manager.tokenizer.decode(
|
632
|
+
request.prompt, skip_special_tokens=True
|
633
|
+
)
|
634
|
+
elif isinstance(request.prompt[0], list) and isinstance(
|
635
|
+
request.prompt[0][0], int
|
636
|
+
):
|
637
|
+
# for the case of multiple token ids prompts
|
638
|
+
prompts = tokenizer_manager.tokenizer.decode(
|
639
|
+
request.prompt[index // request.n],
|
640
|
+
skip_special_tokens=True,
|
641
|
+
)
|
593
642
|
|
594
643
|
# Prepend prompt in response text.
|
595
644
|
text = prompts + text
|
@@ -626,7 +675,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
626
675
|
delta = text[len(stream_buffer) :]
|
627
676
|
stream_buffer = stream_buffer + delta
|
628
677
|
choice_data = CompletionResponseStreamChoice(
|
629
|
-
index=
|
678
|
+
index=index,
|
630
679
|
text=delta,
|
631
680
|
logprobs=logprobs,
|
632
681
|
finish_reason=format_finish_reason(
|
@@ -639,12 +688,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
639
688
|
choices=[choice_data],
|
640
689
|
model=request.model,
|
641
690
|
)
|
691
|
+
|
692
|
+
stream_buffers[index] = stream_buffer
|
693
|
+
n_prev_tokens[index] = n_prev_token
|
694
|
+
|
642
695
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
643
696
|
if request.stream_options and request.stream_options.include_usage:
|
697
|
+
total_prompt_tokens = sum(
|
698
|
+
tokens
|
699
|
+
for i, tokens in prompt_tokens.items()
|
700
|
+
if i % request.n == 0
|
701
|
+
)
|
702
|
+
total_completion_tokens = sum(
|
703
|
+
tokens for tokens in completion_tokens.values()
|
704
|
+
)
|
644
705
|
usage = UsageInfo(
|
645
|
-
prompt_tokens=
|
646
|
-
completion_tokens=
|
647
|
-
total_tokens=
|
706
|
+
prompt_tokens=total_prompt_tokens,
|
707
|
+
completion_tokens=total_completion_tokens,
|
708
|
+
total_tokens=total_prompt_tokens + total_completion_tokens,
|
648
709
|
)
|
649
710
|
|
650
711
|
final_usage_chunk = CompletionStreamResponse(
|
@@ -683,12 +744,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
683
744
|
return response
|
684
745
|
|
685
746
|
|
686
|
-
def v1_chat_generate_request(
|
747
|
+
def v1_chat_generate_request(
|
748
|
+
all_requests: List[ChatCompletionRequest], tokenizer_manager
|
749
|
+
):
|
687
750
|
input_ids = []
|
688
751
|
sampling_params_list = []
|
689
752
|
image_data_list = []
|
690
753
|
return_logprobs = []
|
754
|
+
logprob_start_lens = []
|
691
755
|
top_logprobs_nums = []
|
756
|
+
|
757
|
+
# NOTE: with openai API, the prompt's logprobs are always not computed
|
758
|
+
|
692
759
|
for request in all_requests:
|
693
760
|
# Prep the data needed for the underlying GenerateReqInput:
|
694
761
|
# - prompt: The full prompt string.
|
@@ -721,6 +788,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
721
788
|
image_data = None
|
722
789
|
input_ids.append(prompt_ids)
|
723
790
|
return_logprobs.append(request.logprobs)
|
791
|
+
logprob_start_lens.append(-1)
|
724
792
|
top_logprobs_nums.append(request.top_logprobs)
|
725
793
|
sampling_params_list.append(
|
726
794
|
{
|
@@ -747,17 +815,20 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
747
815
|
sampling_params_list = sampling_params_list[0]
|
748
816
|
image_data = image_data_list[0]
|
749
817
|
return_logprobs = return_logprobs[0]
|
818
|
+
logprob_start_lens = logprob_start_lens[0]
|
750
819
|
top_logprobs_nums = top_logprobs_nums[0]
|
751
820
|
else:
|
752
821
|
if isinstance(input_ids[0], str):
|
753
822
|
prompt_kwargs = {"text": input_ids}
|
754
823
|
else:
|
755
824
|
prompt_kwargs = {"input_ids": input_ids}
|
825
|
+
|
756
826
|
adapted_request = GenerateReqInput(
|
757
827
|
**prompt_kwargs,
|
758
828
|
image_data=image_data,
|
759
829
|
sampling_params=sampling_params_list,
|
760
830
|
return_logprob=return_logprobs,
|
831
|
+
logprob_start_len=logprob_start_lens,
|
761
832
|
top_logprobs_num=top_logprobs_nums,
|
762
833
|
stream=all_requests[0].stream,
|
763
834
|
return_text_in_logprobs=True,
|
@@ -881,16 +952,23 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
881
952
|
if adapted_request.stream:
|
882
953
|
|
883
954
|
async def generate_stream_resp():
|
884
|
-
|
885
|
-
|
886
|
-
|
887
|
-
|
955
|
+
is_firsts = {}
|
956
|
+
stream_buffers = {}
|
957
|
+
n_prev_tokens = {}
|
958
|
+
prompt_tokens = {}
|
959
|
+
completion_tokens = {}
|
888
960
|
try:
|
889
961
|
async for content in tokenizer_manager.generate_request(
|
890
962
|
adapted_request, raw_request
|
891
963
|
):
|
892
|
-
|
893
|
-
|
964
|
+
index = content["index"]
|
965
|
+
|
966
|
+
is_first = is_firsts.get(index, True)
|
967
|
+
stream_buffer = stream_buffers.get(index, "")
|
968
|
+
n_prev_token = n_prev_tokens.get(index, 0)
|
969
|
+
|
970
|
+
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
971
|
+
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
894
972
|
if request.logprobs:
|
895
973
|
logprobs = to_openai_style_logprobs(
|
896
974
|
output_token_logprobs=content["meta_info"][
|
@@ -940,7 +1018,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
940
1018
|
# First chunk with role
|
941
1019
|
is_first = False
|
942
1020
|
choice_data = ChatCompletionResponseStreamChoice(
|
943
|
-
index=
|
1021
|
+
index=index,
|
944
1022
|
delta=DeltaMessage(role="assistant"),
|
945
1023
|
finish_reason=format_finish_reason(
|
946
1024
|
content["meta_info"]["finish_reason"]
|
@@ -958,7 +1036,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
958
1036
|
delta = text[len(stream_buffer) :]
|
959
1037
|
stream_buffer = stream_buffer + delta
|
960
1038
|
choice_data = ChatCompletionResponseStreamChoice(
|
961
|
-
index=
|
1039
|
+
index=index,
|
962
1040
|
delta=DeltaMessage(content=delta),
|
963
1041
|
finish_reason=format_finish_reason(
|
964
1042
|
content["meta_info"]["finish_reason"]
|
@@ -970,12 +1048,25 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
970
1048
|
choices=[choice_data],
|
971
1049
|
model=request.model,
|
972
1050
|
)
|
1051
|
+
|
1052
|
+
is_firsts[index] = is_first
|
1053
|
+
stream_buffers[index] = stream_buffer
|
1054
|
+
n_prev_tokens[index] = n_prev_token
|
1055
|
+
|
973
1056
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
974
1057
|
if request.stream_options and request.stream_options.include_usage:
|
1058
|
+
total_prompt_tokens = sum(
|
1059
|
+
tokens
|
1060
|
+
for i, tokens in prompt_tokens.items()
|
1061
|
+
if i % request.n == 0
|
1062
|
+
)
|
1063
|
+
total_completion_tokens = sum(
|
1064
|
+
tokens for tokens in completion_tokens.values()
|
1065
|
+
)
|
975
1066
|
usage = UsageInfo(
|
976
|
-
prompt_tokens=
|
977
|
-
completion_tokens=
|
978
|
-
total_tokens=
|
1067
|
+
prompt_tokens=total_prompt_tokens,
|
1068
|
+
completion_tokens=total_completion_tokens,
|
1069
|
+
total_tokens=total_prompt_tokens + total_completion_tokens,
|
979
1070
|
)
|
980
1071
|
|
981
1072
|
final_usage_chunk = ChatCompletionStreamResponse(
|