sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/configs/internvl.py +3 -0
  3. sglang/srt/configs/model_config.py +7 -0
  4. sglang/srt/constrained/base_grammar_backend.py +10 -2
  5. sglang/srt/constrained/xgrammar_backend.py +7 -5
  6. sglang/srt/conversation.py +16 -1
  7. sglang/srt/debug_utils/__init__.py +0 -0
  8. sglang/srt/debug_utils/dump_comparator.py +131 -0
  9. sglang/srt/debug_utils/dumper.py +108 -0
  10. sglang/srt/debug_utils/text_comparator.py +172 -0
  11. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  12. sglang/srt/disaggregation/mooncake/conn.py +16 -0
  13. sglang/srt/disaggregation/prefill.py +13 -1
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +13 -1
  16. sglang/srt/entrypoints/openai/protocol.py +3 -1
  17. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  18. sglang/srt/entrypoints/openai/serving_chat.py +132 -79
  19. sglang/srt/function_call/ebnf_composer.py +10 -3
  20. sglang/srt/function_call/function_call_parser.py +2 -0
  21. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  22. sglang/srt/function_call/qwen3_coder_detector.py +1 -0
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  24. sglang/srt/layers/attention/vision.py +56 -8
  25. sglang/srt/layers/layernorm.py +26 -1
  26. sglang/srt/layers/logits_processor.py +14 -3
  27. sglang/srt/layers/moe/ep_moe/layer.py +323 -242
  28. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  33. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  34. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  35. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  36. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  37. sglang/srt/layers/moe/topk.py +90 -24
  38. sglang/srt/layers/multimodal.py +11 -8
  39. sglang/srt/layers/quantization/fp8.py +25 -247
  40. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  41. sglang/srt/layers/quantization/modelopt_quant.py +27 -10
  42. sglang/srt/layers/quantization/unquant.py +24 -76
  43. sglang/srt/layers/quantization/w4afp8.py +68 -17
  44. sglang/srt/lora/lora_registry.py +93 -29
  45. sglang/srt/managers/cache_controller.py +9 -7
  46. sglang/srt/managers/data_parallel_controller.py +4 -0
  47. sglang/srt/managers/io_struct.py +12 -0
  48. sglang/srt/managers/mm_utils.py +154 -35
  49. sglang/srt/managers/multimodal_processor.py +3 -14
  50. sglang/srt/managers/schedule_batch.py +14 -8
  51. sglang/srt/managers/scheduler.py +64 -1
  52. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  53. sglang/srt/managers/tokenizer_manager.py +80 -15
  54. sglang/srt/managers/tp_worker.py +8 -0
  55. sglang/srt/mem_cache/hiradix_cache.py +5 -2
  56. sglang/srt/model_executor/model_runner.py +83 -27
  57. sglang/srt/models/deepseek_v2.py +75 -84
  58. sglang/srt/models/glm4_moe.py +1035 -0
  59. sglang/srt/models/glm4_moe_nextn.py +167 -0
  60. sglang/srt/models/interns1.py +328 -0
  61. sglang/srt/models/internvl.py +143 -47
  62. sglang/srt/models/llava.py +9 -5
  63. sglang/srt/models/minicpmo.py +4 -1
  64. sglang/srt/models/qwen2_moe.py +2 -2
  65. sglang/srt/models/qwen3_moe.py +17 -71
  66. sglang/srt/multimodal/processors/base_processor.py +20 -6
  67. sglang/srt/multimodal/processors/clip.py +2 -2
  68. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  69. sglang/srt/multimodal/processors/gemma3.py +2 -2
  70. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  71. sglang/srt/multimodal/processors/internvl.py +21 -8
  72. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  73. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  74. sglang/srt/multimodal/processors/llava.py +4 -4
  75. sglang/srt/multimodal/processors/minicpm.py +2 -3
  76. sglang/srt/multimodal/processors/mlama.py +2 -2
  77. sglang/srt/multimodal/processors/mllama4.py +18 -111
  78. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  79. sglang/srt/multimodal/processors/pixtral.py +2 -2
  80. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  81. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  82. sglang/srt/multimodal/processors/vila.py +3 -1
  83. sglang/srt/poll_based_barrier.py +31 -0
  84. sglang/srt/reasoning_parser.py +2 -1
  85. sglang/srt/server_args.py +65 -6
  86. sglang/srt/two_batch_overlap.py +8 -3
  87. sglang/srt/utils.py +96 -1
  88. sglang/srt/weight_sync/utils.py +119 -0
  89. sglang/test/runners.py +4 -0
  90. sglang/test/test_utils.py +118 -5
  91. sglang/utils.py +19 -0
  92. sglang/version.py +1 -1
  93. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
  94. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
  95. sglang/srt/debug_utils.py +0 -74
  96. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
  97. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
  98. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,100 @@
