sglang 0.3.4.post1__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_latency.py +3 -3
  3. sglang/bench_server_latency.py +2 -3
  4. sglang/bench_serving.py +92 -0
  5. sglang/global_config.py +9 -3
  6. sglang/lang/chat_template.py +50 -25
  7. sglang/lang/interpreter.py +9 -1
  8. sglang/lang/ir.py +11 -2
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/configs/model_config.py +76 -15
  11. sglang/srt/constrained/__init__.py +18 -0
  12. sglang/srt/constrained/bnf_cache.py +61 -0
  13. sglang/srt/constrained/fsm_cache.py +10 -3
  14. sglang/srt/constrained/grammar.py +190 -0
  15. sglang/srt/hf_transformers_utils.py +20 -5
  16. sglang/srt/layers/attention/flashinfer_backend.py +5 -5
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
  18. sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
  19. sglang/srt/layers/fused_moe/fused_moe.py +4 -3
  20. sglang/srt/layers/fused_moe/layer.py +28 -0
  21. sglang/srt/layers/logits_processor.py +5 -5
  22. sglang/srt/layers/quantization/base_config.py +16 -1
  23. sglang/srt/layers/rotary_embedding.py +15 -48
  24. sglang/srt/layers/sampler.py +51 -39
  25. sglang/srt/layers/vocab_parallel_embedding.py +486 -0
  26. sglang/srt/managers/data_parallel_controller.py +8 -7
  27. sglang/srt/managers/detokenizer_manager.py +11 -9
  28. sglang/srt/managers/image_processor.py +4 -3
  29. sglang/srt/managers/io_struct.py +80 -78
  30. sglang/srt/managers/schedule_batch.py +46 -52
  31. sglang/srt/managers/schedule_policy.py +24 -13
  32. sglang/srt/managers/scheduler.py +145 -82
  33. sglang/srt/managers/tokenizer_manager.py +236 -334
  34. sglang/srt/managers/tp_worker.py +5 -5
  35. sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
  36. sglang/srt/mem_cache/flush_cache.py +1 -1
  37. sglang/srt/mem_cache/memory_pool.py +10 -3
  38. sglang/srt/model_executor/cuda_graph_runner.py +34 -23
  39. sglang/srt/model_executor/forward_batch_info.py +6 -9
  40. sglang/srt/model_executor/model_runner.py +10 -19
  41. sglang/srt/models/baichuan.py +4 -4
  42. sglang/srt/models/chatglm.py +4 -4
  43. sglang/srt/models/commandr.py +1 -1
  44. sglang/srt/models/dbrx.py +5 -5
  45. sglang/srt/models/deepseek.py +4 -4
  46. sglang/srt/models/deepseek_v2.py +4 -4
  47. sglang/srt/models/exaone.py +4 -4
  48. sglang/srt/models/gemma.py +1 -1
  49. sglang/srt/models/gemma2.py +1 -1
  50. sglang/srt/models/gpt2.py +287 -0
  51. sglang/srt/models/gpt_bigcode.py +1 -1
  52. sglang/srt/models/grok.py +4 -4
  53. sglang/srt/models/internlm2.py +4 -4
  54. sglang/srt/models/llama.py +15 -7
  55. sglang/srt/models/llama_embedding.py +2 -10
  56. sglang/srt/models/llama_reward.py +5 -0
  57. sglang/srt/models/minicpm.py +4 -4
  58. sglang/srt/models/minicpm3.py +4 -4
  59. sglang/srt/models/mixtral.py +7 -5
  60. sglang/srt/models/mixtral_quant.py +4 -4
  61. sglang/srt/models/mllama.py +5 -5
  62. sglang/srt/models/olmo.py +4 -4
  63. sglang/srt/models/olmoe.py +4 -4
  64. sglang/srt/models/qwen.py +4 -4
  65. sglang/srt/models/qwen2.py +4 -4
  66. sglang/srt/models/qwen2_moe.py +4 -4
  67. sglang/srt/models/qwen2_vl.py +4 -8
  68. sglang/srt/models/stablelm.py +4 -4
  69. sglang/srt/models/torch_native_llama.py +4 -4
  70. sglang/srt/models/xverse.py +4 -4
  71. sglang/srt/models/xverse_moe.py +4 -4
  72. sglang/srt/openai_api/adapter.py +52 -66
  73. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  74. sglang/srt/sampling/sampling_batch_info.py +7 -13
  75. sglang/srt/sampling/sampling_params.py +5 -7
  76. sglang/srt/server.py +41 -33
  77. sglang/srt/server_args.py +34 -5
  78. sglang/srt/utils.py +40 -56
  79. sglang/test/run_eval.py +2 -0
  80. sglang/test/runners.py +2 -1
  81. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  82. sglang/test/test_utils.py +151 -6
  83. sglang/utils.py +62 -1
  84. sglang/version.py +1 -1
  85. sglang-0.3.5.dist-info/METADATA +344 -0
  86. sglang-0.3.5.dist-info/RECORD +152 -0
  87. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
  88. sglang-0.3.4.post1.dist-info/METADATA +0 -900
  89. sglang-0.3.4.post1.dist-info/RECORD +0 -148
  90. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -29,10 +29,6 @@ from vllm.distributed import (
29
29
  )
