sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +2 -0
  3. sglang/srt/configs/longcat_flash.py +104 -0
  4. sglang/srt/configs/model_config.py +14 -1
  5. sglang/srt/connector/__init__.py +1 -1
  6. sglang/srt/connector/base_connector.py +1 -2
  7. sglang/srt/connector/redis.py +2 -2
  8. sglang/srt/connector/serde/__init__.py +1 -1
  9. sglang/srt/connector/serde/safe_serde.py +4 -3
  10. sglang/srt/disaggregation/ascend/conn.py +75 -0
  11. sglang/srt/disaggregation/launch_lb.py +0 -13
  12. sglang/srt/disaggregation/mini_lb.py +33 -8
  13. sglang/srt/disaggregation/prefill.py +1 -1
  14. sglang/srt/distributed/parallel_state.py +27 -15
  15. sglang/srt/entrypoints/engine.py +19 -12
  16. sglang/srt/entrypoints/http_server.py +174 -34
  17. sglang/srt/entrypoints/openai/protocol.py +60 -0
  18. sglang/srt/eplb/eplb_manager.py +26 -2
  19. sglang/srt/eplb/expert_distribution.py +29 -2
  20. sglang/srt/hf_transformers_utils.py +10 -0
  21. sglang/srt/layers/activation.py +12 -0
  22. sglang/srt/layers/attention/ascend_backend.py +240 -109
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  24. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  25. sglang/srt/layers/layernorm.py +28 -3
  26. sglang/srt/layers/linear.py +3 -2
  27. sglang/srt/layers/logits_processor.py +1 -1
  28. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  29. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  30. sglang/srt/layers/moe/ep_moe/layer.py +14 -13
  31. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  32. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  34. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  37. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  38. sglang/srt/layers/moe/topk.py +35 -12
  39. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  40. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  41. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  42. sglang/srt/layers/quantization/mxfp4.py +9 -4
  43. sglang/srt/layers/quantization/utils.py +13 -0
  44. sglang/srt/layers/quantization/w4afp8.py +30 -25
  45. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  46. sglang/srt/layers/rotary_embedding.py +28 -1
  47. sglang/srt/layers/sampler.py +29 -5
  48. sglang/srt/managers/cache_controller.py +62 -96
  49. sglang/srt/managers/detokenizer_manager.py +9 -2
  50. sglang/srt/managers/io_struct.py +27 -0
  51. sglang/srt/managers/mm_utils.py +5 -1
  52. sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
  53. sglang/srt/managers/scheduler.py +39 -2
  54. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  55. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  56. sglang/srt/managers/tokenizer_manager.py +86 -39
  57. sglang/srt/mem_cache/chunk_cache.py +1 -1
  58. sglang/srt/mem_cache/hicache_storage.py +20 -3
  59. sglang/srt/mem_cache/hiradix_cache.py +94 -71
  60. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  61. sglang/srt/mem_cache/memory_pool.py +4 -0
  62. sglang/srt/mem_cache/memory_pool_host.py +4 -4
  63. sglang/srt/mem_cache/radix_cache.py +5 -4
  64. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  65. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  66. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
  67. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
  68. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  69. sglang/srt/model_executor/model_runner.py +5 -4
  70. sglang/srt/model_loader/loader.py +15 -24
  71. sglang/srt/model_loader/utils.py +12 -0
  72. sglang/srt/models/deepseek_v2.py +31 -10
  73. sglang/srt/models/gpt_oss.py +5 -18
  74. sglang/srt/models/llama_eagle3.py +4 -0
  75. sglang/srt/models/longcat_flash.py +1026 -0
  76. sglang/srt/models/longcat_flash_nextn.py +699 -0
  77. sglang/srt/models/qwen2.py +26 -3
  78. sglang/srt/models/qwen2_5_vl.py +65 -41
  79. sglang/srt/models/qwen2_moe.py +22 -2
  80. sglang/srt/models/transformers.py +1 -1
  81. sglang/srt/multimodal/processors/base_processor.py +4 -2
  82. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  83. sglang/srt/server_args.py +112 -55
  84. sglang/srt/speculative/eagle_worker.py +28 -8
  85. sglang/srt/utils.py +4 -0
  86. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  87. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  88. sglang/version.py +1 -1
  89. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
  90. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
  91. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
  92. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
  93. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@