1
+ from typing import TYPE_CHECKING, Optional, Union
2
+
3
+ import torch
4
+
5
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
6
+ from sglang.srt.layers.radix_attention import RadixAttention
7
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
8
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
9
+
10
+
11
+ class HybridAttnBackend(AttentionBackend):
12
+ """Support different backends for prefill and decode."""
13
+
14
+ def __init__(
15
+ self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend
16
+ ):
17
+ self.prefill_backend = prefill_backend
18
+ self.decode_backend = decode_backend
19
+
20
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
21
+ if forward_batch.forward_mode.is_decode():
22
+ self.decode_backend.init_forward_metadata(forward_batch)
23
+ else:
24
+ self.prefill_backend.init_forward_metadata(forward_batch)
25
+
26
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
27
+ self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
28
+
29
+ def init_forward_metadata_capture_cuda_graph(
30
+ self,
31
+ bs: int,
32
+ num_tokens: int,
33
+ req_pool_indices: torch.Tensor,
34
+ seq_lens: torch.Tensor,
35
+ encoder_lens: Optional[torch.Tensor],
36
+ forward_mode: ForwardMode,
37
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
38
+ ):
39
+ self.decode_backend.init_forward_metadata_capture_cuda_graph(
40
+ bs,
41
+ num_tokens,
42
+ req_pool_indices,
43
+ seq_lens,
44
+ encoder_lens,
45
+ forward_mode,
46
+ spec_info,
47
+ )
48
+
49
+ def init_forward_metadata_replay_cuda_graph(
50
+ self,
51
+ bs: int,
52
+ req_pool_indices: torch.Tensor,
53
+ seq_lens: torch.Tensor,
54
+ seq_lens_sum: int,
55
+ encoder_lens: Optional[torch.Tensor],
56
+ forward_mode: ForwardMode,
57
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
58
+ seq_lens_cpu: Optional[torch.Tensor],
59
+ ):
60
+ self.decode_backend.init_forward_metadata_replay_cuda_graph(
61
+ bs,
62
+ req_pool_indices,
63
+ seq_lens,
64
+ seq_lens_sum,
65
+ encoder_lens,
66
+ forward_mode,
67
+ spec_info,
68
+ seq_lens_cpu,
69
+ )
70
+
71
+ def get_cuda_graph_seq_len_fill_value(self):
72
+ return self.decode_backend.get_cuda_graph_seq_len_fill_value()
73
+
74
+ def forward_decode(
75
+ self,
76
+ q: torch.Tensor,
77
+ k: torch.Tensor,
78
+ v: torch.Tensor,
79
+ layer: RadixAttention,
80
+ forward_batch: ForwardBatch,
81
+ save_kv_cache: bool = True,
82
+ **kwargs,
83
+ ):
84
+ return self.decode_backend.forward_decode(
85
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
86
+ )
87
+
88
+ def forward_extend(
89
+ self,
90
+ q: torch.Tensor,
91
+ k: torch.Tensor,
92
+ v: torch.Tensor,
93
+ layer: RadixAttention,
94
+ forward_batch: ForwardBatch,
95
+ save_kv_cache: bool = True,
96
+ **kwargs,
97
+ ):
98
+ return self.prefill_backend.forward_extend(
99
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
100
+ )
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import dataclasses
4
4
  import functools
5
5
  import math
6
- from functools import lru_cache
6
+ from functools import lru_cache, partial
7
7
  from typing import Any, Optional, Tuple, Union
8
8
 
9
9
  import torch
@@ -18,11 +18,16 @@ _is_cuda = is_cuda()
18
18
  if _is_cuda:
19
19
  from sgl_kernel.flash_attn import flash_attn_varlen_func
20
20
 
21
- from sglang.srt.distributed import parallel_state
21
+ from sglang.srt.distributed import (
22
+ parallel_state,
23
+ split_tensor_along_last_dim,
24
+ tensor_model_parallel_all_gather,
25
+ )
22
26
  from sglang.srt.distributed import utils as dist_utils
