sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -271,12 +271,13 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
271
271
  batch,
272
272
  dp_size=model_runner.server_args.dp_size,
273
273
  attn_tp_size=1,
274
- tp_cpu_group=model_runner.tp_group.cpu_group,
274
+ tp_group=model_runner.tp_group,
275
275
  get_idle_batch=None,
276
276
  disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
277
277
  spec_algorithm=SpeculativeAlgorithm.NONE,
278
278
  speculative_num_draft_tokens=None,
279
279
  require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
280
+ disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
280
281
  )
281
282
 
282
283
 
@@ -73,6 +73,8 @@ async def benchmark(args):
73
73
 
74
74
  tasks: List[asyncio.Task] = []
75
75
  for idx, ex in enumerate(dataset):
76
+ if idx >= args.num_prompts:
77
+ break
76
78
  tasks.append(
77
79
  asyncio.create_task(
78
80
  fetch_response(
@@ -103,6 +105,8 @@ def analyse(args):
103
105
  hyps: List[str] = []
104
106
  refs: List[str] = []
105
107
  for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
108
+ if idx >= args.num_prompts:
109
+ break
106
110
  pkl_file = output_dir / f"response_{idx}.pkl"
107
111
  if not pkl_file.exists():
108
112
  raise FileNotFoundError(pkl_file)
@@ -150,6 +154,9 @@ if __name__ == "__main__":
150
154
  parser.add_argument(
151
155
  "--output-dir", default="tmp-output-dir", help="Directory for cached responses"
152
156
  )
157
+ parser.add_argument(
158
+ "--num-prompts", type=int, default=10000, help="Number of prompts to run"
159
+ )
153
160
  args = parser.parse_args()
154
161
 
155
162
  asyncio.run(benchmark(args))
sglang/srt/_custom_ops.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
2
2
  import logging
3
- from typing import List, Tuple
3
+ from typing import List, Optional, Tuple
4
4
 
5
5
  import torch
6
6
 
@@ -114,6 +114,34 @@ else:
114
114
  def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
115
115
  return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
116
116
 
117
+ # ROCM custom quick allreduce
118
+
119
+ def init_custom_qr(
120
+ rank: int, world_size: int, qr_max_size: Optional[int] = None
121
+ ) -> int:
122
+ return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size)
123
+
124
+ def qr_get_handle(fa: int) -> torch.Tensor:
125
+ return sgl_kernel.allreduce.qr_get_handle(fa)
126
+
127
+ def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
128
+ sgl_kernel.allreduce.qr_open_handles(fa, handles)
129
+
130
+ def qr_all_reduce(
131
+ fa: int,
132
+ inp: torch.Tensor,
133
+ out: torch.Tensor,
134
+ quant_level: int,
135
+ cast_bf2half: bool,
136
+ ) -> None:
137
+ sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)
138
+
139
+ def qr_destroy(fa: int) -> None:
140
+ sgl_kernel.allreduce.qr_destroy(fa)
141
+
142
+ def qr_max_size() -> int:
143
+ return sgl_kernel.allreduce.qr_max_size()
144
+
117
145
 
118
146
  def mscclpp_generate_unique_id() -> bytes:
119
147
  return sgl_kernel.allreduce.mscclpp_generate_unique_id()
@@ -42,6 +42,9 @@ def select_best_resolution(image_size, candidate_resolutions):
42
42
 
43
43
 
44
44
  class DictOutput(object):
45
+ def items(self):
46
+ return self.__dict__.items()
47
+
45
48
  def keys(self):
46
49
  return self.__dict__.keys()
47
50
 
@@ -59,7 +62,9 @@ class DictOutput(object):
59
62
  class VLChatProcessorOutput(DictOutput):
60
63
  input_ids: torch.LongTensor
61
64
  target_ids: torch.LongTensor
62
- images: torch.Tensor
65
+ pixel_values: (
66
+ torch.Tensor
67
+ ) # rename from "images" to "pixel_values" for compatibility
63
68
  images_seq_mask: torch.BoolTensor
64
69
  images_spatial_crop: torch.LongTensor
65
70
 
@@ -312,10 +317,14 @@ class DeepseekVLV2Processor(ProcessorMixin):
312
317
  images = torch.stack(images_list, dim=0)
313
318
  images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
314
319
 
320
+ images_spatial_crop = torch.stack(
321
+ [images_spatial_crop], dim=0
322
+ ) # stack the tensor to make it a batch of 1
323
+
315
324
  prepare = VLChatProcessorOutput(
316
325
  input_ids=input_ids,
317
326
  target_ids=target_ids,
318
- images=images,
327
+ pixel_values=images,
319
328
  images_seq_mask=images_seq_mask,
320
329
  images_spatial_crop=images_spatial_crop,
321
330
  )
@@ -9,6 +9,7 @@ from transformers import (
9
9
  LlamaConfig,
10
10
  PretrainedConfig,
11
11
  PreTrainedTokenizer,
12
+ Qwen2Config,
12
13
  )
13
14
 
14
15
  from sglang.utils import logger
@@ -311,6 +312,8 @@ class InternVLChatConfig(PretrainedConfig):
311
312
  self.llm_config = LlamaConfig(**llm_config)
312
313
  elif llm_config.get("architectures")[0] == "InternLM2ForCausalLM":
313
314
  self.llm_config = InternLM2Config(**llm_config)
315
+ elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
316
+ self.llm_config = Qwen2Config(**llm_config)
314
317
  else:
315
318
  raise ValueError(
316
319
  "Unsupported architecture: {}".format(
@@ -284,6 +284,9 @@ class VLMImageProcessor(BaseImageProcessor):
284
284
 
285
285
 
286
286
  class DictOutput(object):
287
+ def items(self):
288
+ return self.__dict__.items()
289
+
287
290
  def keys(self):
288
291
  return self.__dict__.keys()
289
292
 
@@ -53,7 +53,7 @@ class ModelConfig:
53
53
  trust_remote_code: bool = True,
54
54
  revision: Optional[str] = None,
55
55
  context_length: Optional[int] = None,
56
- model_override_args: Optional[str] = None,
56
+ model_override_args: str = "{}",
57
57
  is_embedding: Optional[bool] = None,
58
58
  enable_multimodal: Optional[bool] = None,
59
59
  dtype: str = "auto",
@@ -61,13 +61,13 @@ class ModelConfig:
61
61
  override_config_file: Optional[str] = None,
62
62
  is_draft_model: bool = False,
63
63
  hybrid_kvcache_ratio: Optional[float] = None,
64
- impl: Union[str, ModelImpl] = ModelImpl.AUTO,
64
+ model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
65
65
  ) -> None:
66
66
 
67
67
  self.model_path = model_path
68
68
  self.revision = revision
69
69
  self.quantization = quantization
70
- self.impl = impl
70
+ self.model_impl = model_impl
71
71
 
72
72
  # Parse args
73
73
  self.maybe_pull_model_tokenizer_from_remote()
@@ -286,7 +286,7 @@ class ModelConfig:
286
286
  dtype=server_args.dtype,
287
287
  quantization=server_args.quantization,
288
288
  hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
289
- impl=server_args.impl,
289
+ model_impl=server_args.model_impl,
290
290
  **kwargs,
291
291
  )
292
292
 
@@ -391,6 +391,7 @@ class ModelConfig:
391
391
  "compressed-tensors",
392
392
  "fbgemm_fp8",
393
393
  "w8a8_fp8",
394
+ "petit_nvfp4",
394
395
  ]