16
16
  # Modify details for the adaptation of Qwen2 model.
17
17
  """Inference-only Qwen2 model compatible with HuggingFace weights."""
18
18
  import logging
19
- from typing import Any, Dict, Iterable, Optional, Tuple, Union
19
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
20
20
 
21
21
  import torch
22
22
  from torch import nn
@@ -431,7 +431,6 @@ class Qwen2ForCausalLM(nn.Module):
431
431
  quant_config=quant_config,
432
432
  prefix=add_prefix("lm_head", prefix),
433
433
  )
434
-
435
434
  else:
436
435
  # ranks other than the last rank will have a placeholder layer
437
436
  self.lm_head = PPMissingLayer()
@@ -452,6 +451,8 @@ class Qwen2ForCausalLM(nn.Module):
452
451
 
453
452
  self.logits_processor = LogitsProcessor(config)
454
453
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
454
+ # For EAGLE3 support
455
+ self.capture_aux_hidden_states = False
455
456
 
456
457
  def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
457
458
  return self.model.get_input_embedding(input_ids)
@@ -476,11 +477,18 @@ class Qwen2ForCausalLM(nn.Module):
476
477
  input_embeds,
477
478
  pp_proxy_tensors=pp_proxy_tensors,
478
479
  )
480
+ aux_hidden_states = None
481
+ if self.capture_aux_hidden_states:
482
+ hidden_states, aux_hidden_states = hidden_states
479
483
 
480
484
  if self.pp_group.is_last_rank:
481
485
  if not get_embedding:
482
486
  return self.logits_processor(
483
- input_ids, hidden_states, self.lm_head, forward_batch
487
+ input_ids,
488
+ hidden_states,
489
+ self.lm_head,
490
+ forward_batch,
491
+ aux_hidden_states,
484
492
  )
485
493
  else:
486
494
  return self.pooler(hidden_states, forward_batch)
@@ -619,5 +627,20 @@ class Qwen2ForCausalLM(nn.Module):
619
627
  def load_kv_cache_scales(self, quantization_param_path: str) -> None:
620
628
  self.model.load_kv_cache_scales(quantization_param_path)
621
629
 
630
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
631
+ if not self.pp_group.is_last_rank:
632
+ return
633
+
634
+ self.capture_aux_hidden_states = True
635
+ if layer_ids is None:
636
+ num_layers = self.config.num_hidden_layers
637
+ self.model.layers_to_capture = [
638
+ 2,
639
+ num_layers // 2,
640
+ num_layers - 3,
641
+ ] # Specific layers for EAGLE3 support
642
+ else:
643
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
644
+
622
645
 
623
646
  EntryClass = Qwen2ForCausalLM
@@ -31,7 +31,6 @@ import torch.nn as nn
31
31
  import torch.nn.functional as F
32
32
  from einops import rearrange
33
33
  from transformers.activations import ACT2FN
34
- from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
35
34
  from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
36
35
  Qwen2_5_VLConfig,
37
36
  Qwen2_5_VLVisionConfig,
@@ -43,7 +42,12 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
43
42
 
44
43
  from sglang.srt.hf_transformers_utils import get_processor
45
44
  from sglang.srt.layers.attention.vision import VisionAttention
46
- from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
45
+ from sglang.srt.layers.layernorm import RMSNorm
46
+ from sglang.srt.layers.linear import (
47
+ ColumnParallelLinear,
48
+ MergedColumnParallelLinear,
49
+ RowParallelLinear,
50
+ )
47
51
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
52
  from sglang.srt.layers.pooler import Pooler, PoolingType