23
27
  from sglang.srt.layers.attention.triton_ops.prefill_attention import (
24
28
  context_attention_fwd,
25
29
  )
30
+ from sglang.srt.layers.layernorm import RMSNorm
26
31
  from sglang.srt.layers.linear import (
27
32
  ColumnParallelLinear,
28
33
  QKVParallelLinear,
@@ -349,25 +354,44 @@ class VisionAttention(nn.Module):
349
354
  flatten_batch: bool = False,
350
355
  prefix: str = "",
351
356
  proj_bias: bool = True,
357
+ num_dummy_heads: int = 0,
358
+ qkv_bias: bool = True,
359
+ qk_normalization: bool = False,
360
+ layer_norm_eps: float = 1e-06,
352
361
  **kwargs,
353
362
  ):
354
363
  super().__init__()
355
364
  world_size = parallel_state.get_tensor_model_parallel_world_size()
365
+ self.tp_size = world_size
366
+ self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
356
367
  self.dropout = dropout
357
368
  self.head_size = embed_dim // num_heads
358
369
  self.hidden_size_per_attention_head = dist_utils.divide(
359
370
  projection_size, num_heads
360
371
  )
361
372
  self.num_attention_heads_per_partition = dist_utils.divide(
362
- num_heads, world_size
373
+ num_dummy_heads + num_heads, world_size
363
374
  )
364
375
  self.num_attention_kv_heads_per_partition = dist_utils.divide(
365
- num_heads, world_size
376
+ num_dummy_heads + num_heads, world_size
366
377
  )
367
378
 
368
379
  self.q_size = self.num_attention_heads_per_partition * self.head_size
369
380
  self.kv_size = self.num_attention_kv_heads_per_partition * self.head_size
370
381
 
382
+ self.qk_normalization = qk_normalization
383
+
384
+ # Additional dummy heads are used to enable TP for common GPU counts.
385
+ self.dummy_dim = (num_dummy_heads + num_heads) * self.head_size
386
+
387
+ if self.qk_normalization:
388
+ self.q_norm = RMSNorm(
389
+ self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
390
+ )
391
+ self.k_norm = RMSNorm(
392
+ self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
393
+ )
394
+
371
395
  if global_server_args_dict["mm_attention_backend"] is None:
372
396
  if qkv_backend is None:
373
397
  qkv_backend = "sdpa"
@@ -391,26 +415,46 @@ class VisionAttention(nn.Module):
391
415
  self.qkv_proj = QKVParallelLinear(
392
416
  hidden_size=embed_dim,
393
417
  head_size=self.head_size,
394
- total_num_heads=num_heads,
395
- total_num_kv_heads=num_heads,
418
+ total_num_heads=num_dummy_heads + num_heads,
419
+ total_num_kv_heads=num_dummy_heads + num_heads,
420
+ bias=qkv_bias,
396
421
  quant_config=quant_config,
397
422
  prefix=add_prefix("qkv_proj", prefix),
398
423
  )
399
424
  else:
400
425
  self.qkv_proj = ColumnParallelLinear(
401
426
  input_size=embed_dim,
402
- output_size=3 * projection_size,
427
+ output_size=3 * self.dummy_dim,
428
+ bias=qkv_bias,
403
429
  quant_config=quant_config,
404
430
  prefix=add_prefix("qkv_proj", prefix),
405
431
  )
406
432
  self.proj = RowParallelLinear(
407
- input_size=embed_dim,
433
+ input_size=self.dummy_dim,
408
434
  output_size=embed_dim,
409
435
  bias=proj_bias,
410
436
  quant_config=quant_config,
411
437
  prefix=add_prefix("proj", prefix),
412
438
  )
413
439
 
