sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/constrained/fsm_cache.py +11 -2
  11. sglang/srt/constrained/jump_forward.py +1 -0
  12. sglang/srt/conversation.py +50 -1
  13. sglang/srt/hf_transformers_utils.py +22 -23
  14. sglang/srt/layers/activation.py +100 -1
  15. sglang/srt/layers/decode_attention.py +338 -50
  16. sglang/srt/layers/fused_moe/layer.py +2 -2
  17. sglang/srt/layers/logits_processor.py +56 -19
  18. sglang/srt/layers/radix_attention.py +3 -4
  19. sglang/srt/layers/sampler.py +101 -0
  20. sglang/srt/managers/controller_multi.py +2 -8
  21. sglang/srt/managers/controller_single.py +7 -10
  22. sglang/srt/managers/detokenizer_manager.py +20 -9
  23. sglang/srt/managers/io_struct.py +44 -11
  24. sglang/srt/managers/policy_scheduler.py +5 -2
  25. sglang/srt/managers/schedule_batch.py +46 -166
  26. sglang/srt/managers/tokenizer_manager.py +192 -83
  27. sglang/srt/managers/tp_worker.py +118 -24
  28. sglang/srt/mem_cache/memory_pool.py +82 -8
  29. sglang/srt/mm_utils.py +79 -7
  30. sglang/srt/model_executor/cuda_graph_runner.py +32 -8
  31. sglang/srt/model_executor/forward_batch_info.py +51 -26
  32. sglang/srt/model_executor/model_runner.py +201 -58
  33. sglang/srt/models/gemma2.py +10 -6
  34. sglang/srt/models/gpt_bigcode.py +1 -1
  35. sglang/srt/models/grok.py +11 -1
  36. sglang/srt/models/llama_embedding.py +4 -0
  37. sglang/srt/models/llava.py +176 -59
  38. sglang/srt/models/qwen2.py +9 -3
  39. sglang/srt/openai_api/adapter.py +200 -39
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_batch_info.py +136 -0
  42. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
  43. sglang/srt/server.py +92 -57
  44. sglang/srt/server_args.py +43 -15
  45. sglang/srt/utils.py +26 -16
  46. sglang/test/runners.py +22 -30
  47. sglang/test/simple_eval_common.py +9 -10
  48. sglang/test/simple_eval_gpqa.py +2 -1
  49. sglang/test/simple_eval_humaneval.py +2 -2
  50. sglang/test/simple_eval_math.py +2 -1
  51. sglang/test/simple_eval_mmlu.py +2 -1
  52. sglang/test/test_activation.py +55 -0
  53. sglang/test/test_utils.py +36 -53
  54. sglang/version.py +1 -1
  55. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
  56. sglang-0.2.14.post1.dist-info/RECORD +114 -0
  57. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
  58. sglang/launch_server_llavavid.py +0 -29
  59. sglang-0.2.13.dist-info/RECORD +0 -112
  60. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
  61. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ limitations under the License.
17
17
 
18
18
  import asyncio
19
19
  import json
20
+ import logging
20
21
  import os
21
22
  import time
22
23
  import uuid
@@ -64,6 +65,8 @@ from sglang.srt.openai_api.protocol import (
64
65
  UsageInfo,
65
66
  )
66
67
 
68
+ logger = logging.getLogger(__name__)
69
+
67
70
  chat_template_name = None
68
71
 
69
72
 