49
53
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -62,7 +66,6 @@ logger = logging.getLogger(__name__)
62
66
 
63
67
 
64
68
  class Qwen2_5_VLMLP(nn.Module):
65
-
66
69
  def __init__(
67
70
  self,
68
71
  in_features: int,
@@ -73,19 +76,12 @@ class Qwen2_5_VLMLP(nn.Module):
73
76
  prefix: str = "",
74
77
  ):
75
78
  super().__init__()
76
- self.gate_proj = ColumnParallelLinear(
77
- in_features,
78
- hidden_features,
79
- bias=bias,
80
- quant_config=quant_config,
81
- prefix=add_prefix("gate_proj", prefix),
82
- )
83
- self.up_proj = ColumnParallelLinear(
84
- in_features,
85
- hidden_features,
79
+ self.gate_up_proj = MergedColumnParallelLinear(
80
+ input_size=in_features,
81
+ output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
86
82
  bias=bias,
87
83
  quant_config=quant_config,
88
- prefix=add_prefix("up_proj", prefix),
84
+ prefix=add_prefix("gate_up_proj", prefix),
89
85
  )
90
86
  self.down_proj = RowParallelLinear(
91
87
  hidden_features,
@@ -97,12 +93,11 @@ class Qwen2_5_VLMLP(nn.Module):
97
93
  self.act = ACT2FN[hidden_act]
98
94
 
99
95
  def forward(self, x: torch.Tensor) -> torch.Tensor:
100
- x_parallel_gate, _ = self.gate_proj(x)
101
- x_parallel_gate = self.act(x_parallel_gate)
102
- x_parallel_up, _ = self.up_proj(x)
103
- x_parallel = x_parallel_gate * x_parallel_up
104
- x, _ = self.down_proj(x_parallel)
105
- return x
96
+ gate_up, _ = self.gate_up_proj(x)
97
+ gate, up = gate_up.chunk(2, dim=-1)
98
+ x = self.act(gate) * up
99
+ x_down, _ = self.down_proj(x)
100
+ return x_down
106
101
 
107
102
 
108
103
  class Qwen2_5_VisionBlock(nn.Module):
@@ -122,8 +117,8 @@ class Qwen2_5_VisionBlock(nn.Module):
122
117
  super().__init__()
123
118
  if norm_layer is None:
124
119
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
125
- self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
126
- self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
120
+ self.norm1 = RMSNorm(dim, eps=1e-6)
121
+ self.norm2 = RMSNorm(dim, eps=1e-6)
127
122
 
128
123
  if attn_implementation is None:
129
124
  softmax_in_single_precision = False
@@ -174,18 +169,29 @@ class Qwen2_5_VisionBlock(nn.Module):
174
169
  cu_seqlens: torch.Tensor,
175
170
  position_embeddings: torch.Tensor,
176
171
  ) -> torch.Tensor:
177
- hidden_states = self.norm1(x)
178
- hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
172
+ S, B, H = x.shape
173
+ # norm1: flatten to 2D -> [S*B, H], then reshape back
174
+ x2d = x.reshape(-1, H)
175
+ hidden_states = self.norm1(x2d).reshape(S, B, H)
176
+
177
+ # Attention expects [B, S, H]
178
+ hidden_states = rearrange(hidden_states, "s b h -> b s h")
179
179
  attn = self.attn(
180
180
  hidden_states,
181
181
  cu_seqlens=cu_seqlens,
182
182
  position_embeddings=position_embeddings,
183
183
  )
184
- attn = rearrange(attn, "b s ... -> s b ...")
185
- x = x + attn
186
- norm2 = self.norm2(x)
187
- mlp = self.mlp(norm2)
188
- x = x + mlp
184
+ attn = rearrange(attn, "b s h -> s b h")
185
+
186
+ # norm2 with fused residual-add: also 2D
187
+ attn2d = attn.reshape(-1, H)
188
+ x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d)
189
+ x_norm = x_norm_2d.reshape(S, B, H)
190
+ x_after_add = x_after_add_2d.reshape(S, B, H)
191
+
192
+ # MLP and final residual
193
+ mlp_out = self.mlp(x_norm)
194
+ x = x_after_add + mlp_out
189
195
  return x