440
+ def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
441
+ """apply qk norm for internvl vit attn"""
442
+ q = q.flatten(1, 2)
443
+ k = k.flatten(1, 2)
444
+
445
+ if self.tp_size > 1:
446
+ q = tensor_model_parallel_all_gather(q.contiguous())
447
+ k = tensor_model_parallel_all_gather(k.contiguous())
448
+ q = self.q_norm(q)
449
+ k = self.k_norm(k)
450
+ if self.tp_size > 1:
451
+ splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
452
+ q = splitter(q)[self.tp_rank]
453
+ k = splitter(k)[self.tp_rank]
454
+ q = q.unflatten(-1, (-1, self.head_size))
455
+ k = k.unflatten(-1, (-1, self.head_size))
456
+ return q, k
457
+
414
458
  def forward(
415
459
  self,
416
460
  x: torch.Tensor,
@@ -489,6 +533,10 @@ class VisionAttention(nn.Module):
489
533
  assert k.dim() == 3, k.dim()
490
534
  assert v.dim() == 3, v.dim()
491
535
 
536
+ # internvl
537
+ if self.qk_normalization:
538
+ q, k = self._apply_qk_norm(q, k)
539
+
492
540
  output = self.qkv_backend.forward(
493
541
  q=q,
494
542
  k=k,
@@ -61,10 +61,15 @@ class RMSNorm(CustomOp):
61
61
  self,
62
62
  hidden_size: int,
63
63
  eps: float = 1e-6,
64
+ var_hidden_size: Optional[int] = None,
64
65
  ) -> None:
65
66
  super().__init__()
66
67
  self.weight = nn.Parameter(torch.ones(hidden_size))
67
68
  self.variance_epsilon = eps
69
+ self.hidden_size = hidden_size
70
+ self.variance_size_override = (
71
+ None if var_hidden_size == hidden_size else var_hidden_size
72
+ )
68
73
  if _use_aiter:
69
74
  self._forward_method = self.forward_aiter
70
75
 
@@ -73,6 +78,8 @@ class RMSNorm(CustomOp):
73
78
  x: torch.Tensor,
74
79
  residual: Optional[torch.Tensor] = None,
75
80
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
81
+ if self.variance_size_override is not None:
82
+ return self.forward_native(x, residual)
76
83
  if residual is not None:
77
84
  fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
78
85
  return x, residual
@@ -138,7 +145,25 @@ class RMSNorm(CustomOp):
138
145
  x = x + residual.to(torch.float32)
139
146
  residual = x.to(orig_dtype)
140
147
 
141
- variance = x.pow(2).mean(dim=-1, keepdim=True)
148
+ hidden_size = x.shape[-1]
149
+ if hidden_size != self.hidden_size:
150
+ raise ValueError(
151
+ "Expected hidden_size to be "
152
+ f"{self.hidden_size}, but found: {hidden_size}"
153
+ )
154
+
155
+ if self.variance_size_override is None:
156
+ x_var = x
157
+ else:
158
+ if hidden_size < self.variance_size_override:
159
+ raise ValueError(
160
+ "Expected hidden_size to be at least "
161
+ f"{self.variance_size_override}, but found: {hidden_size}"
162
+ )
163
+
164
+ x_var = x[..., : self.variance_size_override]
165
+
166
+ variance = x_var.pow(2).mean(dim=-1, keepdim=True)
142
167
  x = x * torch.rsqrt(variance + self.variance_epsilon)
143
168
  x = (x * self.weight).to(orig_dtype)
144
169
  if residual is None:
@@ -170,8 +170,6 @@ class LogitsMetadata:
170
170
  )
171
171
 
172
172
  def compute_dp_attention_metadata(self):
173
- # TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
174
- # we may use a smaller buffer in draft extend.
175
173
 
176
174
  cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
177
175
  dp_rank = get_attention_dp_rank()
@@ -186,6 +184,19 @@ class LogitsMetadata:
186
184
  self.dp_local_start_pos = dp_local_start_pos
187
185
  self.dp_local_num_tokens = dp_local_num_tokens
188
186
 
187
+ if self.global_num_tokens_for_logprob_cpu is not None:
188
+ # create a smaller buffer to reduce peak memory usage
189
+ self.gathered_buffer = torch.empty(
190
+ (
191
+ sum(self.global_num_tokens_for_logprob_cpu),
192
+ self.gathered_buffer.shape[1],
193
+ ),
194
+ dtype=self.gathered_buffer.dtype,
195
+ device=self.gathered_buffer.device,
196
+ )
197
+ else:
198
+ self.gathered_buffer = torch.empty_like(self.gathered_buffer)
199
+
189
200
 
190
201
  class LogitsProcessor(nn.Module):
191
202
  def __init__(
@@ -430,7 +441,7 @@ class LogitsProcessor(nn.Module):
430
441
  if self.do_tensor_parallel_all_gather_dp_attn:
431
442
  logits_metadata.compute_dp_attention_metadata()
432
443
  hidden_states, local_hidden_states = (
433
- torch.empty_like(logits_metadata.gathered_buffer),
444
+ logits_metadata.gathered_buffer,
434
445
  hidden_states,
435
446
  )
436
447
  dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)