sglang 0.4.0__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/srt/constrained/outlines_backend.py +5 -0
  3. sglang/srt/constrained/xgrammar_backend.py +5 -5
  4. sglang/srt/layers/attention/__init__.py +5 -2
  5. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  6. sglang/srt/layers/attention/flashinfer_backend.py +20 -5
  7. sglang/srt/layers/attention/torch_native_backend.py +22 -8
  8. sglang/srt/layers/attention/triton_backend.py +22 -8
  9. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  10. sglang/srt/layers/ep_moe/__init__.py +0 -0
  11. sglang/srt/layers/ep_moe/kernels.py +349 -0
  12. sglang/srt/layers/ep_moe/layer.py +661 -0
  13. sglang/srt/layers/quantization/__init__.py +2 -2
  14. sglang/srt/layers/quantization/fp8.py +559 -0
  15. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  16. sglang/srt/layers/radix_attention.py +4 -2
  17. sglang/srt/layers/sampler.py +2 -0
  18. sglang/srt/layers/torchao_utils.py +23 -45
  19. sglang/srt/managers/schedule_batch.py +1 -0
  20. sglang/srt/managers/scheduler.py +69 -65
  21. sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
  22. sglang/srt/mem_cache/memory_pool.py +5 -1
  23. sglang/srt/model_executor/cuda_graph_runner.py +15 -1
  24. sglang/srt/model_executor/model_runner.py +11 -4
  25. sglang/srt/model_parallel.py +1 -5
  26. sglang/srt/models/commandr.py +2 -2
  27. sglang/srt/models/deepseek_v2.py +87 -7
  28. sglang/srt/models/grok.py +0 -5
  29. sglang/srt/models/llama.py +0 -5
  30. sglang/srt/models/mixtral.py +12 -9
  31. sglang/srt/models/phi3_small.py +0 -5
  32. sglang/srt/models/qwen2_moe.py +0 -5
  33. sglang/srt/models/torch_native_llama.py +0 -5
  34. sglang/srt/sampling/sampling_batch_info.py +9 -8
  35. sglang/srt/server.py +3 -3
  36. sglang/srt/server_args.py +43 -4
  37. sglang/srt/utils.py +50 -0
  38. sglang/version.py +1 -1
  39. {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  40. {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/RECORD +43 -38
  41. {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  42. {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  43. {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
+ from vllm import _custom_ops as ops
24
25
  from vllm.distributed import (
25
26
  get_tensor_model_parallel_rank,
26
27
  get_tensor_model_parallel_world_size,
@@ -30,6 +31,7 @@ from vllm.distributed import (
30
31
  from vllm.model_executor.layers.rotary_embedding import get_rope
31
32
 
32
33
  from sglang.srt.layers.activation import SiluAndMul
34
+ from sglang.srt.layers.ep_moe.layer import EPMoE
33
35
  from sglang.srt.layers.fused_moe_triton import FusedMoE
34
36
  from sglang.srt.layers.layernorm import RMSNorm
35
37
  from sglang.srt.layers.linear import (
@@ -112,12 +114,12 @@ class DeepseekV2MoE(nn.Module):
112
114
  "Only silu is supported for now."
113
115
  )
114
116
 
115
- self.experts = FusedMoE(
117
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
118
+ self.experts = MoEImpl(
116
119
  num_experts=config.n_routed_experts,
117
120
  top_k=config.num_experts_per_tok,
118
121
  hidden_size=config.hidden_size,
119
122
  intermediate_size=config.moe_intermediate_size,
120
- reduce_results=False,
121
123
  renormalize=config.norm_topk_prob,
122
124
  quant_config=quant_config,
123
125
  use_grouped_topk=True,
@@ -453,7 +455,7 @@ class DeepseekV2AttentionMLA(nn.Module):
453
455
  mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
454
456
  self.scaling = self.scaling * mscale * mscale
455
457
 
456
- self.attn = RadixAttention(
458
+ self.attn_mqa = RadixAttention(
457
459
  self.num_local_heads,
458
460
  self.kv_lora_rank + self.qk_rope_head_dim,
459
461
  self.scaling,
@@ -462,6 +464,15 @@ class DeepseekV2AttentionMLA(nn.Module):
462
464
  v_head_dim=self.kv_lora_rank,
463
465
  )
464
466
 
467
+ self.attn_mha = RadixAttention(
468
+ self.num_local_heads,
469
+ self.qk_nope_head_dim + self.qk_rope_head_dim,
470
+ self.scaling,
471
+ num_kv_heads=self.num_local_heads,
472
+ layer_id=layer_id,
473
+ v_head_dim=self.v_head_dim,
474
+ )
475
+
465
476
  self.w_kc = None
466
477
  self.w_vc = None
467
478
  self.w_scale = None
@@ -471,6 +482,63 @@ class DeepseekV2AttentionMLA(nn.Module):
471
482
  positions: torch.Tensor,
472
483
  hidden_states: torch.Tensor,
473
484
  forward_batch: ForwardBatch,
485
+ ) -> torch.Tensor:
486
+ # Use normal computation for prefill and use weight absorption for extend/decode
487
+ if (
488
+ forward_batch.forward_mode.is_extend()
489
+ and forward_batch.extend_prefix_lens.sum() == 0
490
+ ):
491
+ return self.forward_normal(positions, hidden_states, forward_batch)
492
+ else:
493
+ return self.forward_absorb(positions, hidden_states, forward_batch)
494
+
495
+ def forward_normal(
496
+ self,
497
+ positions: torch.Tensor,
498
+ hidden_states: torch.Tensor,
499
+ forward_batch: ForwardBatch,
500
+ ) -> torch.Tensor:
501
+ if self.q_lora_rank is not None:
502
+ q = self.q_a_proj(hidden_states)[0]
503
+ q = self.q_a_layernorm(q)
504
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
505
+ else:
506
+ q = self.q_proj(hidden_states)[0].view(
507
+ -1, self.num_local_heads, self.qk_head_dim
508
+ )
509
+ _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
510
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
511
+ kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
512
+ latent_cache = latent_cache.unsqueeze(1)
513
+ kv_a = self.kv_a_layernorm(kv_a.contiguous())
514
+ kv = self.kv_b_proj(kv_a)[0]
515
+ kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
516
+ k_nope = kv[..., : self.qk_nope_head_dim]
517
+ v = kv[..., self.qk_nope_head_dim :]
518
+ k_pe = latent_cache[:, :, self.kv_lora_rank :]
519
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
520
+ q[..., self.qk_nope_head_dim :] = q_pe
521
+ k = torch.empty_like(q)
522
+ k[..., : self.qk_nope_head_dim] = k_nope
523
+ k[..., self.qk_nope_head_dim :] = k_pe
524
+
525
+ latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
526
+ latent_cache[:, :, self.kv_lora_rank :] = k_pe
527
+
528
+ # Save latent cache
529
+ forward_batch.token_to_kv_pool.set_kv_buffer(
530
+ self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
531
+ )
532
+ attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
533
+ attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
534
+ output, _ = self.o_proj(attn_output)
535
+ return output
536
+
537
+ def forward_absorb(
538
+ self,
539
+ positions: torch.Tensor,
540
+ hidden_states: torch.Tensor,
541
+ forward_batch: ForwardBatch,
474
542
  ) -> torch.Tensor:
475
543
  q_len = hidden_states.shape[0]
476
544
  q_input = hidden_states.new_empty(
@@ -508,7 +576,7 @@ class DeepseekV2AttentionMLA(nn.Module):
508
576
  q_input[..., self.kv_lora_rank :] = q_pe
509
577
  k_input[..., self.kv_lora_rank :] = k_pe
510
578
 
511
- attn_output = self.attn(q_input, k_input, v_input, forward_batch)
579
+ attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
512
580
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
513
581
 
514
582
  if self.w_vc.dtype == torch.float8_e4m3fn:
@@ -767,7 +835,8 @@ class DeepseekV2ForCausalLM(nn.Module):
767
835
 
768
836
  # Params for weights, fp8 weight scales, fp8 activation scales
769
837
  # (param_name, weight_name, expert_id, shard_id)
770
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
838
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
839
+ expert_params_mapping = MoEImpl.make_expert_params_mapping(
771
840
  ckpt_gate_proj_name="gate_proj",
772
841
  ckpt_down_proj_name="down_proj",
773
842
  ckpt_up_proj_name="up_proj",
@@ -828,14 +897,25 @@ class DeepseekV2ForCausalLM(nn.Module):
828
897
  if not global_server_args_dict["disable_mla"]:
829
898
  for layer_id in range(self.config.num_hidden_layers):
830
899
  self_attn = self.model.layers[layer_id].self_attn
831
- w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
900
+ if hasattr(self_attn.kv_b_proj, "qweight"):
901
+ # AWQ compatible
902
+ w = ops.awq_dequantize(
903
+ self_attn.kv_b_proj.qweight,
904
+ self_attn.kv_b_proj.scales,
905
+ self_attn.kv_b_proj.qzeros,
906
+ 0,
907
+ 0,
908
+ 0,
909
+ ).T
910
+ else:
911
+ w = self_attn.kv_b_proj.weight
912
+ w_kc, w_vc = w.unflatten(
832
913
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
833
914
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
834
915
  self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
835
916
  self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
836
917
  if hasattr(self_attn.kv_b_proj, "weight_scale"):
837
918
  self_attn.w_scale = self_attn.kv_b_proj.weight_scale
838
- del self_attn.kv_b_proj
839
919
 
840
920
 
841
921
  EntryClass = DeepseekV2ForCausalLM
sglang/srt/models/grok.py CHANGED
@@ -35,12 +35,10 @@ from sglang.srt.layers.linear import (
35
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
39
38
  from sglang.srt.layers.vocab_parallel_embedding import (
40
39
  ParallelLMHead,
41
40
  VocabParallelEmbedding,
42
41
  )
43
- from sglang.srt.managers.schedule_batch import global_server_args_dict
44
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
43
  from sglang.srt.model_loader.loader import DefaultModelLoader
46
44
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -290,7 +288,6 @@ class Grok1ForCausalLM(nn.Module):
290
288
  super().__init__()
291
289
  self.config = config
292
290
  self.quant_config = quant_config
293
- self.torchao_config = global_server_args_dict["torchao_config"]
294
291
  self.model = Grok1Model(config, quant_config=quant_config)
295
292
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
296
293
  self.logits_processor = LogitsProcessor(config)
@@ -374,8 +371,6 @@ class Grok1ForCausalLM(nn.Module):
374
371
  )
375
372
  weight_loader(param, loaded_weight)
376
373
 
377
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
378
-
379
374
 
380
375
  class Grok1ModelForCausalLM(Grok1ForCausalLM):
381
376
  """An alias for backward-compatbility."""
@@ -36,12 +36,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
36
36
  from sglang.srt.layers.pooler import Pooler, PoolingType
37
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
40
39
  from sglang.srt.layers.vocab_parallel_embedding import (
41
40
  ParallelLMHead,
42
41
  VocabParallelEmbedding,
43
42
  )
44
- from sglang.srt.managers.schedule_batch import global_server_args_dict
45
43
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
46
44
  from sglang.srt.model_loader.weight_utils import default_weight_loader
47
45
  from sglang.srt.utils import make_layers
@@ -304,7 +302,6 @@ class LlamaForCausalLM(nn.Module):
304
302
  super().__init__()
305
303
  self.config = config
306
304
  self.quant_config = quant_config
307
- self.torchao_config = global_server_args_dict["torchao_config"]
308
305
  self.model = LlamaModel(config, quant_config=quant_config)
309
306
  # Llama 3.2 1B Insturct set tie_word_embeddings to True
310
307
  # Llama 3.1 8B Insturct set tie_word_embeddings to False
@@ -424,8 +421,6 @@ class LlamaForCausalLM(nn.Module):
424
421
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
425
422
  weight_loader(param, loaded_weight)
426
423
 
427
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
428
-
429
424
  def get_weights_by_name(
430
425
  self, name: str, truncate_size: int = 100, tp_size: int = 1
431
426
  ) -> Optional[torch.Tensor]:
@@ -21,9 +21,13 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import MixtralConfig
24
- from vllm.distributed import get_tensor_model_parallel_world_size
24
+ from vllm.distributed import (
25
+ get_tensor_model_parallel_world_size,
26
+ tensor_model_parallel_all_reduce,
27
+ )
25
28
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
29
 
30
+ from sglang.srt.layers.ep_moe.layer import EPMoE
27
31
  from sglang.srt.layers.fused_moe_triton import FusedMoE
28
32
  from sglang.srt.layers.layernorm import RMSNorm
29
33
  from sglang.srt.layers.linear import (
@@ -34,7 +38,6 @@ from sglang.srt.layers.linear import (
34
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
35
39
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
40
  from sglang.srt.layers.radix_attention import RadixAttention
37
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
38
41
  from sglang.srt.layers.vocab_parallel_embedding import (
39
42
  ParallelLMHead,
40
43
  VocabParallelEmbedding,
@@ -65,6 +68,7 @@ class MixtralMoE(nn.Module):
65
68
  prefix: str = "",
66
69
  ):
67
70
  super().__init__()
71
+ self.tp_size = get_tensor_model_parallel_world_size()
68
72
  self.hidden_size = hidden_size
69
73
 
70
74
  # Gate always runs at half / full precision for now.
@@ -76,14 +80,13 @@ class MixtralMoE(nn.Module):
76
80
  quant_config=None,
77
81
  prefix=f"{prefix}.gate",
78
82
  )
79
-
80
- self.experts = FusedMoE(
83
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
84
+ self.experts = MoEImpl(
81
85
  num_experts=num_experts,
82
86
  top_k=top_k,
83
87
  hidden_size=hidden_size,
84
88
  intermediate_size=intermediate_size,
85
89
  params_dtype=params_dtype,
86
- reduce_results=True,
87
90
  renormalize=True,
88
91
  quant_config=quant_config,
89
92
  tp_size=tp_size,
@@ -97,6 +100,8 @@ class MixtralMoE(nn.Module):
97
100
  # router_logits: (num_tokens, n_experts)
98
101
  router_logits, _ = self.gate(hidden_states)
99
102
  final_hidden_states = self.experts(hidden_states, router_logits)
103
+ if self.tp_size > 1:
104
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
100
105
  return final_hidden_states.view(orig_shape)
101
106
 
102
107
 
@@ -295,7 +300,6 @@ class MixtralForCausalLM(nn.Module):
295
300
  super().__init__()
296
301
  self.config = config
297
302
  self.quant_config = quant_config
298
- self.torchao_config = global_server_args_dict["torchao_config"]
299
303
  self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
300
304
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
301
305
  self.logits_processor = LogitsProcessor(config)
@@ -322,7 +326,8 @@ class MixtralForCausalLM(nn.Module):
322
326
 
323
327
  # Params for weights, fp8 weight scales, fp8 activation scales
324
328
  # (param_name, weight_name, expert_id, shard_id)
325
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
329
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
330
+ expert_params_mapping = MoEImpl.make_expert_params_mapping(
326
331
  ckpt_gate_proj_name="w1",
327
332
  ckpt_down_proj_name="w2",
328
333
  ckpt_up_proj_name="w3",
@@ -387,7 +392,5 @@ class MixtralForCausalLM(nn.Module):
387
392
  )
388
393
  weight_loader(param, loaded_weight)
389
394
 
390
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
391
-
392
395
 
393
396
  EntryClass = MixtralForCausalLM
@@ -17,13 +17,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
17
17
  from sglang.srt.layers.pooler import Pooler, PoolingType
18
18
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
19
  from sglang.srt.layers.radix_attention import RadixAttention
20
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
21
20
  from sglang.srt.layers.vocab_parallel_embedding import (
22
21
  DEFAULT_VOCAB_PADDING_SIZE,
23
22
  ParallelLMHead,
24
23
  VocabParallelEmbedding,
25
24
  )
26
- from sglang.srt.managers.schedule_batch import global_server_args_dict
27
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
26
  from sglang.srt.model_loader.weight_utils import default_weight_loader
29
27
  from sglang.srt.utils import make_layers
@@ -348,7 +346,6 @@ class Phi3SmallForCausalLM(nn.Module):
348
346
  quant_config=quant_config,
349
347
  prefix="model",
350
348
  )
351
- self.torchao_config = global_server_args_dict["torchao_config"]
352
349
  self.vocab_size = config.vocab_size
353
350
  self.mup_width_multiplier = config.mup_width_multiplier
354
351
  self.lm_head = ParallelLMHead(
@@ -441,7 +438,5 @@ class Phi3SmallForCausalLM(nn.Module):
441
438
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
442
439
  weight_loader(param, loaded_weight)
443
440
 
444
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
445
-
446
441
 
447
442
  EntryClass = Phi3SmallForCausalLM
@@ -40,12 +40,10 @@ from sglang.srt.layers.linear import (
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
44
43
  from sglang.srt.layers.vocab_parallel_embedding import (
45
44
  ParallelLMHead,
46
45
  VocabParallelEmbedding,
47
46
  )
48
- from sglang.srt.managers.schedule_batch import global_server_args_dict
49
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
48
  from sglang.srt.model_loader.weight_utils import default_weight_loader
51
49
 
@@ -352,7 +350,6 @@ class Qwen2MoeForCausalLM(nn.Module):
352
350
  super().__init__()
353
351
  self.config = config
354
352
  self.quant_config = quant_config
355
- self.torchao_config = global_server_args_dict["torchao_config"]
356
353
  self.model = Qwen2MoeModel(config, quant_config)
357
354
  self.lm_head = ParallelLMHead(
358
355
  config.vocab_size, config.hidden_size, quant_config=quant_config
@@ -445,7 +442,5 @@ class Qwen2MoeForCausalLM(nn.Module):
445
442
  )
446
443
  weight_loader(param, loaded_weight)
447
444
 
448
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
449
-
450
445
 
451
446
  EntryClass = Qwen2MoeForCausalLM
@@ -58,12 +58,10 @@ from sglang.srt.layers.layernorm import RMSNorm
58
58
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
59
59
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
60
60
  from sglang.srt.layers.radix_attention import RadixAttention
61
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
62
61
  from sglang.srt.layers.vocab_parallel_embedding import (
63
62
  ParallelLMHead,
64
63
  VocabParallelEmbedding,
65
64
  )
66
- from sglang.srt.managers.schedule_batch import global_server_args_dict
67
65
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
68
66
  from sglang.srt.model_loader.weight_utils import default_weight_loader
69
67
 
@@ -392,7 +390,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
392
390
  super().__init__()
393
391
  self.config = config
394
392
  self.quant_config = quant_config
395
- self.torchao_config = global_server_args_dict["torchao_config"]
396
393
  self.supports_torch_tp = True
397
394
  self.model = LlamaModel(config, quant_config=quant_config)
398
395
  if self.config.tie_word_embeddings:
@@ -503,8 +500,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
503
500
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
504
501
  weight_loader(param, loaded_weight)
505
502
 
506
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
507
-
508
503
 
509
504
  class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
510
505
  pass
@@ -158,22 +158,23 @@ class SamplingBatchInfo:
158
158
  return
159
159
 
160
160
  # find a grammar from the list
161
- grammar = next(grammar for grammar in self.grammars if grammar)
161
+ first_grammar = next(grammar for grammar in self.grammars if grammar)
162
162
 
163
163
  # maybe we can reuse the existing mask?
164
- self.vocab_mask = grammar.allocate_vocab_mask(
164
+ self.vocab_mask = first_grammar.allocate_vocab_mask(
165
165
  vocab_size=self.vocab_size,
166
166
  batch_size=len(self.temperatures),
167
167
  device=self.device,
168
168
  )
169
- self.apply_mask = type(grammar).apply_vocab_mask # force to use static method
169
+ self.apply_mask = first_grammar.apply_vocab_mask # force to use static method
170
170
 
171
+ # Apply the mask
171
172
  for i, grammar in enumerate(self.grammars):
172
- if grammar is not None:
173
- try:
174
- grammar.fill_vocab_mask(self.vocab_mask, i)
175
- except RuntimeError:
176
- continue
173
+ if grammar and not grammar.finished:
174
+ grammar.fill_vocab_mask(self.vocab_mask, i)
175
+
176
+ # Move the mask to the device if needed
177
+ self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
177
178
 
178
179
  def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
179
180
  self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
sglang/srt/server.py CHANGED
@@ -329,7 +329,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
329
329
  )
330
330
 
331
331
 
332
- @app.api_route("/encode", methods=["POST", "PUT"])
332
+ @app.api_route("/classify", methods=["POST", "PUT"])
333
333
  @time_func_latency
334
334
  async def classify_request(obj: EmbeddingReqInput, request: Request):
335
335
  """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
@@ -462,8 +462,8 @@ def launch_engine(
462
462
  if server_args.node_rank >= 1:
463
463
  # For other nodes, they do not need to run tokenizer or detokenizer,
464
464
  # so they can just wait here.
465
- while True:
466
- pass
465
+ for proc in scheduler_procs:
466
+ proc.join()
467
467
  else:
468
468
  # Launch the data parallel controller
469
469
  reader, writer = mp.Pipe(duplex=False)
sglang/srt/server_args.py CHANGED
@@ -20,9 +20,12 @@ import random
20
20
  import tempfile
21
21
  from typing import List, Optional
22
22
 
23
+ import torch
24
+
23
25
  from sglang.srt.hf_transformers_utils import check_gguf_file
24
26
  from sglang.srt.utils import (
25
27
  get_amdgpu_memory_capacity,
28
+ get_hpu_memory_capacity,
26
29
  get_nvgpu_memory_capacity,
27
30
  is_flashinfer_available,
28
31
  is_hip,
@@ -91,6 +94,8 @@ class ServerArgs:
91
94
  # Data parallelism
92
95
  dp_size: int = 1
93
96
  load_balance_method: str = "round_robin"
97
+ # Expert parallelism
98
+ ep_size: int = 1
94
99
 
95
100
  # Multi-node distributed serving
96
101
  dist_init_addr: Optional[str] = None
@@ -128,6 +133,7 @@ class ServerArgs:
128
133
  disable_overlap_schedule: bool = False
129
134
  enable_mixed_chunk: bool = False
130
135
  enable_dp_attention: bool = False
136
+ enable_ep_moe: bool = False
131
137
  enable_torch_compile: bool = False
132
138
  torch_compile_max_bs: int = 32
133
139
  cuda_graph_max_bs: Optional[int] = None
@@ -151,8 +157,13 @@ class ServerArgs:
151
157
 
152
158
  if is_hip():
153
159
  gpu_mem = get_amdgpu_memory_capacity()
154
- else:
160
+ elif torch.cuda.is_available():
155
161
  gpu_mem = get_nvgpu_memory_capacity()
162
+ elif self.device == "hpu":
163
+ gpu_mem = get_hpu_memory_capacity()
164
+ else:
165
+ # GPU memory is not known yet or no GPU is available.
166
+ gpu_mem = None
156
167
 
157
168
  # Set mem fraction static, which depends on the tensor parallelism size
158
169
  if self.mem_fraction_static is None:
@@ -169,19 +180,27 @@ class ServerArgs:
169
180
 
170
181
  # Set chunked prefill size, which depends on the gpu memory capacity
171
182
  if self.chunked_prefill_size is None:
172
- if gpu_mem < 25_000:
183
+ if gpu_mem is not None and gpu_mem < 25_000:
173
184
  self.chunked_prefill_size = 2048
174
185
  else:
175
186
  self.chunked_prefill_size = 8192
176
187
 
177
188
  # Set cuda graph max batch size
178
189
  if self.cuda_graph_max_bs is None:
179
- if gpu_mem < 25_000:
180
- self.cuda_graph_max_bs = 8
190
+ # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
191
+ if gpu_mem is not None and gpu_mem < 25_000:
192
+ if self.tp_size < 4:
193
+ self.cuda_graph_max_bs = 8
194
+ else:
195
+ self.cuda_graph_max_bs = 80
181
196
  else:
182
197
  self.cuda_graph_max_bs = 160
183
198
 
184
199
  # Choose kernel backends
200
+ if self.device == "hpu":
201
+ self.attention_backend = "torch_native"
202
+ self.sampling_backend = "pytorch"
203
+
185
204
  if self.attention_backend is None:
186
205
  self.attention_backend = (
187
206
  "flashinfer" if is_flashinfer_available() else "triton"
@@ -211,6 +230,12 @@ class ServerArgs:
211
230
  "Data parallel size is adjusted to be the same as tensor parallel size. "
212
231
  "Overlap scheduler is disabled."
213
232
  )
233
+ # Expert parallelism
234
+ if self.enable_ep_moe:
235
+ self.ep_size = self.tp_size
236
+ logger.info(
237
+ f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
238
+ )
214
239
 
215
240
  # GGUF
216
241
  if (
@@ -521,6 +546,14 @@ class ServerArgs:
521
546
  "shortest_queue",
522
547
  ],
523
548
  )
549
+ # Expert parallelism
550
+ parser.add_argument(
551
+ "--expert-parallel-size",
552
+ "--ep-size",
553
+ type=int,
554
+ default=ServerArgs.ep_size,
555
+ help="The expert parallelism size.",
556
+ )
524
557
 
525
558
  # Multi-node distributed serving
526
559
  parser.add_argument(
@@ -676,6 +709,11 @@ class ServerArgs:
676
709
  action="store_true",
677
710
  help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
678
711
  )
712
+ parser.add_argument(
713
+ "--enable-ep-moe",
714
+ action="store_true",
715
+ help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
716
+ )
679
717
  parser.add_argument(
680
718
  "--enable-torch-compile",
681
719
  action="store_true",
@@ -755,6 +793,7 @@ class ServerArgs:
755
793
  def from_cli_args(cls, args: argparse.Namespace):
756
794
  args.tp_size = args.tensor_parallel_size
757
795
  args.dp_size = args.data_parallel_size
796
+ args.ep_size = args.expert_parallel_size
758
797
  attrs = [attr.name for attr in dataclasses.fields(cls)]
759
798
  return cls(**{attr: getattr(args, attr) for attr in attrs})
760
799
 
sglang/srt/utils.py CHANGED
@@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
201
201
  total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
202
202
  free_gpu_memory = total_gpu_memory - used_memory
203
203
 
204
+ elif device == "hpu":
205
+ num_gpus = torch.hpu.device_count()
206
+ assert gpu_id < num_gpus
207
+
208
+ if torch.hpu.current_device() != gpu_id:
209
+ print(
210
+ f"WARNING: current device is not {gpu_id}, but {torch.hpu.current_device()}, ",
211
+ "which may cause useless memory allocation for torch HPU context.",
212
+ )
213
+
214
+ free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()
215
+
204
216
  if distributed:
205
217
  tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
206
218
  torch.device(device, gpu_id)
@@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity():
939
951
  )
940
952
 
941
953
 
954
+ def get_hpu_memory_capacity():
955
+ try:
956
+ # Run hl-smi and capture the output
957
+ result = subprocess.run(
958
+ ["hl-smi --query | grep 'Total'"],
959
+ stdout=subprocess.PIPE,
960
+ stderr=subprocess.PIPE,
961
+ shell=True,
962
+ text=True,
963
+ )
964
+
965
+ if result.returncode != 0:
966
+ raise RuntimeError(f"hl-smi error: {result.stderr.strip()}")
967
+
968
+ # Parse the output to extract memory values in MiB
969
+ memory_values = [
970
+ float(mem.split(" ")[-2]) for mem in result.stdout.strip().split("\n")
971
+ ]
972
+
973
+ if not memory_values:
974
+ raise ValueError("No GPU memory values found.")
975
+
976
+ # Return the minimum memory value
977
+ return min(memory_values)
978
+
979
+ except FileNotFoundError:
980
+ raise RuntimeError(
981
+ "hl-smi not found. Ensure Habana drivers are installed and accessible."
982
+ )
983
+
984
+
942
985
  # Copy from pytorch and OpenRLHF to allow creating multiple main groups.
943
986
  # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
944
987
  # https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
@@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
1062
1105
  return major, minor
1063
1106
 
1064
1107
 
1108
+ def get_compiler_backend() -> str:
1109
+ if hasattr(torch, "hpu") and torch.hpu.is_available():
1110
+ return "hpu_backend"
1111
+
1112
+ return "inductor"
1113
+
1114
+
1065
1115
  sglang_lib = Library("sglang", "FRAGMENT") # noqa
1066
1116
 
1067
1117
 
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.0"
1
+ __version__ = "0.4.0.post1"