190
196
 
191
197
 
@@ -201,7 +207,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
201
207
  ) -> None:
202
208
  super().__init__()
203
209
  self.hidden_size = context_dim * (spatial_merge_size**2)
204
- self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
210
+ self.ln_q = RMSNorm(context_dim, eps=1e-6)
205
211
  self.mlp = nn.ModuleList(
206
212
  [
207
213
  ColumnParallelLinear(
@@ -223,11 +229,13 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
223
229
  )
224
230
 
225
231
  def forward(self, x: torch.Tensor) -> torch.Tensor:
226
- x = self.ln_q(x)
227
- x = x.view(-1, self.hidden_size)
228
-
232
+ # x expected shape: [S, B, context_dim]
233
+ S, B, D = x.shape
234
+ x2d = x.reshape(-1, D)
235
+ x2d = self.ln_q(x2d) # RMSNorm expects 2D
236
+ x2d = x2d.view(-1, self.hidden_size) # group into spatial_merge_unit
229
237
  mlp_fc1, mlp_act, mlp_fc2 = self.mlp
230
- x_parallel, _ = mlp_fc1(x)
238
+ x_parallel, _ = mlp_fc1(x2d)
231
239
  x_parallel = mlp_act(x_parallel)
232
240
  out, _ = mlp_fc2(x_parallel)
233
241
  return out
@@ -340,7 +348,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
340
348
 
341
349
  @property
342
350
  def device(self) -> torch.device:
343
- return self.blocks[0].mlp.gate_proj.weight.device
351
+ return self.patch_embed.proj.weight.device
344
352
 
345
353
  def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
346
354
  pos_ids = []
@@ -394,6 +402,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
394
402
  )
395
403
  cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
396
404
 
405
+ # Move window_index to the same device as x before using it to index x
406
+ window_index = window_index.to(device=x.device)
407
+
408
+ # Ensure rotary_pos_emb is on the same device/dtype as x
409
+ rotary_pos_emb = rotary_pos_emb.to(device=x.device, dtype=x.dtype)
410
+
397
411
  seq_len, _ = x.size()
398
412
 
399
413
  x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
@@ -406,12 +420,19 @@ class Qwen2_5_VisionTransformer(nn.Module):
406
420
  rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
407
421
  emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
408
422
  position_embeddings = (emb.cos(), emb.sin())
423
+ # After building position_embeddings, make sure both cos and sin are on the same device/dtype as the attention input
424
+ position_embeddings = (
425
+ position_embeddings[0].to(x.device, x.dtype),
426
+ position_embeddings[1].to(x.device, x.dtype),
427
+ )
409
428
 
410
- # compute cu_seqlens
429
+ # compute cu_seqlens - move cu_seqlens to GPU and make it int32
411
430
  cu_seqlens = torch.cat(
412
431
  [
413
- torch.tensor([0], device=grid_thw.device),
414
- (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
432
+ torch.tensor([0], device=x.device, dtype=torch.int32),
433
+ (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2])
434
+ .cumsum(dim=0)
435
+ .to(device=x.device, dtype=torch.int32),
415
436
  ]
416
437
  )
417
438
  cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
@@ -442,9 +463,8 @@ cached_get_processor = lru_cache(get_processor)
442
463
  class Qwen2_5_VLForConditionalGeneration(nn.Module):
443
464
  # BitandBytes specific attributes
444
465
  default_bitsandbytes_target_modules = [
445
- ".gate_proj.",
466
+ ".gate_up_proj.",
446
467
  ".down_proj.",
447
- ".up_proj.",
448
468
  ".q_proj.",
449
469
  ".k_proj.",
450
470
  ".v_proj.",
@@ -591,7 +611,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
591
611
  for param_name, weight_name, shard_id in stacked_params_mapping:
592
612
  if weight_name not in name:
593
613
  continue
594
- if "visual" in name:
614
+ if (
615
+ "visual" in name
616
+ and "up_proj" not in name
617
+ and "gate_proj" not in name
618
+ ):
595
619
  continue
596
620
  name = name.replace(weight_name, param_name)
597
621
 
@@ -17,7 +17,7 @@
17
17
  """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
18
18
 
19
19
  import logging
20
- from typing import Any, Dict, Iterable, Optional, Tuple, Union
20
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
21
21
 
22
22
  import torch
23
23
  import torch.nn.functional as F
@@ -536,6 +536,8 @@ class Qwen2MoeForCausalLM(nn.Module):
536
536
  use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
537
537
  )
538
538
  self.logits_processor = LogitsProcessor(config)
539
+ # For EAGLE3 support
540
+ self.capture_aux_hidden_states = False
539
541
 
540
542
  @torch.no_grad()
541
543
  def forward(
@@ -553,9 +555,12 @@ class Qwen2MoeForCausalLM(nn.Module):
553
555
  input_embeds,
554
556
  pp_proxy_tensors=pp_proxy_tensors,
555
557
  )
558
+ aux_hidden_states = None
559
+ if self.capture_aux_hidden_states:
560
+ hidden_states, aux_hidden_states = hidden_states
556
561
  if self.pp_group.is_last_rank:
557
562
  return self.logits_processor(
558
- input_ids, hidden_states, self.lm_head, forward_batch
563
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
559
564
  )
560
565
  else:
561
566
  return hidden_states
@@ -705,5 +710,20 @@ class Qwen2MoeForCausalLM(nn.Module):
705
710
  num_groups=None,
706
711
  )
707
712
 
713
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
714
+ if not self.pp_group.is_last_rank:
715
+ return
716
+
717
+ self.capture_aux_hidden_states = True
718
+ if layer_ids is None:
719
+ num_layers = self.config.num_hidden_layers
720
+ self.model.layers_to_capture = [
721
+ 2,
722
+ num_layers // 2,
723
+ num_layers - 3,
724
+ ] # Specific layers for EAGLE3 support
725
+ else:
726
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
727
+
708
728
 
709
729
  EntryClass = Qwen2MoeForCausalLM
@@ -213,7 +213,7 @@ class TransformersForCausalLM(nn.Module):
213
213
  """
214
214
  tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {}
215
215
 
216
- if not tp_plan and self.tp_size > 1:
216
+ if not tp_plan and tp_size > 1:
217
217
  raise ValueError(
218
218
  f"{type(self.model)} does not support tensor parallel yet!"
219
219
  )
@@ -13,7 +13,9 @@ from PIL import Image
13
13
  from transformers import BaseImageProcessorFast
14
14
 
15
15
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
16
- from sglang.srt.utils import load_audio, load_image, load_video, logger
16
+ from sglang.srt.utils import is_npu, load_audio, load_image, load_video, logger
17
+
18
+ _is_npu = is_npu()
17
19
 
18
20
 
19
21
  @dataclasses.dataclass
@@ -232,7 +234,7 @@ class BaseMultimodalProcessor(ABC):
232
234
  and isinstance(processor.image_processor, BaseImageProcessorFast)
233
235
  and not self.server_args.disable_fast_image_processor
234
236
  ):
235
- kwargs["device"] = "cuda"
237
+ kwargs["device"] = "cuda" if not _is_npu else "npu"
236
238
  result = processor.__call__(
237
239
  text=[input_text],
238
240
  padding=True,
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import abc
4
- from typing import TYPE_CHECKING, Set, Type
4
+ import weakref
5
+ from typing import TYPE_CHECKING, Optional, Set, Type
5
6
 
6
7
  import torch
7
8
 
@@ -17,7 +18,7 @@ class BatchedPenalizerOrchestrator:
17
18
  penalizers: Set[Type["_BatchedPenalizer"]],
18
19
  ):
19
20
  self.vocab_size = vocab_size
20
- self.batch = batch
21
+ self._batch_ref = weakref.ref(batch)
21
22
  self.device = batch.device
22
23
  self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
23
24
 
@@ -27,6 +28,17 @@ class BatchedPenalizerOrchestrator:
27
28
  is_required |= pen_is_required
28
29
  self.is_required = is_required
29
30
 
31
+ @property
32
+ def batch(self) -> ScheduleBatch | None:
33
+ return self._batch_ref()
34
+
35
+ @batch.setter
36
+ def batch(self, value: Optional[ScheduleBatch]):
37
+ if value is None:
38
+ self._batch_ref = lambda: None
39
+ else:
40
+ self._batch_ref = weakref.ref(value)
41
+
30
42
  def reqs(self):
31
43
  return self.batch.reqs
32
44
 
sglang/srt/server_args.py CHANGED
@@ -48,12 +48,87 @@ from sglang.srt.utils import (
48
48
  logger = logging.getLogger(__name__)
49
49
 
50
50
 
51
+ # Define constants
52
+ LOAD_FORMAT_CHOICES = [
53
+ "auto",
54
+ "pt",
55
+ "safetensors",
56
+ "npcache",
57
+ "dummy",
58
+ "sharded_state",
59
+ "gguf",
60
+ "bitsandbytes",
61
+ "layered",
62
+ "remote",
63
+ ]
64
+
65
+ QUANTIZATION_CHOICES = [
66
+ "awq",
67
+ "fp8",
68
+ "gptq",
69
+ "marlin",
70
+ "gptq_marlin",
71
+ "awq_marlin",
72
+ "bitsandbytes",
73
+ "gguf",
74
+ "modelopt",
75
+ "modelopt_fp4",
76
+ "petit_nvfp4",
77
+ "w8a8_int8",
78
+ "w8a8_fp8",
79
+ "moe_wna16",
80
+ "qoq",
81
+ "w4afp8",
82
+ "mxfp4",
83
+ ]
84
+
85
+ ATTENTION_BACKEND_CHOICES = [
86
+ # Common
87
+ "triton",
88
+ "torch_native",
89
+ # NVIDIA specific
90
+ "cutlass_mla",
91
+ "fa3",
92
+ "flashinfer",
93
+ "flashmla",
94
+ "trtllm_mla",
95
+ "trtllm_mha",
96
+ "dual_chunk_flash_attn",
97
+ # AMD specific
98
+ "aiter",
99
+ "wave",
100
+ # Other platforms
101
+ "intel_amx",
102
+ "ascend",
103
+ ]
104
+
105
+ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
106
+
107
+
108
+ # Allow external code to add more choices
109
+ def add_load_format_choices(choices):
110
+ LOAD_FORMAT_CHOICES.extend(choices)
111
+
112
+
113
+ def add_quantization_method_choices(choices):
114
+ QUANTIZATION_CHOICES.extend(choices)
115
+
116
+
117
+ def add_attention_backend_choices(choices):
118
+ ATTENTION_BACKEND_CHOICES.extend(choices)
119
+
120
+
121
+ def add_disagg_transfer_backend_choices(choices):
122
+ DISAGG_TRANSFER_BACKEND_CHOICES.extend(choices)
123
+
124
+
51
125
  @dataclasses.dataclass
52
126
  class ServerArgs:
53
127
  # Model and tokenizer
54
128
  model_path: str
55
129
  tokenizer_path: Optional[str] = None
56
130
  tokenizer_mode: str = "auto"
131
+ tokenizer_worker_num: int = 1
57
132
  skip_tokenizer_init: bool = False
58
133
  load_format: str = "auto"
59
134
  model_loader_extra_config: str = "{}"
@@ -200,6 +275,7 @@ class ServerArgs:
200
275
  eplb_algorithm: str = "auto"
201
276
  eplb_rebalance_num_iterations: int = 1000
202
277
  eplb_rebalance_layers_per_chunk: Optional[int] = None
278
+ eplb_min_rebalancing_utilization_threshold: float = 1.0
203
279
  expert_distribution_recorder_mode: Optional[
204
280
  Literal["stat", "stat_approx", "per_pass", "per_token"]
205
281
  ] = None
@@ -212,7 +288,7 @@ class ServerArgs:
212
288
  enable_hierarchical_cache: bool = False
213
289
  hicache_ratio: float = 2.0
214
290
  hicache_size: int = 0
215
- hicache_write_policy: str = "write_through_selective"
291
+ hicache_write_policy: str = "write_through"
216
292
  hicache_io_backend: str = "kernel"
217
293
  hicache_mem_layout: str = "layer_first"
218
294
  hicache_storage_backend: Optional[str] = None
@@ -673,6 +749,15 @@ class ServerArgs:
673
749
  )
674
750
  self.speculative_num_draft_tokens = self.speculative_num_steps + 1
675
751
 
752
+ if (
753
+ self.speculative_eagle_topk > 1
754
+ and self.page_size > 1
755
+ and self.attention_backend != "flashinfer"
756
+ ):
757
+ raise ValueError(
758
+ "speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
759
+ )
760
+
676
761
  # The token generated from the verify step is counted.
677
762
  # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
678
763
  # assert self.speculative_num_steps < self.speculative_num_draft_tokens
@@ -743,6 +828,12 @@ class ServerArgs:
743
828
  default=ServerArgs.tokenizer_path,
744
829
  help="The path of the tokenizer.",
745
830
  )
