sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/bench_one_batch.py +3 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/check_env.py +3 -3
  4. sglang/lang/chat_template.py +44 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/deepseekvl2.py +3 -0
  7. sglang/srt/configs/device_config.py +1 -1
  8. sglang/srt/configs/internvl.py +696 -0
  9. sglang/srt/configs/janus_pro.py +3 -0
  10. sglang/srt/configs/kimi_vl.py +38 -0
  11. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  12. sglang/srt/configs/model_config.py +32 -0
  13. sglang/srt/constrained/xgrammar_backend.py +11 -19
  14. sglang/srt/conversation.py +151 -3
  15. sglang/srt/disaggregation/decode.py +4 -1
  16. sglang/srt/disaggregation/mini_lb.py +74 -23
  17. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  18. sglang/srt/disaggregation/nixl/conn.py +241 -71
  19. sglang/srt/disaggregation/utils.py +44 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  21. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  22. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  24. sglang/srt/distributed/parallel_state.py +22 -1
  25. sglang/srt/entrypoints/engine.py +58 -24
  26. sglang/srt/entrypoints/http_server.py +28 -1
  27. sglang/srt/entrypoints/verl_engine.py +3 -2
  28. sglang/srt/function_call_parser.py +97 -0
  29. sglang/srt/hf_transformers_utils.py +22 -1
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  31. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  32. sglang/srt/layers/attention/flashinfer_backend.py +129 -94
  33. sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
  34. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  35. sglang/srt/layers/attention/merge_state.py +46 -0
  36. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  37. sglang/srt/layers/attention/vision.py +290 -163
  38. sglang/srt/layers/dp_attention.py +5 -2
  39. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  40. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
  49. sglang/srt/layers/quantization/__init__.py +2 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  52. sglang/srt/layers/quantization/deep_gemm.py +6 -1
  53. sglang/srt/layers/quantization/fp8.py +108 -95
  54. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  55. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  56. sglang/srt/layers/quantization/kv_cache.py +3 -10
  57. sglang/srt/layers/quantization/utils.py +0 -5
  58. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  59. sglang/srt/layers/utils.py +35 -0
  60. sglang/srt/lora/layers.py +35 -9
  61. sglang/srt/lora/lora_manager.py +81 -35
  62. sglang/srt/managers/cache_controller.py +115 -119
  63. sglang/srt/managers/data_parallel_controller.py +52 -34
  64. sglang/srt/managers/io_struct.py +10 -0
  65. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  66. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  67. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  68. sglang/srt/managers/schedule_batch.py +44 -16
  69. sglang/srt/managers/schedule_policy.py +11 -5
  70. sglang/srt/managers/scheduler.py +291 -72
  71. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  72. sglang/srt/managers/tokenizer_manager.py +24 -13
  73. sglang/srt/managers/tp_worker.py +60 -28
  74. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  75. sglang/srt/mem_cache/chunk_cache.py +2 -0
  76. sglang/srt/mem_cache/memory_pool.py +70 -36
  77. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  78. sglang/srt/model_executor/forward_batch_info.py +31 -1
  79. sglang/srt/model_executor/model_runner.py +159 -90
  80. sglang/srt/model_loader/loader.py +18 -11
  81. sglang/srt/models/clip.py +4 -4
  82. sglang/srt/models/deepseek_janus_pro.py +1 -1
  83. sglang/srt/models/deepseek_nextn.py +2 -277
  84. sglang/srt/models/deepseek_v2.py +132 -37
  85. sglang/srt/models/gemma3_mm.py +1 -1
  86. sglang/srt/models/internlm2.py +3 -0
  87. sglang/srt/models/internvl.py +670 -0
  88. sglang/srt/models/kimi_vl.py +308 -0
  89. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  90. sglang/srt/models/llama.py +93 -31
  91. sglang/srt/models/llama4.py +54 -7
  92. sglang/srt/models/llama_eagle.py +4 -1
  93. sglang/srt/models/llama_eagle3.py +4 -1
  94. sglang/srt/models/minicpmv.py +1 -1
  95. sglang/srt/models/mllama.py +1 -1
  96. sglang/srt/models/phi3_small.py +16 -2
  97. sglang/srt/models/qwen2_5_vl.py +8 -4
  98. sglang/srt/models/qwen2_moe.py +8 -3
  99. sglang/srt/models/qwen2_vl.py +4 -16
  100. sglang/srt/models/qwen3_moe.py +8 -3
  101. sglang/srt/models/xiaomi_mimo.py +171 -0
  102. sglang/srt/openai_api/adapter.py +58 -62
  103. sglang/srt/openai_api/protocol.py +38 -16
  104. sglang/srt/reasoning_parser.py +2 -2
  105. sglang/srt/sampling/sampling_batch_info.py +54 -2
  106. sglang/srt/sampling/sampling_params.py +2 -0
  107. sglang/srt/server_args.py +93 -24
  108. sglang/srt/speculative/eagle_worker.py +3 -2
  109. sglang/srt/utils.py +123 -10
  110. sglang/test/runners.py +4 -0
  111. sglang/test/test_block_fp8.py +2 -2
  112. sglang/test/test_deepep_utils.py +219 -0
  113. sglang/test/test_utils.py +32 -1
  114. sglang/version.py +1 -1
  115. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
  116. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
  117. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  118. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,73 @@