@@ -120,7 +123,7 @@ def create_streaming_error_response(
120
123
  def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
121
124
  global chat_template_name
122
125
 
123
- print(f"Use chat template: {chat_template_arg}")
126
+ logger.info(f"Use chat template: {chat_template_arg}")
124
127
  if not chat_template_exists(chat_template_arg):
125
128
  if not os.path.exists(chat_template_arg):
126
129
  raise RuntimeError(
@@ -272,20 +275,32 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
272
275
  end_point = batch_storage[batch_id].endpoint
273
276
  file_request_list = []
274
277
  all_requests = []
278
+ request_ids = []
275
279
  for line in lines:
276
280
  request_data = json.loads(line)
277
281
  file_request_list.append(request_data)
278
282
  body = request_data["body"]
283
+ request_ids.append(request_data["custom_id"])
284
+
285
+ # Although streaming is supported for standalone completions, it is not supported in
286
+ # batch mode (multiple completions in single request).
287
+ if body.get("stream", False):
288
+ raise ValueError("Streaming requests are not supported in batch mode")
289
+
279
290
  if end_point == "/v1/chat/completions":
280
291
  all_requests.append(ChatCompletionRequest(**body))
281
292
  elif end_point == "/v1/completions":
282
293
  all_requests.append(CompletionRequest(**body))
294
+
283
295
  if end_point == "/v1/chat/completions":
284
296
  adapted_request, request = v1_chat_generate_request(
285
- all_requests, tokenizer_manager
297
+ all_requests, tokenizer_manager, request_ids=request_ids
286
298
  )
287
299
  elif end_point == "/v1/completions":
288
- adapted_request, request = v1_generate_request(all_requests)
300
+ adapted_request, request = v1_generate_request(
301
+ all_requests, request_ids=request_ids
302
+ )
303
+
289
304
  try:
290
305
  ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
291
306
  if not isinstance(ret, list):
@@ -317,6 +332,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
317
332
  }
318
333
  all_ret.append(response_json)
319
334
  completed_requests += 1
335
+
320
336
  # Write results to a new file
321
337
  output_file_id = f"backend_result_file-{uuid.uuid4()}"
322
338
  global storage_dir
@@ -346,7 +362,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
346
362
  }
347
363
 
348
364
  except Exception as e:
349
- print("error in SGLang:", e)
365
+ logger.error("error in SGLang:", e)
350
366
  # Update batch status to "failed"
351
367
  retrieve_batch = batch_storage[batch_id]
352
368
  retrieve_batch.status = "failed"
@@ -363,6 +379,72 @@ async def v1_retrieve_batch(batch_id: str):
363
379
  return batch_response
364
380
 
365
381
 
382
+ async def v1_cancel_batch(tokenizer_manager, batch_id: str):
383
+ # Retrieve the batch job from the in-memory storage
384
+ batch_response = batch_storage.get(batch_id)
385
+ if batch_response is None:
386
+ raise HTTPException(status_code=404, detail="Batch not found")
387
+
388
+ # Only do cancal when status is "validating" or "in_progress"
389
+ if batch_response.status in ["validating", "in_progress"]:
390
+ # Start cancelling the batch asynchronously
391
+ asyncio.create_task(
392
+ cancel_batch(
393
+ tokenizer_manager=tokenizer_manager,
394
+ batch_id=batch_id,
395
+ input_file_id=batch_response.input_file_id,
396
+ )
397
+ )
398
+
399
+ # Update batch status to "cancelling"
400
+ batch_response.status = "cancelling"
401
+
402
+ return batch_response
403
+ else:
404
+ raise HTTPException(
405
+ status_code=500,
406
+ detail=f"Current status is {batch_response.status}, no need to cancel",
407
+ )
408
+
409
+
410
+ async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
411
+ try:
412
+ # Update the batch status to "cancelling"
413
+ batch_storage[batch_id].status = "cancelling"
414
+
415
+ # Retrieve the input file content
416
+ input_file_request = file_id_request.get(input_file_id)
417
+ if not input_file_request:
418
+ raise ValueError("Input file not found")
419
+
420
+ # Parse the JSONL file and process each request
421
+ input_file_path = file_id_storage.get(input_file_id)
422
+ with open(input_file_path, "r", encoding="utf-8") as f:
423
+ lines = f.readlines()
424
+
425
+ file_request_list = []
426
+ request_ids = []
427
+ for line in lines:
428
+ request_data = json.loads(line)
429
+ file_request_list.append(request_data)
430
+ request_ids.append(request_data["custom_id"])
431
+
432
+ # Cancel requests by request_ids
433
+ for rid in request_ids:
434
+ tokenizer_manager.abort_request(rid=rid)
435
+
436
+ retrieve_batch = batch_storage[batch_id]
437
+ retrieve_batch.status = "cancelled"
438
+
439
+ except Exception as e:
440
+ logger.error("error in SGLang:", e)
441
+ # Update batch status to "failed"
442
+ retrieve_batch = batch_storage[batch_id]
443
+ retrieve_batch.status = "failed"
444
+ retrieve_batch.failed_at = int(time.time())
445
+ retrieve_batch.errors = {"message": str(e)}
446
+
447
+
366
448
  async def v1_retrieve_file(file_id: str):