831
+ parser.add_argument(
832
+ "--tokenizer-worker-num",
833
+ type=int,
834
+ default=ServerArgs.tokenizer_worker_num,
835
+ help="The worker num of the tokenizer manager.",
836
+ )
746
837
  parser.add_argument(
747
838
  "--tokenizer-mode",
748
839
  type=str,
@@ -761,18 +852,7 @@ class ServerArgs:
761
852
  "--load-format",
762
853
  type=str,
763
854
  default=ServerArgs.load_format,
764
- choices=[
765
- "auto",
766
- "pt",
767
- "safetensors",
768
- "npcache",
769
- "dummy",
770
- "sharded_state",
771
- "gguf",
772
- "bitsandbytes",
773
- "layered",
774
- "remote",
775
- ],
855
+ choices=LOAD_FORMAT_CHOICES,
776
856
  help="The format of the model weights to load. "
777
857
  '"auto" will try to load the weights in the safetensors format '
778
858
  "and fall back to the pytorch bin format if safetensors format "
@@ -891,25 +971,7 @@ class ServerArgs:
891
971
  "--quantization",
892
972
  type=str,
893
973
  default=ServerArgs.quantization,
894
- choices=[
895
- "awq",
896
- "fp8",
897
- "gptq",
898
- "marlin",
899
- "gptq_marlin",
900
- "awq_marlin",
901
- "bitsandbytes",
902
- "gguf",
903
- "modelopt",
904
- "modelopt_fp4",
905
- "petit_nvfp4",
906
- "w8a8_int8",
907
- "w8a8_fp8",
908
- "moe_wna16",
909
- "qoq",
910
- "w4afp8",
911
- "mxfp4",
912
- ],
974
+ choices=QUANTIZATION_CHOICES,
913
975
  help="The quantization method.",
914
976
  )