395
396
  optimized_quantization_methods = [
396
397
  "fp8",
@@ -408,9 +409,11 @@ class ModelConfig:
408
409
  "moe_wna16",
409
410
  "qoq",
410
411
  "w4afp8",
412
+ "petit_nvfp4",
411
413
  ]
412
414
  compatible_quantization_methods = {
413
415
  "modelopt_fp4": ["modelopt"],
416
+ "petit_nvfp4": ["modelopt"],
414
417
  "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
415
418
  "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
416
419
  }
@@ -472,7 +475,7 @@ class ModelConfig:
472
475
 
473
476
  def get_hf_eos_token_id(self) -> Optional[Set[int]]:
474
477
  eos_ids = getattr(self.hf_config, "eos_token_id", None)
475
- if eos_ids:
478
+ if eos_ids is not None:
476
479
  # it can be either int or list of int
477
480
  eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
478
481
  if eos_ids is None:
@@ -711,7 +714,6 @@ def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int)
711
714
  i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
712
715
  ]
713
716
  else:
714
- raise ValueError(
715
- "get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration"
716
- )
717
+ swa_attention_layer_ids = None
718
+ full_attention_layer_ids = None
717
719
  return swa_attention_layer_ids, full_attention_layer_ids
@@ -115,5 +115,7 @@ def adjust_config_with_unaligned_cpu_tp(
115
115
  model_config = update_intermediate_size(
116
116
  model_config, "intermediate_size", intermediate_padding_size
117
117
  )
118
-
118
+ model_config = update_intermediate_size(
119
+ model_config, "intermediate_size_mlp", intermediate_padding_size
120
+ )
119
121
  return model_config
@@ -729,6 +729,7 @@ register_conv_template(
729
729
  sep="<|end|>",
730
730
  stop_str="<|end|>",
731
731
  image_token="<|endoftext10|>",
732
+ audio_token="<|endoftext11|>",
732
733
  )
733
734
  )