367
449
  # Retrieve the batch job from the in-memory storage
368
450
  file_response = file_id_response.get(file_id)
@@ -383,20 +465,35 @@ async def v1_retrieve_file_content(file_id: str):
383
465
  return StreamingResponse(iter_file(), media_type="application/octet-stream")
384
466
 
385
467
 
386
- def v1_generate_request(all_requests):
468
+ def v1_generate_request(
469
+ all_requests: List[CompletionRequest], request_ids: List[str] = None
470
+ ):
387
471
  prompts = []
388
472
  sampling_params_list = []
389
473
  return_logprobs = []
474
+ logprob_start_lens = []
390
475
  top_logprobs_nums = []
391
- first_prompt_type = type(all_requests[0].prompt)
392
476
 
477
+ # NOTE: with openai API, the prompt's logprobs are always not computed
478
+ first_prompt_type = type(all_requests[0].prompt)
393
479
  for request in all_requests:
394
- prompt = request.prompt
395
480
  assert (
396
- type(prompt) == first_prompt_type
481
+ type(request.prompt) == first_prompt_type
397
482
  ), "All prompts must be of the same type in file input settings"
398
- prompts.append(prompt)
483
+ if len(all_requests) > 1 and request.n > 1:
484
+ raise ValueError(
485
+ "Parallel sampling is not supported for completions from files"
486
+ )
487
+ if request.echo and request.logprobs:
488
+ logger.warning(
489
+ "Echo is not compatible with logprobs. "
490
+ "To compute logprobs of input prompt, please use SGLang /request API."
491
+ )
492
+
493
+ for request in all_requests:
494
+ prompts.append(request.prompt)
399
495
  return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
496
+ logprob_start_lens.append(-1)
400
497
  top_logprobs_nums.append(
401
498
  request.logprobs if request.logprobs is not None else 0
402
499
  )
@@ -412,18 +509,16 @@ def v1_generate_request(all_requests):
412
509
  "frequency_penalty": request.frequency_penalty,
413
510
  "repetition_penalty": request.repetition_penalty,
414
511
  "regex": request.regex,
512
+ "json_schema": request.json_schema,
415
513
  "n": request.n,
416
514
  "ignore_eos": request.ignore_eos,
417
515
  }
418
516
  )
419
- if len(all_requests) > 1 and request.n > 1:
420
- raise ValueError(
421
- "Parallel sampling is not supported for completions from files"
422
- )
423
517
 
424
518
  if len(all_requests) == 1:
425
519
  prompt = prompts[0]
426
520
  sampling_params_list = sampling_params_list[0]
521
+ logprob_start_lens = logprob_start_lens[0]
427
522
  return_logprobs = return_logprobs[0]
428
523
  top_logprobs_nums = top_logprobs_nums[0]
429
524
  if isinstance(prompt, str) or isinstance(prompt[0], str):
@@ -441,8 +536,10 @@ def v1_generate_request(all_requests):
441
536
  sampling_params=sampling_params_list,
442
537
  return_logprob=return_logprobs,
443
538
  top_logprobs_num=top_logprobs_nums,
539
+ logprob_start_len=logprob_start_lens,
444
540
  return_text_in_logprobs=True,
445
541
  stream=all_requests[0].stream,
542
+ rid=request_ids,
446
543
  )
447
544
 
448
545
  if len(all_requests) == 1:
@@ -580,27 +677,45 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
580
677
  if adapted_request.stream:
581
678
 
582
679
  async def generate_stream_resp():
583
- stream_buffer = ""
584
- n_prev_token = 0
680
+ stream_buffers = {}
681
+ n_prev_tokens = {}
682
+ prompt_tokens = {}
683
+ completion_tokens = {}
585
684
  try:
586
685
  async for content in tokenizer_manager.generate_request(
587
686
  adapted_request, raw_request
588
687
  ):
688
+ index = content["index"]
689
+
690
+ stream_buffer = stream_buffers.get(index, "")
691
+ n_prev_token = n_prev_tokens.get(index, 0)
692
+
589
693
  text = content["text"]
590
- prompt_tokens = content["meta_info"]["prompt_tokens"]
591
- completion_tokens = content["meta_info"]["completion_tokens"]
694
+ prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
695
+ completion_tokens[index] = content["meta_info"]["completion_tokens"]
592
696
 