30
30
  from vllm.model_executor.layers.fused_moe import FusedMoE
31
31
  from vllm.model_executor.layers.rotary_embedding import get_rope
32
- from vllm.model_executor.layers.vocab_parallel_embedding import (
33
- ParallelLMHead,
34
- VocabParallelEmbedding,
35
- )
36
32
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37
33
 
38
34
  from sglang.srt.layers.activation import SiluAndMul
@@ -47,6 +43,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
47
43
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
48
44
  from sglang.srt.layers.radix_attention import RadixAttention
49
45
  from sglang.srt.layers.torchao_utils import apply_torchao_config_
46
+ from sglang.srt.layers.vocab_parallel_embedding import (
47
+ ParallelLMHead,
48
+ VocabParallelEmbedding,
49
+ )
50
50
  from sglang.srt.managers.schedule_batch import global_server_args_dict
51
51
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
52
52
 
@@ -23,7 +23,7 @@
23
23
  # limitations under the License.
24
24
  """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
25
25
  from functools import lru_cache, partial
26
- from typing import Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union
26
+ from typing import Iterable, List, Optional, Tuple, Type, TypedDict
27
27
 
28
28
  import numpy as np
29
29
  import torch
@@ -35,9 +35,7 @@ from vllm.distributed import parallel_state
35
35
  from vllm.distributed import utils as dist_utils
36
36
  from vllm.logger import init_logger
37
37
  from vllm.model_executor.layers.activation import QuickGELU
38
- from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
39
38
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
- from vllm.model_executor.models.interfaces import SupportsMultiModal
41
39
 
42
40
  from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
43
41
  from sglang.srt.hf_transformers_utils import get_processor
@@ -47,6 +45,7 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
47
45
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
48
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
49
47
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
48
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
50
49
  from sglang.srt.managers.schedule_batch import ImageInputs
51
50
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
52
51
  from sglang.srt.models.qwen2 import Qwen2Model
@@ -486,7 +485,7 @@ class Qwen2VisionTransformer(nn.Module):
486
485
  cached_get_processor = lru_cache(get_processor)
487
486
 
488
487
 
489
- class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
488
+ class Qwen2VLForConditionalGeneration(nn.Module):
490
489
  def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
491
490
  processor = cached_get_processor(self.config._name_or_path)
492
491
  grid_t, grid_h, grid_w = image_grid_thw
@@ -536,15 +535,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
536
535
  def __init__(
537
536
  self,
538
537
  config: Qwen2VLConfig,
539
- multimodal_config: MultiModalConfig,
540
538
  cache_config: Optional[CacheConfig] = None,
541
539
  quant_config: Optional[QuantizationConfig] = None,
542
540
  ) -> None:
543
541
  super().__init__()
544
542
 
545
543
  self.config = config
546
- self.multimodal_config = multimodal_config
547
-
548
544
  self.visual = Qwen2VisionTransformer(
549
545
  config.vision_config,
550
546
  norm_eps=getattr(config, "rms_norm_eps", 1e-6),
@@ -622,7 +618,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
622
618
  extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
623
619
  prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
624
620
  for i, image in enumerate(forward_batch.image_inputs):
625
- if image == None:
621
+ if image is None:
626
622
  continue
627
623
  start_idx = extend_start_loc_cpu[i]
628
624
  prefix_len = prefix_lens_cpu[i]
@@ -24,10 +24,6 @@ from torch import nn
24
24
  from transformers import PretrainedConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
- from vllm.model_executor.layers.vocab_parallel_embedding import (
28
- ParallelLMHead,
29
- VocabParallelEmbedding,
30
- )
31
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
28
 
33
29
  from sglang.srt.layers.activation import SiluAndMul
@@ -39,6 +35,10 @@ from sglang.srt.layers.linear import (
39
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
40
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.vocab_parallel_embedding import (
39
+ ParallelLMHead,
40
+ VocabParallelEmbedding,
41
+ )
42
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
43
 
44
44
 
@@ -26,10 +26,6 @@ from torch.nn.parameter import Parameter
26
26
  from transformers import LlamaConfig
27
27
  from vllm.distributed import get_tensor_model_parallel_world_size
28
28
  from vllm.model_executor.layers.rotary_embedding import get_rope
29
- from vllm.model_executor.layers.vocab_parallel_embedding import (
30
- ParallelLMHead,
31
- VocabParallelEmbedding,
32
- )
33
29
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
30
 
35
31
  from sglang.srt.layers.activation import SiluAndMul
@@ -38,6 +34,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
38
34
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
35
  from sglang.srt.layers.radix_attention import RadixAttention
40
36
  from sglang.srt.layers.torchao_utils import apply_torchao_config_
37
+ from sglang.srt.layers.vocab_parallel_embedding import (
38
+ ParallelLMHead,
39
+ VocabParallelEmbedding,
40
+ )
41
41
  from sglang.srt.managers.schedule_batch import global_server_args_dict
42
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
43
 
@@ -31,15 +31,15 @@ from vllm.model_executor.layers.linear import (
31
31
  RowParallelLinear,
32
32
  )
33
33
  from vllm.model_executor.layers.rotary_embedding import get_rope
34
- from vllm.model_executor.layers.vocab_parallel_embedding import (
35
- ParallelLMHead,
36
- VocabParallelEmbedding,
37
- )
38
34
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
35
 
40
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.vocab_parallel_embedding import (
40
+ ParallelLMHead,
41
+ VocabParallelEmbedding,
42
+ )
43
43
  from sglang.srt.model_executor.model_runner import ForwardBatch
44
44
 
45
45
 
@@ -34,15 +34,15 @@ from vllm.model_executor.layers.linear import (
34
34
  RowParallelLinear,
35
35
  )
36
36
  from vllm.model_executor.layers.rotary_embedding import get_rope
37
- from vllm.model_executor.layers.vocab_parallel_embedding import (
38
- ParallelLMHead,
39
- VocabParallelEmbedding,
40
- )
41
37
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
38
 
43
39
  from sglang.srt.layers.logits_processor import LogitsProcessor
44
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
45
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.vocab_parallel_embedding import (
43
+ ParallelLMHead,
44
+ VocabParallelEmbedding,
45
+ )
46
46
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
47
 
48
48
 
@@ -71,6 +71,7 @@ from sglang.srt.openai_api.protocol import (
71
71
  TopLogprob,
72
72
  UsageInfo,
73
73
  )
74
+ from sglang.utils import get_exception_traceback
74
75
 
75
76
  logger = logging.getLogger(__name__)
76
77
 
@@ -314,6 +315,8 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
314
315
  )
315
316
 
316
317
  except Exception as e:
318
+ logger.error(f"error: {get_exception_traceback()}")
319
+ responses = []
317
320
  error_json = {
318
321
  "id": f"batch_req_{uuid.uuid4()}",
319
322
  "custom_id": request_data.get("custom_id"),
@@ -363,7 +366,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
363
366
  }
364
367
 
365
368
  except Exception as e:
366
- logger.error("error in SGLang:", e)
369
+ logger.error(f"error: {e}")
367
370
  # Update batch status to "failed"
368
371
  retrieve_batch = batch_storage[batch_id]
369
372
  retrieve_batch.status = "failed"
@@ -469,80 +472,67 @@ async def v1_retrieve_file_content(file_id: str):
469
472
  def v1_generate_request(
470
473
  all_requests: List[CompletionRequest], request_ids: List[str] = None
471
474
  ):
475
+ if len(all_requests) > 1:
476
+ first_prompt_type = type(all_requests[0].prompt)
477
+ for request in all_requests:
478
+ assert (
479
+ type(request.prompt) is first_prompt_type
480
+ ), "All prompts must be of the same type in file input settings"
481
+ if request.n > 1:
482
+ raise ValueError(
483
+ "Parallel sampling is not supported for completions from files"
484
+ )
485
+
472
486
  prompts = []
473
487
  sampling_params_list = []
474
488
  return_logprobs = []
475
489
  logprob_start_lens = []
476
490
  top_logprobs_nums = []
477
491
 
478
- # NOTE: with openai API, the prompt's logprobs are always not computed
479
- first_prompt_type = type(all_requests[0].prompt)
480
492
  for request in all_requests:
481
- assert (
482
- type(request.prompt) is first_prompt_type
483
- ), "All prompts must be of the same type in file input settings"
484
- if len(all_requests) > 1 and request.n > 1:
485
- raise ValueError(
486
- "Parallel sampling is not supported for completions from files"
487
- )
493
+ # NOTE: with openai API, the prompt's logprobs are always not computed
488
494
  if request.echo and request.logprobs:
489
495
  logger.warning(
490
496
  "Echo is not compatible with logprobs. "
491
- "To compute logprobs of input prompt, please use SGLang /request API."
497
+ "To compute logprobs of input prompt, please use the native /generate API."
492
498
  )
493
499
 
494
- for request in all_requests:
495
500
  prompts.append(request.prompt)
501
+ sampling_params_list.append(
502
+ {
503
+ "temperature": request.temperature,
504
+ "max_new_tokens": request.max_tokens,
505
+ "min_new_tokens": request.min_tokens,
506
+ "stop": request.stop,
507
+ "stop_token_ids": request.stop_token_ids,
508
+ "top_p": request.top_p,
509
+ "presence_penalty": request.presence_penalty,
510
+ "frequency_penalty": request.frequency_penalty,
511
+ "repetition_penalty": request.repetition_penalty,
512
+ "regex": request.regex,
513
+ "json_schema": request.json_schema,
514
+ "n": request.n,
515
+ "ignore_eos": request.ignore_eos,
516
+ "no_stop_trim": request.no_stop_trim,
517
+ }
518
+ )
496
519
  return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
497
520
  logprob_start_lens.append(-1)
498
521
  top_logprobs_nums.append(
499
522
  request.logprobs if request.logprobs is not None else 0
500
523
  )
501
- sampling_params = []
502
- if isinstance(request.no_stop_trim, list):
503
- num_reqs = len(request.prompt)
504
- else:
505
- num_reqs = 1
506
- for i in range(num_reqs):
507
- sampling_params.append(
508
- {
509
- "temperature": request.temperature,
510
- "max_new_tokens": request.max_tokens,
511
- "min_new_tokens": request.min_tokens,
512
- "stop": request.stop,
513
- "stop_token_ids": request.stop_token_ids,
514
- "top_p": request.top_p,
515
- "presence_penalty": request.presence_penalty,
516
- "frequency_penalty": request.frequency_penalty,
517
- "repetition_penalty": request.repetition_penalty,
518
- "regex": request.regex,
519
- "json_schema": request.json_schema,
520
- "n": request.n,
521
- "ignore_eos": request.ignore_eos,
522
- "no_stop_trim": (
523
- request.no_stop_trim
524
- if not isinstance(request.no_stop_trim, list)
525
- else request.no_stop_trim[i]
526
- ),
527
- }
528
- )
529
- if num_reqs == 1:
530
- sampling_params_list.append(sampling_params[0])
531
- else:
532
- sampling_params_list.append(sampling_params)
533
524
 
534
525
  if len(all_requests) == 1:
535
- prompt = prompts[0]
526
+ if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
527
+ prompt_kwargs = {"text": prompts[0]}
528
+ else:
529
+ prompt_kwargs = {"input_ids": prompts[0]}
536
530
  sampling_params_list = sampling_params_list[0]
537
- logprob_start_lens = logprob_start_lens[0]
538
531
  return_logprobs = return_logprobs[0]
532
+ logprob_start_lens = logprob_start_lens[0]
539
533
  top_logprobs_nums = top_logprobs_nums[0]
540
- if isinstance(prompt, str) or isinstance(prompt[0], str):
541
- prompt_kwargs = {"text": prompt}
542
- else:
543
- prompt_kwargs = {"input_ids": prompt}
544
534
  else:
545
- if isinstance(prompts[0], str):
535
+ if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
546
536
  prompt_kwargs = {"text": prompts}
547
537
  else:
548
538
  prompt_kwargs = {"input_ids": prompts}
@@ -558,9 +548,7 @@ def v1_generate_request(
558
548
  rid=request_ids,
559
549
  )
560
550
 
561
- if len(all_requests) == 1:
562
- return adapted_request, all_requests[0]
563
- return adapted_request, all_requests
551
+ return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
564
552
 
565
553
 
566
554
  def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
@@ -595,7 +583,7 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
595
583
  if isinstance(request, list) and request[idx].echo:
596
584
  echo = True
597
585
  text = request[idx].prompt + text
598
- if (not isinstance(request, list)) and echo:
586
+ if echo and not isinstance(request, list):
599
587
  prompt_index = idx // request.n
600
588
  text = prompts[prompt_index] + text
601
589
 
@@ -709,7 +697,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
709
697
  async for content in tokenizer_manager.generate_request(
710
698
  adapted_request, raw_request
711
699
  ):
712
- index = content["index"]
700
+ index = content.get("index", 0)
713
701
 
714
702
  stream_buffer = stream_buffers.get(index, "")
715
703
  n_prev_token = n_prev_tokens.get(index, 0)
@@ -945,19 +933,18 @@ def v1_chat_generate_request(
945
933
  sampling_params_list.append(sampling_params)
946
934
 
947
935
  image_data_list.append(image_data)
948
- modalities_list.extend(modalities)
936
+ modalities_list.append(modalities)
949
937
  if len(all_requests) == 1:
950
- input_ids = input_ids[0]
951
- if isinstance(input_ids, str):
952
- prompt_kwargs = {"text": input_ids}
938
+ if isinstance(input_ids[0], str):
939
+ prompt_kwargs = {"text": input_ids[0]}
953
940
  else:
954
- prompt_kwargs = {"input_ids": input_ids}
941
+ prompt_kwargs = {"input_ids": input_ids[0]}
955
942
  sampling_params_list = sampling_params_list[0]
956
943
  image_data_list = image_data_list[0]
957
944
  return_logprobs = return_logprobs[0]
958
945
  logprob_start_lens = logprob_start_lens[0]
959
946
  top_logprobs_nums = top_logprobs_nums[0]
960
- modalities_list = modalities_list[:1]
947
+ modalities_list = modalities_list[0]
961
948
  else:
962
949
  if isinstance(input_ids[0], str):
963
950
  prompt_kwargs = {"text": input_ids}
@@ -976,9 +963,8 @@ def v1_chat_generate_request(
976
963
  rid=request_ids,
977
964
  modalities=modalities_list,
978
965
  )
979
- if len(all_requests) == 1:
980
- return adapted_request, all_requests[0]
981
- return adapted_request, all_requests
966
+
967
+ return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
982
968
 
983
969
 
984
970
  def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
@@ -1116,7 +1102,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1116
1102
  async for content in tokenizer_manager.generate_request(
1117
1103
  adapted_request, raw_request
1118
1104
  ):
1119
- index = content["index"]
1105
+ index = content.get("index", 0)
1120
1106
 
1121
1107
  is_first = is_firsts.get(index, True)
1122
1108
  stream_buffer = stream_buffers.get(index, "")
@@ -31,9 +31,12 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
31
31
  padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(
32
32
  sequences=[
33
33
  torch.tensor(
34
- data=list(
35
- req.sampling_params.stop_token_ids
36
- | {req.tokenizer.eos_token_id}
34
+ data=(
35
+ list(
36
+ (req.sampling_params.stop_token_ids or set())
37
+ | (req.tokenizer.additional_stop_token_ids or set())
38
+ | {req.tokenizer.eos_token_id}
39
+ )
37
40
  ),
38
41
  dtype=torch.int64,
39
42
  device=self.orchestrator.device,
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List, Optional
6
6
  import torch
7
7
 
8
8
  import sglang.srt.sampling.penaltylib as penaltylib
9
- from sglang.srt.constrained import RegexGuide
9
+ from sglang.srt.constrained.grammar import Grammar
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -29,11 +29,9 @@ class SamplingBatchInfo:
29
29
  # Bias Tensors
30
30
  vocab_size: int
31
31
  logit_bias: torch.Tensor = None
32
- vocab_mask: torch.Tensor = None
32
+ vocab_mask: Optional[torch.Tensor] = None
33
33
 
34
- # FSM states
35
- regex_fsms: List[RegexGuide] = None
36
- regex_fsm_states: List[int] = None
34
+ grammars: Optional[List[Optional[Grammar]]] = None
37
35
 
38
36
  # Penalizer
39
37
  penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
@@ -136,8 +134,7 @@ class SamplingBatchInfo:
136
134
  self.linear_penalties = penalizer.apply(self.linear_penalties)
137
135
 
138
136
  def update_regex_vocab_mask(self):
139
- has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
140
- if not has_regex:
137
+ if not self.grammars or not any(grammar for grammar in self.grammars):
141
138
  self.vocab_mask = None
142
139
  return
143
140
 
@@ -147,12 +144,9 @@ class SamplingBatchInfo:
147
144
  dtype=torch.bool,
148
145
  device=self.device,
149
146
  )
150
- for i, regex_fsm in enumerate(self.regex_fsms):
151
- if regex_fsm is not None:
152
- self.vocab_mask[i].fill_(1)
153
- self.vocab_mask[i][
154
- regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
155
- ] = 0
147
+ for i, grammar in enumerate(self.grammars):
148
+ if grammar is not None:
149
+ grammar.fill_vocab_mask(self.vocab_mask[i], self.vocab_size)
156
150
 
157
151
  def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
158
152
  if self.penalizer_orchestrator:
@@ -50,9 +50,10 @@ class SamplingParams:
50
50
  self.presence_penalty = presence_penalty
51
51
  self.repetition_penalty = repetition_penalty
52
52
  self.stop_strs = stop
53
- if stop_token_ids is None:
54
- stop_token_ids = []
55
- self.stop_token_ids = {*stop_token_ids}
53
+ if stop_token_ids:
54
+ self.stop_token_ids = set(stop_token_ids)
55
+ else:
56
+ self.stop_token_ids = None
56
57
  self.max_new_tokens = max_new_tokens
57
58
  self.min_new_tokens = min_new_tokens
58
59
  self.ignore_eos = ignore_eos
@@ -119,10 +120,7 @@ class SamplingParams:
119
120
  # Process stop strings
120
121
  if self.stop_strs is None:
121
122
  self.stop_strs = []
122
- if self.stop_token_ids is None:
123
- self.stop_str_max_len = 0
124
- else:
125
- self.stop_str_max_len = 1
123
+ self.stop_str_max_len = 0
126
124
  else:
127
125
  if isinstance(self.stop_strs, str):
128
126
  self.stop_strs = [self.stop_strs]
sglang/srt/server.py CHANGED
@@ -53,7 +53,6 @@ from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
53
53
  from sglang.srt.managers.io_struct import (
54
54
  EmbeddingReqInput,
55
55
  GenerateReqInput,
56
- RewardReqInput,
57
56
  UpdateWeightReqInput,
58
57
  )
59
58
  from sglang.srt.managers.scheduler import run_scheduler_process
@@ -91,7 +90,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
91
90
 
92
91
 
93
92
  app = FastAPI()
94
- tokenizer_manager = None
93
+ tokenizer_manager: TokenizerManager = None
95
94
 
96
95
  app.add_middleware(
97
96
  CORSMiddleware,
@@ -139,7 +138,7 @@ async def get_server_args():
139
138
  return dataclasses.asdict(tokenizer_manager.server_args)
140
139
 
141
140
 
142
- @app.get("/flush_cache")
141
+ @app.post("/flush_cache")
143
142
  async def flush_cache():
144
143
  """Flush the radix cache."""
145
144
  tokenizer_manager.flush_cache()
@@ -172,6 +171,19 @@ async def stop_profile():
172
171
  )
173
172
 
174
173
 
174
+ @app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
175
+ async def get_memory_pool_size():
176
+ """Get the memory pool size in number of tokens"""
177
+ try:
178
+ ret = await tokenizer_manager.get_memory_pool_size()
179
+
180
+ return ret
181
+ except Exception as e:
182
+ return ORJSONResponse(
183
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
184
+ )
185
+
186
+
175
187
  @app.post("/update_weights")
176
188
  async def update_weights(obj: UpdateWeightReqInput, request: Request):
177
189
  """Update the weights inplace without re-launching the server."""
@@ -241,8 +253,8 @@ app.post("/encode")(encode_request)
241
253
  app.put("/encode")(encode_request)
242
254
 
243
255
 
244
- async def judge_request(obj: RewardReqInput, request: Request):
245
- """Handle a reward model request."""
256
+ async def judge_request(obj: EmbeddingReqInput, request: Request):
257
+ """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
246
258
  try:
247
259
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
248
260
  return ret
@@ -429,7 +441,7 @@ def launch_server(
429
441
 
430
442
  # Send a warmup request
431
443
  t = threading.Thread(
432
- target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
444
+ target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
433
445
  )
434
446
  t.start()
435
447
 
@@ -484,7 +496,7 @@ def _set_envs_and_config(server_args: ServerArgs):
484
496
  mp.set_start_method("spawn", force=True)
485
497
 
486
498
 
487
- def _wait_and_warmup(server_args, pipe_finish_writer, pid):
499
+ def _wait_and_warmup(server_args, pipe_finish_writer):
488
500
  headers = {}
489
501
  url = server_args.url()
490
502
  if server_args.api_key:
@@ -507,7 +519,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
507
519
  if pipe_finish_writer is not None:
508
520
  pipe_finish_writer.send(last_traceback)
509
521
  logger.error(f"Initialization failed. warmup error: {last_traceback}")
510
- kill_child_process(pid, including_parent=False)
522
+ kill_child_process(include_self=True)
511
523
  return
512
524
 
513
525
  model_info = res.json()
@@ -539,7 +551,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
539
551
  if pipe_finish_writer is not None:
540
552
  pipe_finish_writer.send(last_traceback)
541
553
  logger.error(f"Initialization failed. warmup error: {last_traceback}")
542
- kill_child_process(pid, including_parent=False)
554
+ kill_child_process(include_self=True)
543
555
  return
544
556
 
545
557
  # logger.info(f"{res.json()=}")
@@ -605,7 +617,7 @@ class Runtime:
605
617
 
606
618
  def shutdown(self):
607
619
  if self.pid is not None:
608
- kill_child_process(self.pid)
620
+ kill_child_process(self.pid, include_self=True)
609
621
  self.pid = None
610
622
 
611
623
  def cache_prefix(self, prefix: str):
@@ -684,24 +696,8 @@ class Runtime:
684
696
  self,
685
697
  prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
686
698
  ):
