sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +6 -0
- sglang/bench_latency.py +7 -3
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +9 -0
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/fsm_cache.py +11 -2
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +100 -1
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/fused_moe/layer.py +2 -2
- sglang/srt/layers/logits_processor.py +56 -19
- sglang/srt/layers/radix_attention.py +3 -4
- sglang/srt/layers/sampler.py +101 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +46 -166
- sglang/srt/managers/tokenizer_manager.py +192 -83
- sglang/srt/managers/tp_worker.py +118 -24
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +32 -8
- sglang/srt/model_executor/forward_batch_info.py +51 -26
- sglang/srt/model_executor/model_runner.py +201 -58
- sglang/srt/models/gemma2.py +10 -6
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +11 -1
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/qwen2.py +9 -3
- sglang/srt/openai_api/adapter.py +200 -39
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_batch_info.py +136 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
- sglang/srt/server.py +92 -57
- sglang/srt/server_args.py +43 -15
- sglang/srt/utils.py +26 -16
- sglang/test/runners.py +22 -30
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_utils.py +36 -53
- sglang/version.py +1 -1
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
- sglang-0.2.14.post1.dist-info/RECORD +114 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang-0.2.13.dist-info/RECORD +0 -112
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
sglang/srt/openai_api/adapter.py
CHANGED
@@ -17,6 +17,7 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import asyncio
|
19
19
|
import json
|
20
|
+
import logging
|
20
21
|
import os
|
21
22
|
import time
|
22
23
|
import uuid
|
@@ -64,6 +65,8 @@ from sglang.srt.openai_api.protocol import (
|
|
64
65
|
UsageInfo,
|
65
66
|
)
|
66
67
|
|
68
|
+
logger = logging.getLogger(__name__)
|
69
|
+
|
67
70
|
chat_template_name = None
|
68
71
|
|
69
72
|
|
@@ -120,7 +123,7 @@ def create_streaming_error_response(
|
|
120
123
|
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
|
121
124
|
global chat_template_name
|
122
125
|
|
123
|
-
|
126
|
+
logger.info(f"Use chat template: {chat_template_arg}")
|
124
127
|
if not chat_template_exists(chat_template_arg):
|
125
128
|
if not os.path.exists(chat_template_arg):
|
126
129
|
raise RuntimeError(
|
@@ -272,20 +275,32 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
272
275
|
end_point = batch_storage[batch_id].endpoint
|
273
276
|
file_request_list = []
|
274
277
|
all_requests = []
|
278
|
+
request_ids = []
|
275
279
|
for line in lines:
|
276
280
|
request_data = json.loads(line)
|
277
281
|
file_request_list.append(request_data)
|
278
282
|
body = request_data["body"]
|
283
|
+
request_ids.append(request_data["custom_id"])
|
284
|
+
|
285
|
+
# Although streaming is supported for standalone completions, it is not supported in
|
286
|
+
# batch mode (multiple completions in single request).
|
287
|
+
if body.get("stream", False):
|
288
|
+
raise ValueError("Streaming requests are not supported in batch mode")
|
289
|
+
|
279
290
|
if end_point == "/v1/chat/completions":
|
280
291
|
all_requests.append(ChatCompletionRequest(**body))
|
281
292
|
elif end_point == "/v1/completions":
|
282
293
|
all_requests.append(CompletionRequest(**body))
|
294
|
+
|
283
295
|
if end_point == "/v1/chat/completions":
|
284
296
|
adapted_request, request = v1_chat_generate_request(
|
285
|
-
all_requests, tokenizer_manager
|
297
|
+
all_requests, tokenizer_manager, request_ids=request_ids
|
286
298
|
)
|
287
299
|
elif end_point == "/v1/completions":
|
288
|
-
adapted_request, request = v1_generate_request(
|
300
|
+
adapted_request, request = v1_generate_request(
|
301
|
+
all_requests, request_ids=request_ids
|
302
|
+
)
|
303
|
+
|
289
304
|
try:
|
290
305
|
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
|
291
306
|
if not isinstance(ret, list):
|
@@ -317,6 +332,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
317
332
|
}
|
318
333
|
all_ret.append(response_json)
|
319
334
|
completed_requests += 1
|
335
|
+
|
320
336
|
# Write results to a new file
|
321
337
|
output_file_id = f"backend_result_file-{uuid.uuid4()}"
|
322
338
|
global storage_dir
|
@@ -346,7 +362,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
346
362
|
}
|
347
363
|
|
348
364
|
except Exception as e:
|
349
|
-
|
365
|
+
logger.error("error in SGLang:", e)
|
350
366
|
# Update batch status to "failed"
|
351
367
|
retrieve_batch = batch_storage[batch_id]
|
352
368
|
retrieve_batch.status = "failed"
|
@@ -363,6 +379,72 @@ async def v1_retrieve_batch(batch_id: str):
|
|
363
379
|
return batch_response
|
364
380
|
|
365
381
|
|
382
|
+
async def v1_cancel_batch(tokenizer_manager, batch_id: str):
|
383
|
+
# Retrieve the batch job from the in-memory storage
|
384
|
+
batch_response = batch_storage.get(batch_id)
|
385
|
+
if batch_response is None:
|
386
|
+
raise HTTPException(status_code=404, detail="Batch not found")
|
387
|
+
|
388
|
+
# Only do cancal when status is "validating" or "in_progress"
|
389
|
+
if batch_response.status in ["validating", "in_progress"]:
|
390
|
+
# Start cancelling the batch asynchronously
|
391
|
+
asyncio.create_task(
|
392
|
+
cancel_batch(
|
393
|
+
tokenizer_manager=tokenizer_manager,
|
394
|
+
batch_id=batch_id,
|
395
|
+
input_file_id=batch_response.input_file_id,
|
396
|
+
)
|
397
|
+
)
|
398
|
+
|
399
|
+
# Update batch status to "cancelling"
|
400
|
+
batch_response.status = "cancelling"
|
401
|
+
|
402
|
+
return batch_response
|
403
|
+
else:
|
404
|
+
raise HTTPException(
|
405
|
+
status_code=500,
|
406
|
+
detail=f"Current status is {batch_response.status}, no need to cancel",
|
407
|
+
)
|
408
|
+
|
409
|
+
|
410
|
+
async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
|
411
|
+
try:
|
412
|
+
# Update the batch status to "cancelling"
|
413
|
+
batch_storage[batch_id].status = "cancelling"
|
414
|
+
|
415
|
+
# Retrieve the input file content
|
416
|
+
input_file_request = file_id_request.get(input_file_id)
|
417
|
+
if not input_file_request:
|
418
|
+
raise ValueError("Input file not found")
|
419
|
+
|
420
|
+
# Parse the JSONL file and process each request
|
421
|
+
input_file_path = file_id_storage.get(input_file_id)
|
422
|
+
with open(input_file_path, "r", encoding="utf-8") as f:
|
423
|
+
lines = f.readlines()
|
424
|
+
|
425
|
+
file_request_list = []
|
426
|
+
request_ids = []
|
427
|
+
for line in lines:
|
428
|
+
request_data = json.loads(line)
|
429
|
+
file_request_list.append(request_data)
|
430
|
+
request_ids.append(request_data["custom_id"])
|
431
|
+
|
432
|
+
# Cancel requests by request_ids
|
433
|
+
for rid in request_ids:
|
434
|
+
tokenizer_manager.abort_request(rid=rid)
|
435
|
+
|
436
|
+
retrieve_batch = batch_storage[batch_id]
|
437
|
+
retrieve_batch.status = "cancelled"
|
438
|
+
|
439
|
+
except Exception as e:
|
440
|
+
logger.error("error in SGLang:", e)
|
441
|
+
# Update batch status to "failed"
|
442
|
+
retrieve_batch = batch_storage[batch_id]
|
443
|
+
retrieve_batch.status = "failed"
|
444
|
+
retrieve_batch.failed_at = int(time.time())
|
445
|
+
retrieve_batch.errors = {"message": str(e)}
|
446
|
+
|
447
|
+
|
366
448
|
async def v1_retrieve_file(file_id: str):
|
367
449
|
# Retrieve the batch job from the in-memory storage
|
368
450
|
file_response = file_id_response.get(file_id)
|
@@ -383,20 +465,35 @@ async def v1_retrieve_file_content(file_id: str):
|
|
383
465
|
return StreamingResponse(iter_file(), media_type="application/octet-stream")
|
384
466
|
|
385
467
|
|
386
|
-
def v1_generate_request(
|
468
|
+
def v1_generate_request(
|
469
|
+
all_requests: List[CompletionRequest], request_ids: List[str] = None
|
470
|
+
):
|
387
471
|
prompts = []
|
388
472
|
sampling_params_list = []
|
389
473
|
return_logprobs = []
|
474
|
+
logprob_start_lens = []
|
390
475
|
top_logprobs_nums = []
|
391
|
-
first_prompt_type = type(all_requests[0].prompt)
|
392
476
|
|
477
|
+
# NOTE: with openai API, the prompt's logprobs are always not computed
|
478
|
+
first_prompt_type = type(all_requests[0].prompt)
|
393
479
|
for request in all_requests:
|
394
|
-
prompt = request.prompt
|
395
480
|
assert (
|
396
|
-
type(prompt) == first_prompt_type
|
481
|
+
type(request.prompt) == first_prompt_type
|
397
482
|
), "All prompts must be of the same type in file input settings"
|
398
|
-
|
483
|
+
if len(all_requests) > 1 and request.n > 1:
|
484
|
+
raise ValueError(
|
485
|
+
"Parallel sampling is not supported for completions from files"
|
486
|
+
)
|
487
|
+
if request.echo and request.logprobs:
|
488
|
+
logger.warning(
|
489
|
+
"Echo is not compatible with logprobs. "
|
490
|
+
"To compute logprobs of input prompt, please use SGLang /request API."
|
491
|
+
)
|
492
|
+
|
493
|
+
for request in all_requests:
|
494
|
+
prompts.append(request.prompt)
|
399
495
|
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
|
496
|
+
logprob_start_lens.append(-1)
|
400
497
|
top_logprobs_nums.append(
|
401
498
|
request.logprobs if request.logprobs is not None else 0
|
402
499
|
)
|
@@ -412,18 +509,16 @@ def v1_generate_request(all_requests):
|
|
412
509
|
"frequency_penalty": request.frequency_penalty,
|
413
510
|
"repetition_penalty": request.repetition_penalty,
|
414
511
|
"regex": request.regex,
|
512
|
+
"json_schema": request.json_schema,
|
415
513
|
"n": request.n,
|
416
514
|
"ignore_eos": request.ignore_eos,
|
417
515
|
}
|
418
516
|
)
|
419
|
-
if len(all_requests) > 1 and request.n > 1:
|
420
|
-
raise ValueError(
|
421
|
-
"Parallel sampling is not supported for completions from files"
|
422
|
-
)
|
423
517
|
|
424
518
|
if len(all_requests) == 1:
|
425
519
|
prompt = prompts[0]
|
426
520
|
sampling_params_list = sampling_params_list[0]
|
521
|
+
logprob_start_lens = logprob_start_lens[0]
|
427
522
|
return_logprobs = return_logprobs[0]
|
428
523
|
top_logprobs_nums = top_logprobs_nums[0]
|
429
524
|
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
@@ -441,8 +536,10 @@ def v1_generate_request(all_requests):
|
|
441
536
|
sampling_params=sampling_params_list,
|
442
537
|
return_logprob=return_logprobs,
|
443
538
|
top_logprobs_num=top_logprobs_nums,
|
539
|
+
logprob_start_len=logprob_start_lens,
|
444
540
|
return_text_in_logprobs=True,
|
445
541
|
stream=all_requests[0].stream,
|
542
|
+
rid=request_ids,
|
446
543
|
)
|
447
544
|
|
448
545
|
if len(all_requests) == 1:
|
@@ -580,27 +677,45 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
580
677
|
if adapted_request.stream:
|
581
678
|
|
582
679
|
async def generate_stream_resp():
|
583
|
-
|
584
|
-
|
680
|
+
stream_buffers = {}
|
681
|
+
n_prev_tokens = {}
|
682
|
+
prompt_tokens = {}
|
683
|
+
completion_tokens = {}
|
585
684
|
try:
|
586
685
|
async for content in tokenizer_manager.generate_request(
|
587
686
|
adapted_request, raw_request
|
588
687
|
):
|
688
|
+
index = content["index"]
|
689
|
+
|
690
|
+
stream_buffer = stream_buffers.get(index, "")
|
691
|
+
n_prev_token = n_prev_tokens.get(index, 0)
|
692
|
+
|
589
693
|
text = content["text"]
|
590
|
-
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
591
|
-
completion_tokens = content["meta_info"]["completion_tokens"]
|
694
|
+
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
695
|
+
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
592
696
|
|
593
697
|
if not stream_buffer: # The first chunk
|
594
698
|
if request.echo:
|
595
699
|
if isinstance(request.prompt, str):
|
596
700
|
# for the case of single str prompts
|
597
701
|
prompts = request.prompt
|
598
|
-
elif isinstance(request.prompt, list)
|
599
|
-
request.prompt[0],
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
702
|
+
elif isinstance(request.prompt, list):
|
703
|
+
if isinstance(request.prompt[0], str):
|
704
|
+
# for the case of multiple str prompts
|
705
|
+
prompts = request.prompt[index // request.n]
|
706
|
+
elif isinstance(request.prompt[0], int):
|
707
|
+
# for the case of single token ids prompt
|
708
|
+
prompts = tokenizer_manager.tokenizer.decode(
|
709
|
+
request.prompt, skip_special_tokens=True
|
710
|
+
)
|
711
|
+
elif isinstance(request.prompt[0], list) and isinstance(
|
712
|
+
request.prompt[0][0], int
|
713
|
+
):
|
714
|
+
# for the case of multiple token ids prompts
|
715
|
+
prompts = tokenizer_manager.tokenizer.decode(
|
716
|
+
request.prompt[index // request.n],
|
717
|
+
skip_special_tokens=True,
|
718
|
+
)
|
604
719
|
|
605
720
|
# Prepend prompt in response text.
|
606
721
|
text = prompts + text
|
@@ -637,7 +752,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
637
752
|
delta = text[len(stream_buffer) :]
|
638
753
|
stream_buffer = stream_buffer + delta
|
639
754
|
choice_data = CompletionResponseStreamChoice(
|
640
|
-
index=
|
755
|
+
index=index,
|
641
756
|
text=delta,
|
642
757
|
logprobs=logprobs,
|
643
758
|
finish_reason=format_finish_reason(
|
@@ -650,12 +765,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
650
765
|
choices=[choice_data],
|
651
766
|
model=request.model,
|
652
767
|
)
|
768
|
+
|
769
|
+
stream_buffers[index] = stream_buffer
|
770
|
+
n_prev_tokens[index] = n_prev_token
|
771
|
+
|
653
772
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
654
773
|
if request.stream_options and request.stream_options.include_usage:
|
774
|
+
total_prompt_tokens = sum(
|
775
|
+
tokens
|
776
|
+
for i, tokens in prompt_tokens.items()
|
777
|
+
if i % request.n == 0
|
778
|
+
)
|
779
|
+
total_completion_tokens = sum(
|
780
|
+
tokens for tokens in completion_tokens.values()
|
781
|
+
)
|
655
782
|
usage = UsageInfo(
|
656
|
-
prompt_tokens=
|
657
|
-
completion_tokens=
|
658
|
-
total_tokens=
|
783
|
+
prompt_tokens=total_prompt_tokens,
|
784
|
+
completion_tokens=total_completion_tokens,
|
785
|
+
total_tokens=total_prompt_tokens + total_completion_tokens,
|
659
786
|
)
|
660
787
|
|
661
788
|
final_usage_chunk = CompletionStreamResponse(
|
@@ -694,12 +821,20 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
694
821
|
return response
|
695
822
|
|
696
823
|
|
697
|
-
def v1_chat_generate_request(
|
824
|
+
def v1_chat_generate_request(
|
825
|
+
all_requests: List[ChatCompletionRequest],
|
826
|
+
tokenizer_manager,
|
827
|
+
request_ids: List[str] = None,
|
828
|
+
):
|
698
829
|
input_ids = []
|
699
830
|
sampling_params_list = []
|
700
831
|
image_data_list = []
|
701
832
|
return_logprobs = []
|
833
|
+
logprob_start_lens = []
|
702
834
|
top_logprobs_nums = []
|
835
|
+
|
836
|
+
# NOTE: with openai API, the prompt's logprobs are always not computed
|
837
|
+
|
703
838
|
for request in all_requests:
|
704
839
|
# Prep the data needed for the underlying GenerateReqInput:
|
705
840
|
# - prompt: The full prompt string.
|
@@ -732,6 +867,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
732
867
|
image_data = None
|
733
868
|
input_ids.append(prompt_ids)
|
734
869
|
return_logprobs.append(request.logprobs)
|
870
|
+
logprob_start_lens.append(-1)
|
735
871
|
top_logprobs_nums.append(request.top_logprobs)
|
736
872
|
sampling_params_list.append(
|
737
873
|
{
|
@@ -745,6 +881,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
745
881
|
"frequency_penalty": request.frequency_penalty,
|
746
882
|
"repetition_penalty": request.repetition_penalty,
|
747
883
|
"regex": request.regex,
|
884
|
+
"json_schema": request.json_schema,
|
748
885
|
"n": request.n,
|
749
886
|
}
|
750
887
|
)
|
@@ -758,20 +895,24 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
758
895
|
sampling_params_list = sampling_params_list[0]
|
759
896
|
image_data = image_data_list[0]
|
760
897
|
return_logprobs = return_logprobs[0]
|
898
|
+
logprob_start_lens = logprob_start_lens[0]
|
761
899
|
top_logprobs_nums = top_logprobs_nums[0]
|
762
900
|
else:
|
763
901
|
if isinstance(input_ids[0], str):
|
764
902
|
prompt_kwargs = {"text": input_ids}
|
765
903
|
else:
|
766
904
|
prompt_kwargs = {"input_ids": input_ids}
|
905
|
+
|
767
906
|
adapted_request = GenerateReqInput(
|
768
907
|
**prompt_kwargs,
|
769
908
|
image_data=image_data,
|
770
909
|
sampling_params=sampling_params_list,
|
771
910
|
return_logprob=return_logprobs,
|
911
|
+
logprob_start_len=logprob_start_lens,
|
772
912
|
top_logprobs_num=top_logprobs_nums,
|
773
913
|
stream=all_requests[0].stream,
|
774
914
|
return_text_in_logprobs=True,
|
915
|
+
rid=request_ids,
|
775
916
|
)
|
776
917
|
if len(all_requests) == 1:
|
777
918
|
return adapted_request, all_requests[0]
|
@@ -892,16 +1033,23 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
892
1033
|
if adapted_request.stream:
|
893
1034
|
|
894
1035
|
async def generate_stream_resp():
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
1036
|
+
is_firsts = {}
|
1037
|
+
stream_buffers = {}
|
1038
|
+
n_prev_tokens = {}
|
1039
|
+
prompt_tokens = {}
|
1040
|
+
completion_tokens = {}
|
899
1041
|
try:
|
900
1042
|
async for content in tokenizer_manager.generate_request(
|
901
1043
|
adapted_request, raw_request
|
902
1044
|
):
|
903
|
-
|
904
|
-
|
1045
|
+
index = content["index"]
|
1046
|
+
|
1047
|
+
is_first = is_firsts.get(index, True)
|
1048
|
+
stream_buffer = stream_buffers.get(index, "")
|
1049
|
+
n_prev_token = n_prev_tokens.get(index, 0)
|
1050
|
+
|
1051
|
+
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
1052
|
+
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
905
1053
|
if request.logprobs:
|
906
1054
|
logprobs = to_openai_style_logprobs(
|
907
1055
|
output_token_logprobs=content["meta_info"][
|
@@ -951,7 +1099,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
951
1099
|
# First chunk with role
|
952
1100
|
is_first = False
|
953
1101
|
choice_data = ChatCompletionResponseStreamChoice(
|
954
|
-
index=
|
1102
|
+
index=index,
|
955
1103
|
delta=DeltaMessage(role="assistant"),
|
956
1104
|
finish_reason=format_finish_reason(
|
957
1105
|
content["meta_info"]["finish_reason"]
|
@@ -969,7 +1117,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
969
1117
|
delta = text[len(stream_buffer) :]
|
970
1118
|
stream_buffer = stream_buffer + delta
|
971
1119
|
choice_data = ChatCompletionResponseStreamChoice(
|
972
|
-
index=
|
1120
|
+
index=index,
|
973
1121
|
delta=DeltaMessage(content=delta),
|
974
1122
|
finish_reason=format_finish_reason(
|
975
1123
|
content["meta_info"]["finish_reason"]
|
@@ -981,12 +1129,25 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
981
1129
|
choices=[choice_data],
|
982
1130
|
model=request.model,
|
983
1131
|
)
|
1132
|
+
|
1133
|
+
is_firsts[index] = is_first
|
1134
|
+
stream_buffers[index] = stream_buffer
|
1135
|
+
n_prev_tokens[index] = n_prev_token
|
1136
|
+
|
984
1137
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
985
1138
|
if request.stream_options and request.stream_options.include_usage:
|
1139
|
+
total_prompt_tokens = sum(
|
1140
|
+
tokens
|
1141
|
+
for i, tokens in prompt_tokens.items()
|
1142
|
+
if i % request.n == 0
|
1143
|
+
)
|
1144
|
+
total_completion_tokens = sum(
|
1145
|
+
tokens for tokens in completion_tokens.values()
|
1146
|
+
)
|
986
1147
|
usage = UsageInfo(
|
987
|
-
prompt_tokens=
|
988
|
-
completion_tokens=
|
989
|
-
total_tokens=
|
1148
|
+
prompt_tokens=total_prompt_tokens,
|
1149
|
+
completion_tokens=total_completion_tokens,
|
1150
|
+
total_tokens=total_prompt_tokens + total_completion_tokens,
|
990
1151
|
)
|
991
1152
|
|
992
1153
|
final_usage_chunk = ChatCompletionStreamResponse(
|
@@ -161,6 +161,7 @@ class CompletionRequest(BaseModel):
|
|
161
161
|
|
162
162
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
163
163
|
regex: Optional[str] = None
|
164
|
+
json_schema: Optional[str] = None
|
164
165
|
ignore_eos: Optional[bool] = False
|
165
166
|
min_tokens: Optional[int] = 0
|
166
167
|
repetition_penalty: Optional[float] = 1.0
|
@@ -262,6 +263,7 @@ class ChatCompletionRequest(BaseModel):
|
|
262
263
|
|
263
264
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
264
265
|
regex: Optional[str] = None
|
266
|
+
json_schema: Optional[str] = None
|
265
267
|
min_tokens: Optional[int] = 0
|
266
268
|
repetition_penalty: Optional[float] = 1.0
|
267
269
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
@@ -0,0 +1,136 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from typing import TYPE_CHECKING, List
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
import sglang.srt.sampling.penaltylib as penaltylib
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
12
|
+
|
13
|
+
|
14
|
+
@dataclasses.dataclass
|
15
|
+
class SamplingBatchInfo:
|
16
|
+
# Basic Info
|
17
|
+
vocab_size: int
|
18
|
+
|
19
|
+
# Batched sampling params
|
20
|
+
temperatures: torch.Tensor = None
|
21
|
+
top_ps: torch.Tensor = None
|
22
|
+
top_ks: torch.Tensor = None
|
23
|
+
min_ps: torch.Tensor = None
|
24
|
+
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
25
|
+
logit_bias: torch.Tensor = None
|
26
|
+
vocab_mask: torch.Tensor = None
|
27
|
+
|
28
|
+
@classmethod
|
29
|
+
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
30
|
+
device = "cuda"
|
31
|
+
reqs = batch.reqs
|
32
|
+
ret = cls(vocab_size=vocab_size)
|
33
|
+
|
34
|
+
ret.temperatures = torch.tensor(
|
35
|
+
[r.sampling_params.temperature for r in reqs],
|
36
|
+
dtype=torch.float,
|
37
|
+
device=device,
|
38
|
+
).view(-1, 1)
|
39
|
+
ret.top_ps = torch.tensor(
|
40
|
+
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
41
|
+
)
|
42
|
+
ret.top_ks = torch.tensor(
|
43
|
+
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
44
|
+
)
|
45
|
+
ret.min_ps = torch.tensor(
|
46
|
+
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
47
|
+
)
|
48
|
+
|
49
|
+
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
50
|
+
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
51
|
+
# should not add hefty computation overhead other than simple checks.
|
52
|
+
#
|
53
|
+
# While we choose not to even create the class instances if they are not required, this
|
54
|
+
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
55
|
+
# handle {filter_batch()} and {merge()} cases as well.
|
56
|
+
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
57
|
+
vocab_size=vocab_size,
|
58
|
+
batch=batch,
|
59
|
+
device=device,
|
60
|
+
Penalizers={
|
61
|
+
penaltylib.BatchedFrequencyPenalizer,
|
62
|
+
penaltylib.BatchedMinNewTokensPenalizer,
|
63
|
+
penaltylib.BatchedPresencePenalizer,
|
64
|
+
penaltylib.BatchedRepetitionPenalizer,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
|
68
|
+
# Handle logit bias but only allocate when needed
|
69
|
+
ret.logit_bias = None
|
70
|
+
|
71
|
+
ret.update_regex_vocab_mask(batch)
|
72
|
+
|
73
|
+
return ret
|
74
|
+
|
75
|
+
def update_regex_vocab_mask(self, batch: ScheduleBatch):
|
76
|
+
bs, reqs = batch.batch_size(), batch.reqs
|
77
|
+
device = "cuda"
|
78
|
+
has_regex = any(req.regex_fsm is not None for req in reqs)
|
79
|
+
|
80
|
+
# Reset the vocab mask
|
81
|
+
self.vocab_mask = None
|
82
|
+
|
83
|
+
if has_regex:
|
84
|
+
for i, req in enumerate(reqs):
|
85
|
+
if req.regex_fsm is not None:
|
86
|
+
if self.vocab_mask is None:
|
87
|
+
self.vocab_mask = torch.zeros(
|
88
|
+
bs, self.vocab_size, dtype=torch.bool, device=device
|
89
|
+
)
|
90
|
+
self.vocab_mask[i][
|
91
|
+
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
92
|
+
] = 1
|
93
|
+
|
94
|
+
def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
95
|
+
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
96
|
+
|
97
|
+
for item in [
|
98
|
+
"temperatures",
|
99
|
+
"top_ps",
|
100
|
+
"top_ks",
|
101
|
+
"min_ps",
|
102
|
+
"logit_bias",
|
103
|
+
]:
|
104
|
+
self_val = getattr(self, item, None)
|
105
|
+
if self_val is not None: # logit_bias can be None
|
106
|
+
setattr(self, item, self_val[new_indices])
|
107
|
+
|
108
|
+
def merge(self, other: "SamplingBatchInfo"):
|
109
|
+
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
110
|
+
|
111
|
+
for item in [
|
112
|
+
"temperatures",
|
113
|
+
"top_ps",
|
114
|
+
"top_ks",
|
115
|
+
"min_ps",
|
116
|
+
]:
|
117
|
+
self_val = getattr(self, item, None)
|
118
|
+
other_val = getattr(other, item, None)
|
119
|
+
setattr(self, item, torch.concat([self_val, other_val]))
|
120
|
+
|
121
|
+
# logit_bias can be None
|
122
|
+
if self.logit_bias is not None or other.logit_bias is not None:
|
123
|
+
vocab_size = (
|
124
|
+
self.logit_bias.shape[1]
|
125
|
+
if self.logit_bias is not None
|
126
|
+
else other.logit_bias.shape[1]
|
127
|
+
)
|
128
|
+
if self.logit_bias is None:
|
129
|
+
self.logit_bias = torch.zeros(
|
130
|
+
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
131
|
+
)
|
132
|
+
if other.logit_bias is None:
|
133
|
+
other.logit_bias = torch.zeros(
|
134
|
+
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
135
|
+
)
|
136
|
+
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
@@ -30,6 +30,7 @@ class SamplingParams:
|
|
30
30
|
temperature: float = 1.0,
|
31
31
|
top_p: float = 1.0,
|
32
32
|
top_k: int = -1,
|
33
|
+
min_p: float = 0.0,
|
33
34
|
frequency_penalty: float = 0.0,
|
34
35
|
presence_penalty: float = 0.0,
|
35
36
|
repetition_penalty: float = 1.0,
|
@@ -38,10 +39,12 @@ class SamplingParams:
|
|
38
39
|
spaces_between_special_tokens: bool = True,
|
39
40
|
regex: Optional[str] = None,
|
40
41
|
n: int = 1,
|
42
|
+
json_schema: Optional[str] = None,
|
41
43
|
) -> None:
|
42
44
|
self.temperature = temperature
|
43
45
|
self.top_p = top_p
|
44
46
|
self.top_k = top_k
|
47
|
+
self.min_p = min_p
|
45
48
|
self.frequency_penalty = frequency_penalty
|
46
49
|
self.presence_penalty = presence_penalty
|
47
50
|
self.repetition_penalty = repetition_penalty
|
@@ -54,6 +57,7 @@ class SamplingParams:
|
|
54
57
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
55
58
|
self.regex = regex
|
56
59
|
self.n = n
|
60
|
+
self.json_schema = json_schema
|
57
61
|
|
58
62
|
# Process some special cases
|
59
63
|
if self.temperature < _SAMPLING_EPS:
|
@@ -69,6 +73,8 @@ class SamplingParams:
|
|
69
73
|
)
|
70
74
|
if not 0.0 < self.top_p <= 1.0:
|
71
75
|
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
76
|
+
if not 0.0 <= self.min_p <= 1.0:
|
77
|
+
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
|
72
78
|
if self.top_k < -1 or self.top_k == 0:
|
73
79
|
raise ValueError(
|
74
80
|
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
|
@@ -102,6 +108,8 @@ class SamplingParams:
|
|
102
108
|
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
|
103
109
|
f"{self.min_new_tokens}."
|
104
110
|
)
|
111
|
+
if self.regex is not None and self.json_schema is not None:
|
112
|
+
raise ValueError("regex and json_schema cannot be both set.")
|
105
113
|
|
106
114
|
def normalize(self, tokenizer):
|
107
115
|
# Process stop strings
|
@@ -123,3 +131,17 @@ class SamplingParams:
|
|
123
131
|
else:
|
124
132
|
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
125
133
|
self.stop_str_max_len = stop_str_max_len
|
134
|
+
|
135
|
+
def to_srt_kwargs(self):
|
136
|
+
return {
|
137
|
+
"max_new_tokens": self.max_new_tokens,
|
138
|
+
"stop": self.stop_strs,
|
139
|
+
"stop_token_ids": list(self.stop_token_ids),
|
140
|
+
"temperature": self.temperature,
|
141
|
+
"top_p": self.top_p,
|
142
|
+
"top_k": self.top_k,
|
143
|
+
"frequency_penalty": self.frequency_penalty,
|
144
|
+
"presence_penalty": self.presence_penalty,
|
145
|
+
"ignore_eos": self.ignore_eos,
|
146
|
+
"regex": self.regex,
|
147
|
+
}
|