915
977
  parser.add_argument(
@@ -1359,43 +1421,24 @@ class ServerArgs:
1359
1421
  )
1360
1422
 
1361
1423
  # Kernel backend
1362
- ATTN_BACKENDS = [
1363
- # Common
1364
- "triton",
1365
- "torch_native",
1366
- # NVIDIA specific
1367
- "cutlass_mla",
1368
- "fa3",
1369
- "flashinfer",
1370
- "flashmla",
1371
- "trtllm_mla",
1372
- "trtllm_mha",
1373
- "dual_chunk_flash_attn",
1374
- # AMD specific
1375
- "aiter",
1376
- "wave",
1377
- # Other platforms
1378
- "intel_amx",
1379
- "ascend",
1380
- ]
1381
1424
  parser.add_argument(
1382
1425
  "--attention-backend",
1383
1426
  type=str,
1384
- choices=ATTN_BACKENDS,
1427
+ choices=ATTENTION_BACKEND_CHOICES,
1385
1428
  default=ServerArgs.attention_backend,
1386
1429
  help="Choose the kernels for attention layers.",
1387
1430
  )
1388
1431
  parser.add_argument(
1389
1432
  "--prefill-attention-backend",
1390
1433
  type=str,
1391
- choices=ATTN_BACKENDS,
1434
+ choices=ATTENTION_BACKEND_CHOICES,
1392
1435
  default=ServerArgs.prefill_attention_backend,
1393
1436
  help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
1394
1437
  )