687
- if isinstance(prompt, str) or isinstance(prompt[0], str):
688
- # embedding
689
- json_data = {
690
- "text": prompt,
691
- }
692
- response = requests.post(
693
- self.url + "/encode",
694
- json=json_data,
695
- )
696
- else:
697
- # reward
698
- json_data = {
699
- "conv": prompt,
700
- }
701
- response = requests.post(
702
- self.url + "/judge",
703
- json=json_data,
704
- )
699
+ json_data = {"text": prompt}
700
+ response = requests.post(self.url + "/encode", json=json_data)
705
701
  return json.dumps(response.json())
706
702
 
707
703
  def __del__(self):
@@ -724,24 +720,32 @@ class Engine:
724
720
 
725
721
  # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
726
722
  atexit.register(self.shutdown)
723
+
724
+ # runtime server default log level is log
725
+ # offline engine works in scripts, so we set it to error
726
+
727
+ if 'log_level' not in kwargs:
728
+ kwargs['log_level'] = 'error'
727
729
 
728
730
  server_args = ServerArgs(*args, **kwargs)
729
731
  launch_engine(server_args=server_args)
730
732
 
731
733
  def generate(
732
734
  self,
733
- prompt: Union[str, List[str]],
735
+ # The input prompt. It can be a single prompt or a batch of prompts.
736
+ prompt: Optional[Union[List[str], str]] = None,
734
737
  sampling_params: Optional[Dict] = None,
738
+ # The token ids for text; one can either specify text or input_ids.
739
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None,
735
740
  return_logprob: Optional[Union[List[bool], bool]] = False,
736
741
  logprob_start_len: Optional[Union[List[int], int]] = None,
737
742
  top_logprobs_num: Optional[Union[List[int], int]] = None,
738
743
  lora_path: Optional[List[Optional[str]]] = None,
739
744
  stream: bool = False,
740
745
  ):