1
+ import asyncio
2
+ import math
3
+ from typing import List, Union
4
+
5
+ import torch
6
+ from PIL import Image
7
+
8
+ from sglang.srt.managers.multimodal_processors.base_processor import (
9
+ BaseMultimodalProcessor as SGLangBaseProcessor,
10
+ )
11
+ from sglang.srt.managers.multimodal_processors.base_processor import (
12
+ MultimodalSpecialTokens,
13
+ )
14
+ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
15
+ from sglang.srt.models.kimi_vl import KimiVLForConditionalGeneration
16
+
17
+
18
+ # Compatible with KimiVLForConditionalGeneration
19
+ class KimiVLImageProcessor(SGLangBaseProcessor):
20
+ models = [KimiVLForConditionalGeneration]
21
+
22
+ def __init__(self, hf_config, server_args, _processor):
23
+ super().__init__(hf_config, server_args, _processor)
24
+ self.IMAGE_TOKEN = "<|media_pad|>"
25
+ self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
26
+
27
+ self.im_start = "<|media_start|>"
28
+ self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start)
29
+
30
+ self.im_end = "<|media_end|>"
31
+ self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end)
32
+
33
+ self.im_content = "<|media_content|>"
34
+ self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content)
35
+
36
+ async def process_mm_data_async(
37
+ self,
38
+ image_data: List[Union[str, bytes]],
39
+ input_text,
40
+ request_obj,
41
+ max_req_input_len,
42
+ *args,
43
+ **kwargs,
44
+ ):
45
+ if not image_data:
46
+ return None
47
+ if isinstance(image_data, str):
48
+ image_data = [image_data]
49
+
50
+ base_output = self.load_mm_data(
51
+ prompt=input_text,
52
+ image_data=image_data,
53
+ multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
54
+ max_req_input_len=max_req_input_len,
55
+ )
56
+ ret = self.process_mm_data(
57
+ input_text=base_output.input_text,
58
+ images=base_output.images,
59
+ )
60
+ return {
61
+ "input_ids": ret["input_ids"].flatten().tolist(),
62
+ "mm_items": [
63
+ MultimodalDataItem(
64
+ pixel_values=ret["pixel_values"],
65
+ image_grid_thws=ret["image_grid_hws"],
66
+ modality=Modality.IMAGE,
67
+ )
68
+ ],
69
+ "im_token_id": self.im_token_id,
70
+ "im_start_id": self.im_start_id,
71
+ "im_end_id": self.im_end_id,
72
+ "im_content_id": self.im_content_id,
73
+ }
@@ -66,23 +66,24 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
66
66
  # Put some global args for easy access