1395
1438
  parser.add_argument(
1396
1439
  "--decode-attention-backend",
1397
1440
  type=str,
1398
- choices=ATTN_BACKENDS,
1441
+ choices=ATTENTION_BACKEND_CHOICES,
1399
1442
  default=ServerArgs.decode_attention_backend,
1400
1443
  help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
1401
1444
  )
@@ -1560,6 +1603,12 @@ class ServerArgs:
1560
1603
  default=ServerArgs.eplb_rebalance_layers_per_chunk,
1561
1604
  help="Number of layers to rebalance per forward pass.",
1562
1605
  )
1606
+ parser.add_argument(
1607
+ "--eplb-min-rebalancing-utilization-threshold",
1608
+ type=float,
1609
+ default=ServerArgs.eplb_min_rebalancing_utilization_threshold,
1610
+ help="Minimum threshold for GPU average utilization to trigger EPLB rebalancing. Must be in the range [0.0, 1.0].",
1611
+ )
1563
1612
  parser.add_argument(
1564
1613
  "--expert-distribution-recorder-mode",
1565
1614
  type=str,
@@ -1959,7 +2008,7 @@ class ServerArgs:
1959
2008
  "--disaggregation-transfer-backend",
1960
2009
  type=str,
1961
2010
  default=ServerArgs.disaggregation_transfer_backend,
1962
- choices=["mooncake", "nixl", "ascend"],
2011
+ choices=DISAGG_TRANSFER_BACKEND_CHOICES,
1963
2012
  help="The backend for disaggregation transfer. Default is mooncake.",
1964
2013
  )