593
697
  if not stream_buffer: # The first chunk
594
698
  if request.echo:
595
699
  if isinstance(request.prompt, str):
596
700
  # for the case of single str prompts
597
701
  prompts = request.prompt
598
- elif isinstance(request.prompt, list) and isinstance(
599
- request.prompt[0], int
600
- ):
601
- prompts = tokenizer_manager.tokenizer.decode(
602
- request.prompt, skip_special_tokens=True
603
- )
702
+ elif isinstance(request.prompt, list):
703
+ if isinstance(request.prompt[0], str):
704
+ # for the case of multiple str prompts
705
+ prompts = request.prompt[index // request.n]
706
+ elif isinstance(request.prompt[0], int):
707
+ # for the case of single token ids prompt
708
+ prompts = tokenizer_manager.tokenizer.decode(
709
+ request.prompt, skip_special_tokens=True
710
+ )
711
+ elif isinstance(request.prompt[0], list) and isinstance(
712
+ request.prompt[0][0], int
713
+ ):
714
+ # for the case of multiple token ids prompts
715
+ prompts = tokenizer_manager.tokenizer.decode(
716
+ request.prompt[index // request.n],
717
+ skip_special_tokens=True,
718
+ )
604
719
 
605
720
  # Prepend prompt in response text.
606
721
  text = prompts + text
@@ -637,7 +752,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
637
752
  delta = text[len(stream_buffer) :]
638
753
  stream_buffer = stream_buffer + delta
639
754
  choice_data = CompletionResponseStreamChoice(
640
- index=0,
755
+ index=index,
641
756
  text=delta,
642
757
  logprobs=logprobs,
643
758
  finish_reason=format_finish_reason(
@@ -650,12 +765,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
650
765
  choices=[choice_data],
651
766
  model=request.model,
652
767
  )
768
+
769
+ stream_buffers[index] = stream_buffer
770
+ n_prev_tokens[index] = n_prev_token
771
+
653
772
  yield f"data: {chunk.model_dump_json()}\n\n"
654
773
  if request.stream_options and request.stream_options.include_usage:
774
+ total_prompt_tokens = sum(
775
+ tokens
776
+ for i, tokens in prompt_tokens.items()
777
+ if i % request.n == 0
778
+ )
779
+ total_completion_tokens = sum(
780
+ tokens for tokens in completion_tokens.values()
781
+ )
655
782
  usage = UsageInfo(
656
- prompt_tokens=prompt_tokens,
657
- completion_tokens=completion_tokens,
658
- total_tokens=prompt_tokens + completion_tokens,
783
+ prompt_tokens=total_prompt_tokens,
784
+ completion_tokens=total_completion_tokens,
785
+ total_tokens=total_prompt_tokens + total_completion_tokens,
659
786
  )
660
787
 
661
788
  final_usage_chunk = CompletionStreamResponse(
@@ -694,12 +821,20 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
694
821
  return response
695
822
 
696
823
 
697
- def v1_chat_generate_request(all_requests, tokenizer_manager):
824
+ def v1_chat_generate_request(
825
+ all_requests: List[ChatCompletionRequest],
826
+ tokenizer_manager,
827
+ request_ids: List[str] = None,
828
+ ):
698
829
  input_ids = []
699
830
  sampling_params_list = []
700
831
  image_data_list = []
701
832
  return_logprobs = []
833
+ logprob_start_lens = []
702
834
  top_logprobs_nums = []
835
+
836
+ # NOTE: with openai API, the prompt's logprobs are always not computed
837
+
703
838
  for request in all_requests:
704
839
  # Prep the data needed for the underlying GenerateReqInput:
705
840
  # - prompt: The full prompt string.
@@ -732,6 +867,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
732
867
  image_data = None
733
868
  input_ids.append(prompt_ids)
734
869
  return_logprobs.append(request.logprobs)
870
+ logprob_start_lens.append(-1)
735
871
  top_logprobs_nums.append(request.top_logprobs)
736
872
  sampling_params_list.append(
737
873
  {
@@ -745,6 +881,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
745
881
  "frequency_penalty": request.frequency_penalty,
746
882
  "repetition_penalty": request.repetition_penalty,
747
883
  "regex": request.regex,
884
+ "json_schema": request.json_schema,
748
885
  "n": request.n,
749
886
  }
750
887
  )
@@ -758,20 +895,24 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
758
895
  sampling_params_list = sampling_params_list[0]
759
896
  image_data = image_data_list[0]
760
897
  return_logprobs = return_logprobs[0]
898
+ logprob_start_lens = logprob_start_lens[0]
761
899
  top_logprobs_nums = top_logprobs_nums[0]
762
900
  else:
763
901
  if isinstance(input_ids[0], str):
764
902
  prompt_kwargs = {"text": input_ids}
765
903
  else:
766
904
  prompt_kwargs = {"input_ids": input_ids}
905
+
767
906
  adapted_request = GenerateReqInput(
768
907
  **prompt_kwargs,
769
908
  image_data=image_data,
770
909
  sampling_params=sampling_params_list,
771
910
  return_logprob=return_logprobs,
911
+ logprob_start_len=logprob_start_lens,
772
912
  top_logprobs_num=top_logprobs_nums,
773
913
  stream=all_requests[0].stream,
774
914
  return_text_in_logprobs=True,
915
+ rid=request_ids,
775
916
  )
776
917
  if len(all_requests) == 1:
777
918
  return adapted_request, all_requests[0]
@@ -892,16 +1033,23 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
892
1033
  if adapted_request.stream:
893
1034
 
894
1035
  async def generate_stream_resp():
895
- is_first = True
896
-
897
- stream_buffer = ""
898
- n_prev_token = 0
1036
+ is_firsts = {}
1037
+ stream_buffers = {}
1038
+ n_prev_tokens = {}
1039
+ prompt_tokens = {}
1040
+ completion_tokens = {}
899
1041
  try:
900
1042
  async for content in tokenizer_manager.generate_request(
901
1043
  adapted_request, raw_request
902
1044
  ):
903
- prompt_tokens = content["meta_info"]["prompt_tokens"]
904
- completion_tokens = content["meta_info"]["completion_tokens"]
1045
+ index = content["index"]
1046
+
1047
+ is_first = is_firsts.get(index, True)
1048
+ stream_buffer = stream_buffers.get(index, "")
1049
+ n_prev_token = n_prev_tokens.get(index, 0)
1050
+
1051
+ prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
1052
+ completion_tokens[index] = content["meta_info"]["completion_tokens"]
905
1053
  if request.logprobs:
906
1054
  logprobs = to_openai_style_logprobs(
907
1055
  output_token_logprobs=content["meta_info"][
@@ -951,7 +1099,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
951
1099
  # First chunk with role
952
1100
  is_first = False
953
1101
  choice_data = ChatCompletionResponseStreamChoice(
954
- index=0,
1102
+ index=index,
955
1103
  delta=DeltaMessage(role="assistant"),
956
1104
  finish_reason=format_finish_reason(
957
1105
  content["meta_info"]["finish_reason"]
@@ -969,7 +1117,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
969
1117
  delta = text[len(stream_buffer) :]
970
1118
  stream_buffer = stream_buffer + delta
971
1119
  choice_data = ChatCompletionResponseStreamChoice(
972
- index=0,
1120
+ index=index,
973
1121
  delta=DeltaMessage(content=delta),
974
1122
  finish_reason=format_finish_reason(
975
1123
  content["meta_info"]["finish_reason"]
@@ -981,12 +1129,25 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
981
1129
  choices=[choice_data],
982
1130
  model=request.model,
983
1131
  )
1132
+
1133
+ is_firsts[index] = is_first
1134
+ stream_buffers[index] = stream_buffer
1135
+ n_prev_tokens[index] = n_prev_token
1136
+
984
1137
  yield f"data: {chunk.model_dump_json()}\n\n"
985
1138
  if request.stream_options and request.stream_options.include_usage:
1139
+ total_prompt_tokens = sum(
1140
+ tokens
1141
+ for i, tokens in prompt_tokens.items()
1142
+ if i % request.n == 0
1143
+ )
1144
+ total_completion_tokens = sum(
1145
+ tokens for tokens in completion_tokens.values()
1146
+ )
986
1147
  usage = UsageInfo(
987
- prompt_tokens=prompt_tokens,
988
- completion_tokens=completion_tokens,
989
- total_tokens=prompt_tokens + completion_tokens,
1148
+ prompt_tokens=total_prompt_tokens,
1149
+ completion_tokens=total_completion_tokens,
1150
+ total_tokens=total_prompt_tokens + total_completion_tokens,
990
1151
  )
991
1152
 
992
1153
  final_usage_chunk = ChatCompletionStreamResponse(
@@ -161,6 +161,7 @@ class CompletionRequest(BaseModel):
161
161
 
162
162
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
163
163
  regex: Optional[str] = None
164
+ json_schema: Optional[str] = None
164
165
  ignore_eos: Optional[bool] = False
165
166
  min_tokens: Optional[int] = 0
166
167
  repetition_penalty: Optional[float] = 1.0
@@ -262,6 +263,7 @@ class ChatCompletionRequest(BaseModel):
262
263
 
263
264
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
264
265
  regex: Optional[str] = None
266
+ json_schema: Optional[str] = None
265
267
  min_tokens: Optional[int] = 0
266
268
  repetition_penalty: Optional[float] = 1.0
267
269
  stop_token_ids: Optional[List[int]] = Field(default_factory=list)
@@ -0,0 +1,136 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from typing import TYPE_CHECKING, List
5
+
6
+ import torch
7
+
8
+ import sglang.srt.sampling.penaltylib as penaltylib
9
+
10
+ if TYPE_CHECKING:
11
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
12
+
13
+
14
+ @dataclasses.dataclass
15
+ class SamplingBatchInfo:
16
+ # Basic Info
17
+ vocab_size: int
18
+
19
+ # Batched sampling params
20
+ temperatures: torch.Tensor = None
21
+ top_ps: torch.Tensor = None
22
+ top_ks: torch.Tensor = None
23
+ min_ps: torch.Tensor = None
24
+ penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
25
+ logit_bias: torch.Tensor = None
26
+ vocab_mask: torch.Tensor = None
27
+
28
+ @classmethod
29
+ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
30
+ device = "cuda"
31
+ reqs = batch.reqs
32
+ ret = cls(vocab_size=vocab_size)
33
+
34
+ ret.temperatures = torch.tensor(
35
+ [r.sampling_params.temperature for r in reqs],
36
+ dtype=torch.float,
37
+ device=device,
38
+ ).view(-1, 1)
39
+ ret.top_ps = torch.tensor(
40
+ [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
41
+ )
42
+ ret.top_ks = torch.tensor(
43
+ [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
44
+ )
45
+ ret.min_ps = torch.tensor(
46
+ [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
47
+ )
48
+
49
+ # Each penalizers will do nothing if they evaluate themselves as not required by looking at
50
+ # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
51
+ # should not add hefty computation overhead other than simple checks.
52
+ #
53
+ # While we choose not to even create the class instances if they are not required, this
54
+ # could add additional complexity to the {ScheduleBatch} class, especially we need to
55
+ # handle {filter_batch()} and {merge()} cases as well.
56
+ ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
57
+ vocab_size=vocab_size,
58
+ batch=batch,
59
+ device=device,
60
+ Penalizers={
61
+ penaltylib.BatchedFrequencyPenalizer,
62
+ penaltylib.BatchedMinNewTokensPenalizer,
63
+ penaltylib.BatchedPresencePenalizer,
64
+ penaltylib.BatchedRepetitionPenalizer,
65
+ },
66
+ )
67
+
68
+ # Handle logit bias but only allocate when needed
69
+ ret.logit_bias = None
70
+
71
+ ret.update_regex_vocab_mask(batch)
72
+
73
+ return ret
74
+
75
+ def update_regex_vocab_mask(self, batch: ScheduleBatch):
76
+ bs, reqs = batch.batch_size(), batch.reqs
77
+ device = "cuda"
78
+ has_regex = any(req.regex_fsm is not None for req in reqs)
79
+
80
+ # Reset the vocab mask
81
+ self.vocab_mask = None
82
+
83
+ if has_regex:
84
+ for i, req in enumerate(reqs):
85
+ if req.regex_fsm is not None:
86
+ if self.vocab_mask is None:
87
+ self.vocab_mask = torch.zeros(
88
+ bs, self.vocab_size, dtype=torch.bool, device=device
89
+ )
90
+ self.vocab_mask[i][
91
+ req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
92
+ ] = 1
93
+
94
+ def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
95
+ self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
96
+
97
+ for item in [
98
+ "temperatures",
99
+ "top_ps",
100
+ "top_ks",
101
+ "min_ps",
102
+ "logit_bias",
103
+ ]:
104
+ self_val = getattr(self, item, None)
105
+ if self_val is not None: # logit_bias can be None
106
+ setattr(self, item, self_val[new_indices])
107
+
108
+ def merge(self, other: "SamplingBatchInfo"):
109
+ self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
110
+
111
+ for item in [
112
+ "temperatures",
113
+ "top_ps",
114
+ "top_ks",
115
+ "min_ps",
116
+ ]:
117
+ self_val = getattr(self, item, None)
118
+ other_val = getattr(other, item, None)
119
+ setattr(self, item, torch.concat([self_val, other_val]))
120
+
121
+ # logit_bias can be None
122
+ if self.logit_bias is not None or other.logit_bias is not None:
123
+ vocab_size = (
124
+ self.logit_bias.shape[1]
125
+ if self.logit_bias is not None
126
+ else other.logit_bias.shape[1]
127
+ )
128
+ if self.logit_bias is None:
129
+ self.logit_bias = torch.zeros(
130
+ (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
131
+ )
132
+ if other.logit_bias is None:
133
+ other.logit_bias = torch.zeros(
134
+ (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
135
+ )
136
+ self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
@@ -30,6 +30,7 @@ class SamplingParams:
30
30
  temperature: float = 1.0,
31
31
  top_p: float = 1.0,
32
32
  top_k: int = -1,
33
+ min_p: float = 0.0,
33
34
  frequency_penalty: float = 0.0,
34
35
  presence_penalty: float = 0.0,
35
36
  repetition_penalty: float = 1.0,
@@ -38,10 +39,12 @@ class SamplingParams:
38
39
  spaces_between_special_tokens: bool = True,
39
40
  regex: Optional[str] = None,
40
41
  n: int = 1,
42
+ json_schema: Optional[str] = None,
41
43
  ) -> None:
42
44
  self.temperature = temperature
43
45
  self.top_p = top_p
44
46
  self.top_k = top_k
47
+ self.min_p = min_p
45
48
  self.frequency_penalty = frequency_penalty
46
49
  self.presence_penalty = presence_penalty
47
50
  self.repetition_penalty = repetition_penalty
@@ -54,6 +57,7 @@ class SamplingParams:
54
57
  self.spaces_between_special_tokens = spaces_between_special_tokens
55
58
  self.regex = regex
56
59
  self.n = n
60
+ self.json_schema = json_schema
57
61
 
58
62
  # Process some special cases
59
63
  if self.temperature < _SAMPLING_EPS:
@@ -69,6 +73,8 @@ class SamplingParams:
69
73
  )
70
74
  if not 0.0 < self.top_p <= 1.0:
71
75
  raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
76
+ if not 0.0 <= self.min_p <= 1.0:
77
+ raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
72
78
  if self.top_k < -1 or self.top_k == 0:
73
79
  raise ValueError(
74
80
  f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
@@ -102,6 +108,8 @@ class SamplingParams:
102
108
  f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
103
109
  f"{self.min_new_tokens}."
104
110
  )
111
+ if self.regex is not None and self.json_schema is not None:
112
+ raise ValueError("regex and json_schema cannot be both set.")
105
113
 
106
114
  def normalize(self, tokenizer):
107
115
  # Process stop strings
@@ -123,3 +131,17 @@ class SamplingParams:
123
131
  else:
124
132
  stop_str_max_len = max(stop_str_max_len, len(stop_str))
125
133
  self.stop_str_max_len = stop_str_max_len
134
+
135
+ def to_srt_kwargs(self):
136
+ return {
137
+ "max_new_tokens": self.max_new_tokens,
138
+ "stop": self.stop_strs,
139
+ "stop_token_ids": list(self.stop_token_ids),
140
+ "temperature": self.temperature,
141
+ "top_p": self.top_p,
142
+ "top_k": self.top_k,
143
+ "frequency_penalty": self.frequency_penalty,
144
+ "presence_penalty": self.presence_penalty,
145
+ "ignore_eos": self.ignore_eos,
146
+ "regex": self.regex,
147
+ }