734
735
 
@@ -983,7 +984,7 @@ register_conv_template(
983
984
 
984
985
  @register_conv_template_matching_function
985
986
  def match_internvl(model_path: str):
986
- if re.search(r"internvl2_5", model_path, re.IGNORECASE):
987
+ if re.search(r"internvl", model_path, re.IGNORECASE):
987
988
  return "internvl-2-5"
988
989
 
989
990
 
sglang/srt/custom_op.py CHANGED
@@ -29,15 +29,18 @@ class CustomOp(nn.Module):
29
29
 
30
30
  self._original_forward_method = self._forward_method
31
31
  # NOTE: Temporarily workaround MoE
32
+ # The performance of torch.compile on this layer is not always good when bs > 1,
33
+ # so we decide to only use torch.compile when bs=1
32
34
  if "FusedMoE" in self.__class__.__name__:
33
35
  if num_tokens == 1:
34
36
  from sglang.srt.layers.moe.fused_moe_native import (
35
37
  fused_moe_forward_native,
36
38
  )
37
39
 
38
- # The performance of torch.compile on this layer is not always good when bs > 1,
39
- # so we decide to only use torch.compile when bs =1
40
40
  self._forward_method = fused_moe_forward_native
41
+ elif "TopK" in self.__class__.__name__:
42
+ if num_tokens == 1:
43
+ self._forward_method = self.forward_native
41
44
  else:
42
45
  self._forward_method = self.forward_native
43
46
  self.is_torch_compile = True
@@ -23,7 +23,14 @@ from sglang.srt.disaggregation.base.conn import (
23
23
  )
24
24
  from sglang.srt.disaggregation.utils import DisaggregationMode
25
25
  from sglang.srt.server_args import ServerArgs
26
- from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
26
+ from sglang.srt.utils import (
27
+ format_tcp_address,
28
+ get_free_port,
29
+ get_ip,
30
+ get_local_ip_by_remote,
31
+ is_valid_ipv6_address,
32
+ maybe_wrap_ipv6_address,
33
+ )
27
34
 
28
35
  logger = logging.getLogger(__name__)
29
36
 
@@ -65,11 +72,18 @@ class CommonKVManager(BaseKVManager):
65
72
  def _register_to_bootstrap(self):
66
73
  """Register KVSender to bootstrap server via HTTP POST."""
67
74
  if self.dist_init_addr:
68
- ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
75
+ if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
76
+ if self.dist_init_addr.endswith("]"):
77
+ host = self.dist_init_addr
78
+ else:
79
+ host, _ = self.dist_init_addr.rsplit(":", 1)
80
+ else:
81
+ host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
69
82
  else:
70
- ip_address = get_ip()
83
+ host = get_ip()
84
+ host = maybe_wrap_ipv6_address(host)
71
85
 
72
- bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
86
+ bootstrap_server_url = f"{host}:{self.bootstrap_port}"
73
87
  url = f"http://{bootstrap_server_url}/route"
74
88
  payload = {
75
89
  "role": "Prefill",
@@ -92,8 +106,10 @@ class CommonKVManager(BaseKVManager):
92
106
  logger.error(f"Prefill Failed to register to bootstrap server: {e}")
93
107
 
94
108
  @cache
95
- def _connect(self, endpoint: str):
109
+ def _connect(self, endpoint: str, is_ipv6: bool = False):
96
110
  socket = zmq.Context().socket(zmq.PUSH)
111
+ if is_ipv6:
112
+ socket.setsockopt(zmq.IPV6, 1)
97
113
  socket.connect(endpoint)
98
114
  return socket
99
115
 
@@ -263,15 +279,27 @@ class CommonKVReceiver(BaseKVReceiver):
263
279
  return None
264
280
 
265
281
  @classmethod
266
- def _connect(cls, endpoint: str):
282
+ def _connect(cls, endpoint: str, is_ipv6: bool = False):
267
283
  with cls._global_lock:
268
284
  if endpoint not in cls._socket_cache:
269
285
  sock = cls._ctx.socket(zmq.PUSH)
286
+ if is_ipv6:
287
+ sock.setsockopt(zmq.IPV6, 1)
270
288
  sock.connect(endpoint)
271
289
  cls._socket_cache[endpoint] = sock
272
290
  cls._socket_locks[endpoint] = threading.Lock()
273
291
  return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
274
292
 
293
+ @classmethod
294
+ def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
295
+ ip_address = bootstrap_info["rank_ip"]
296
+ port = bootstrap_info["rank_port"]
297
+ is_ipv6_address = is_valid_ipv6_address(ip_address)
298
+ sock, lock = cls._connect(
299
+ format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
300
+ )
301
+ return sock, lock
302
+
275
303
  def _register_kv_args(self):
276
304
  pass
277
305
 
@@ -439,7 +439,15 @@ class DecodePreallocQueue:
439
439
  else 0
440
440
  )
441
441
 
442
- allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max(
442
+ if self.scheduler.model_config.is_hybrid:
443
+ available_size = min(
444
+ self.token_to_kv_pool_allocator.full_available_size(),
445
+ self.token_to_kv_pool_allocator.swa_available_size(),
446
+ )
447
+ else:
448
+ available_size = self.token_to_kv_pool_allocator.available_size()
449
+
450
+ allocatable_tokens = available_size - max(
443
451
  # preserve some space for future decode
444
452
  self.num_reserved_decode_tokens
445
453
  * (
@@ -17,6 +17,7 @@ from fastapi import FastAPI, HTTPException
17
17
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
18
18
 
19
19
  from sglang.srt.disaggregation.utils import PDRegistryRequest
20
+ from sglang.srt.utils import maybe_wrap_ipv6_address
20
21
 
21
22
  AIOHTTP_STREAM_READ_CHUNK_SIZE = (
22
23
  1024 * 64
@@ -271,7 +272,7 @@ async def handle_generate_request(request_data: dict):
271
272
 
272
273
  # Parse and transform prefill_server for bootstrap data
273
274
  parsed_url = urllib.parse.urlparse(prefill_server)
274
- hostname = parsed_url.hostname
275
+ hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
275
276
  modified_request = request_data.copy()
276
277
 
277
278
  batch_size = _get_request_batch_size(modified_request)
@@ -309,7 +310,7 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str):
309
310
 
310
311
  # Parse and transform prefill_server for bootstrap data
311
312
  parsed_url = urllib.parse.urlparse(prefill_server)
312
- hostname = parsed_url.hostname
313
+ hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
313
314
  modified_request = request_data.copy()
314
315
  modified_request.update(
315
316
  {