1965
2014
  parser.add_argument(
@@ -2134,6 +2183,9 @@ class ServerArgs:
2134
2183
  self.chunked_prefill_size % self.page_size == 0
2135
2184
  ), "chunked_prefill_size must be divisible by page_size"
2136
2185
 
2186
+ # Check multi tokenizer
2187
+ assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1"
2188
+
2137
2189
  def check_lora_server_args(self):
2138
2190
  assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
2139
2191
 
@@ -2377,6 +2429,9 @@ class PortArgs:
2377
2429
  # The ipc filename for Scheduler to send metrics
2378
2430
  metrics_ipc_name: str
2379
2431
 
2432
+ # The ipc filename for Tokenizer and worker tokenizer
2433
+ tokenizer_worker_ipc_name: Optional[str]
2434
+
2380
2435
  @staticmethod
2381
2436
  def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
2382
2437
  if server_args.nccl_port is None:
@@ -2400,6 +2455,7 @@ class PortArgs:
2400
2455
  nccl_port=nccl_port,
2401
2456
  rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
2402
2457
  metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
2458
+ tokenizer_worker_ipc_name=None,
2403
2459
  )
2404
2460
  else:
2405
2461
  # DP attention. Use TCP + port to handle both single-node and multi-node.
@@ -2433,6 +2489,7 @@ class PortArgs:
2433
2489
  nccl_port=nccl_port,
2434
2490
  rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
2435
2491
  metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
2492
+ tokenizer_worker_ipc_name=None,
2436
2493
  )
2437
2494
 
2438
2495