67
67
  global_server_args_dict = {
68
68
  "attention_backend": ServerArgs.attention_backend,
69
- "sampling_backend": ServerArgs.sampling_backend,
70
- "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
71
- "torchao_config": ServerArgs.torchao_config,
72
- "enable_nan_detection": ServerArgs.enable_nan_detection,
73
- "enable_dp_attention": ServerArgs.enable_dp_attention,
74
- "enable_ep_moe": ServerArgs.enable_ep_moe,
75
- "enable_deepep_moe": ServerArgs.enable_deepep_moe,
69
+ "chunked_prefill_size": ServerArgs.chunked_prefill_size,
76
70
  "deepep_mode": ServerArgs.deepep_mode,
77
71
  "device": ServerArgs.device,
78
- "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
79
- "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
72
+ "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
80
73
  "disable_radix_cache": ServerArgs.disable_radix_cache,
74
+ "enable_deepep_moe": ServerArgs.enable_deepep_moe,
75
+ "enable_dp_attention": ServerArgs.enable_dp_attention,
76
+ "enable_ep_moe": ServerArgs.enable_ep_moe,
77
+ "enable_nan_detection": ServerArgs.enable_nan_detection,
81
78
  "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
79
+ "max_micro_batch_size": ServerArgs.max_micro_batch_size,
82
80
  "moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
83
- "chunked_prefill_size": ServerArgs.chunked_prefill_size,
84
81
  "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
85
- "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
82
+ "sampling_backend": ServerArgs.sampling_backend,
83
+ "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
84
+ "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
85
+ "torchao_config": ServerArgs.torchao_config,
86
+ "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
86
87
  }
87
88
 
88
89
  logger = logging.getLogger(__name__)
@@ -728,6 +729,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
728
729
  # Events
729
730
  launch_done: Optional[threading.Event] = None
730
731
 
732
+ # For chunked prefill in PP
733
+ chunked_req: Optional[Req] = None
734
+
731
735
  # Sampling info
732
736
  sampling_info: SamplingBatchInfo = None
733
737
  next_batch_sampling_info: SamplingBatchInfo = None
@@ -741,6 +745,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
741
745
  out_cache_loc: torch.Tensor = None # shape: [b], int64
742
746
  output_ids: torch.Tensor = None # shape: [b], int64
743
747
 
748
+ # For multimodal inputs
749
+ multimodal_inputs: Optional[List] = None
750
+
744
751
  # The sum of all sequence lengths
745
752
  seq_lens_sum: int = None
746
753
 
@@ -761,7 +768,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
761
768
  # For extend and mixed chunekd prefill
762
769
  prefix_lens: List[int] = None
763
770
  extend_lens: List[int] = None
764
- extend_num_tokens: int = None
771
+ extend_num_tokens: Optional[int] = None
765
772
  decoding_reqs: List[Req] = None
766
773
  extend_logprob_start_lens: List[int] = None
767
774
  # It comes empty list if logprob is not required.
@@ -803,6 +810,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
803
810
  enable_overlap: bool,
804
811
  spec_algorithm: SpeculativeAlgorithm,
805
812
  enable_custom_logit_processor: bool,
813
+ chunked_req: Optional[Req] = None,
806
814
  ):
807
815
  return_logprob = any(req.return_logprob for req in reqs)
808
816
 
@@ -820,6 +828,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
820
828
  spec_algorithm=spec_algorithm,
821
829
  enable_custom_logit_processor=enable_custom_logit_processor,
822
830
  return_hidden_states=any(req.return_hidden_states for req in reqs),
831
+ chunked_req=chunked_req,
823
832
  )
824
833
 
825
834
  def batch_size(self):
@@ -1044,6 +1053,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1044
1053
  # Copy prefix and do some basic check
1045
1054
  input_embeds = []
1046
1055
  extend_input_logprob_token_ids = []
1056
+ multimodal_inputs = []
1047
1057
 
1048
1058
  for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1049
1059
  req.req_pool_idx = req_pool_indices[i]
@@ -1059,6 +1069,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1059
1069
  # If req.input_embeds is already a list, append its content directly
1060
1070
  input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
1061
1071
 
1072
+ multimodal_inputs.append(req.multimodal_inputs)
1073
+
1062
1074
  req.cached_tokens += pre_len - req.already_computed
1063
1075
  req.already_computed = seq_len
1064
1076
  req.is_retracted = False
@@ -1141,6 +1153,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1141
1153
  if input_embeds
1142
1154
  else None
1143
1155
  )
1156
+ for mm_input in multimodal_inputs:
1157
+ if mm_input is None:
1158
+ continue
1159
+ for mm_item in mm_input.mm_items:
1160
+ pixel_values = getattr(mm_item, "pixel_values", None)
1161
+ if isinstance(pixel_values, torch.Tensor):
1162
+ mm_item.pixel_values = pixel_values.to(
1163
+ self.device, non_blocking=True
1164
+ )
1165
+ self.multimodal_inputs = multimodal_inputs
1144
1166
  self.seq_lens_sum = sum(seq_lens)
1145
1167
 
1146
1168
  if self.return_logprob:
@@ -1236,7 +1258,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1236
1258
 
1237
1259
  def retract_decode(self, server_args: ServerArgs):
1238
1260
  """Retract the decoding requests when there is not enough memory."""
1239
- sorted_indices = [i for i in range(len(self.reqs))]
1261
+ sorted_indices = list(range(len(self.reqs)))
1240
1262
 
1241
1263
  # TODO(lsyin): improve retraction policy for radix cache
1242
1264
  # For spec decoding, filter_batch API can only filter
@@ -1413,15 +1435,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1413
1435
 
1414
1436
  def filter_batch(
1415
1437
  self,
1416
- chunked_req_to_exclude: Optional[Req] = None,
1438
+ chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1417
1439
  keep_indices: Optional[List[int]] = None,
1418
1440
  ):
1419
1441
  if keep_indices is None:
1442
+ if isinstance(chunked_req_to_exclude, Req):
1443
+ chunked_req_to_exclude = [chunked_req_to_exclude]
1444
+ elif chunked_req_to_exclude is None:
1445
+ chunked_req_to_exclude = []
1420
1446
  keep_indices = [
1421
1447
  i
1422
1448
  for i in range(len(self.reqs))
1423
1449
  if not self.reqs[i].finished()
1424
- and self.reqs[i] is not chunked_req_to_exclude
1450
+ and not self.reqs[i] in chunked_req_to_exclude
1425
1451
  ]
1426
1452
 
1427
1453
  if keep_indices is None or len(keep_indices) == 0:
@@ -1442,6 +1468,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1442
1468
  self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
1443
1469
 
1444
1470
  self.reqs = [self.reqs[i] for i in keep_indices]
1471
+ self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1445
1472
  self.req_pool_indices = self.req_pool_indices[keep_indices_device]
1446
1473
  self.seq_lens = self.seq_lens[keep_indices_device]
1447
1474
  self.out_cache_loc = None
@@ -1490,6 +1517,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1490
1517
  self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1491
1518
  self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1492
1519
  self.reqs.extend(other.reqs)
1520
+ self.multimodal_inputs.extend(other.multimodal_inputs)
1493
1521
 
1494
1522
  self.return_logprob |= other.return_logprob
1495
1523
  self.has_stream |= other.has_stream
@@ -1548,7 +1576,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1548
1576
  extend_seq_lens=extend_seq_lens,
1549
1577
  extend_prefix_lens=extend_prefix_lens,
1550
1578
  extend_logprob_start_lens=extend_logprob_start_lens,
1551
- multimodal_inputs=[r.multimodal_inputs for r in self.reqs],
1579
+ multimodal_inputs=self.multimodal_inputs,
1552
1580
  encoder_cached=self.encoder_cached,
1553
1581
  encoder_lens=self.encoder_lens,
1554
1582
  encoder_lens_cpu=self.encoder_lens_cpu,
@@ -455,7 +455,10 @@ class PrefillAdder:
455
455
  total_tokens = req.extend_input_len + min(
456
456
  req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
457
457
  )
458
- input_tokens = req.extend_input_len
458
+ input_tokens = (
459
+ -(-req.extend_input_len // self.tree_cache.page_size)
460
+ * self.tree_cache.page_size
461
+ )
459
462
  prefix_len = len(req.prefix_indices)
460
463
 
461
464
  if total_tokens >= self.rem_total_tokens:
@@ -477,7 +480,10 @@ class PrefillAdder:
477
480
  req.last_node_global, req.prefix_indices
478
481
  )
479
482
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
480
- input_tokens = req.extend_input_len
483
+ input_tokens = (
484
+ -(-req.extend_input_len // self.tree_cache.page_size)
485
+ * self.tree_cache.page_size
486
+ )
481
487
  prefix_len = len(req.prefix_indices)
482
488
 
483
489
  if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
@@ -493,12 +499,12 @@ class PrefillAdder:
493
499
  ),
494
500
  )
495
501
  else:
496
- if self.rem_chunk_tokens == 0:
502
+ # Make sure at least one page is available
503
+ trunc_len = self.rem_chunk_tokens - self.tree_cache.page_size + 1
504
+ if trunc_len <= 0:
497
505
  return AddReqResult.OTHER
498
506
 
499
507
  # Chunked prefill
500
- trunc_len = self.rem_chunk_tokens
501
-
502
508
  req.extend_input_len = trunc_len
503
509
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
504
510