741
- # TODO (ByronHsu): refactor to reduce the duplicated code
742
-
743
746
  obj = GenerateReqInput(
744
747
  text=prompt,
748
+ input_ids=input_ids,
745
749
  sampling_params=sampling_params,
746
750
  return_logprob=return_logprob,
747
751
  logprob_start_len=logprob_start_len,
@@ -779,8 +783,11 @@ class Engine:
779
783
 
780
784
  async def async_generate(
781
785
  self,
782
- prompt: Union[str, List[str]],
786
+ # The input prompt. It can be a single prompt or a batch of prompts.
787
+ prompt: Optional[Union[List[str], str]] = None,
783
788
  sampling_params: Optional[Dict] = None,
789
+ # The token ids for text; one can either specify text or input_ids.
790
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None,
784
791
  return_logprob: Optional[Union[List[bool], bool]] = False,
785
792
  logprob_start_len: Optional[Union[List[int], int]] = None,
786
793
  top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -789,6 +796,7 @@ class Engine:
789
796
  ):
790
797
  obj = GenerateReqInput(
791
798
  text=prompt,
799
+ input_ids=input_ids,
792
800
  sampling_params=sampling_params,
793
801
  return_logprob=return_logprob,
794
802
  logprob_start_len=logprob_start_len,
@@ -822,7 +830,7 @@ class Engine:
822
830
  return ret
823
831
 
824
832
  def shutdown(self):
825
- kill_child_process(os.getpid(), including_parent=False)
833
+ kill_child_process()
826
834
 
827
835
  def get_tokenizer(self):
828
836
  global tokenizer_manager