sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/bench_serving.py +56 -12
  2. sglang/launch_server.py +2 -0
  3. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
  4. sglang/srt/compilation/backend.py +1 -1
  5. sglang/srt/configs/model_config.py +5 -5
  6. sglang/srt/distributed/parallel_state.py +0 -7
  7. sglang/srt/entrypoints/engine.py +18 -15
  8. sglang/srt/entrypoints/grpc_server.py +0 -1
  9. sglang/srt/entrypoints/http_server.py +75 -94
  10. sglang/srt/environ.py +16 -2
  11. sglang/srt/eplb/expert_distribution.py +30 -0
  12. sglang/srt/function_call/function_call_parser.py +2 -0
  13. sglang/srt/function_call/minimax_m2.py +367 -0
  14. sglang/srt/layers/activation.py +6 -0
  15. sglang/srt/layers/attention/flashattention_backend.py +12 -2
  16. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  17. sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
  18. sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
  19. sglang/srt/layers/attention/utils.py +78 -0
  20. sglang/srt/layers/communicator.py +1 -0
  21. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  22. sglang/srt/layers/layernorm.py +19 -4
  23. sglang/srt/layers/logits_processor.py +5 -0
  24. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  25. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  26. sglang/srt/layers/moe/ep_moe/layer.py +79 -272
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  29. sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
  30. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  31. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  32. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  33. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  34. sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
  35. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  36. sglang/srt/layers/moe/topk.py +4 -4
  37. sglang/srt/layers/moe/utils.py +3 -4
  38. sglang/srt/layers/quantization/__init__.py +3 -5
  39. sglang/srt/layers/quantization/awq.py +0 -3
  40. sglang/srt/layers/quantization/base_config.py +7 -0
  41. sglang/srt/layers/quantization/fp8.py +68 -63
  42. sglang/srt/layers/quantization/gguf.py +566 -0
  43. sglang/srt/layers/quantization/mxfp4.py +30 -38
  44. sglang/srt/layers/quantization/unquant.py +23 -45
  45. sglang/srt/layers/quantization/w4afp8.py +38 -2
  46. sglang/srt/layers/radix_attention.py +5 -2
  47. sglang/srt/layers/rotary_embedding.py +13 -1
  48. sglang/srt/layers/sampler.py +12 -1
  49. sglang/srt/managers/io_struct.py +3 -0
  50. sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
  51. sglang/srt/managers/scheduler.py +21 -15
  52. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  53. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  54. sglang/srt/managers/tokenizer_manager.py +11 -19
  55. sglang/srt/mem_cache/hicache_storage.py +7 -1
  56. sglang/srt/mem_cache/memory_pool.py +82 -0
  57. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  58. sglang/srt/model_executor/forward_batch_info.py +44 -3
  59. sglang/srt/model_executor/model_runner.py +1 -149
  60. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  61. sglang/srt/models/deepseek_v2.py +147 -44
  62. sglang/srt/models/glm4_moe.py +322 -354
  63. sglang/srt/models/glm4_moe_nextn.py +4 -14
  64. sglang/srt/models/glm4v_moe.py +29 -196
  65. sglang/srt/models/minimax_m2.py +922 -0
  66. sglang/srt/models/nvila.py +355 -0
  67. sglang/srt/models/nvila_lite.py +184 -0
  68. sglang/srt/models/qwen2.py +22 -1
  69. sglang/srt/models/qwen3.py +34 -4
  70. sglang/srt/models/qwen3_moe.py +2 -4
  71. sglang/srt/multimodal/processors/base_processor.py +1 -0
  72. sglang/srt/multimodal/processors/glm4v.py +1 -1
  73. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  74. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  75. sglang/srt/parser/reasoning_parser.py +28 -1
  76. sglang/srt/server_args.py +365 -186
  77. sglang/srt/single_batch_overlap.py +2 -7
  78. sglang/srt/utils/common.py +87 -42
  79. sglang/srt/utils/hf_transformers_utils.py +7 -3
  80. sglang/test/test_deterministic.py +235 -12
  81. sglang/test/test_deterministic_utils.py +2 -1
  82. sglang/version.py +1 -1
  83. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
  84. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
  85. sglang/srt/models/vila.py +0 -306
  86. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  87. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  88. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@
15
15
  """Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
16
16
 
17
17
  import logging
18
- from typing import Any, Dict, Iterable, Optional, Tuple
18
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
19
19
 
20
20
  import torch
21
21
  import torch.nn.functional as F
@@ -27,10 +27,16 @@ from sglang.srt.distributed import (
27
27
  get_pp_group,
28
28
  get_tensor_model_parallel_rank,
29
29
  get_tensor_model_parallel_world_size,
30
+ parallel_state,
30
31
  tensor_model_parallel_all_reduce,
31
32
  )
33
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
34
+ use_symmetric_memory,
35
+ )
36
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
37
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
38
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
32
39
  from sglang.srt.layers.activation import SiluAndMul
33
- from sglang.srt.layers.amx_utils import PackWeightMethod
34
40
  from sglang.srt.layers.communicator import (
35
41
  LayerCommunicator,
36
42
  LayerScatterModes,
@@ -48,7 +54,10 @@ from sglang.srt.layers.linear import (
48
54
  RowParallelLinear,
49
55
  )
50
56
  from sglang.srt.layers.logits_processor import LogitsProcessor
51
- from sglang.srt.layers.moe import get_moe_a2a_backend
57
+ from sglang.srt.layers.moe import (
58
+ get_moe_a2a_backend,
59
+ should_use_flashinfer_cutlass_moe_fp4_allgather,
60
+ )
52
61
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
53
62
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
54
63
  from sglang.srt.layers.moe.topk import TopK
@@ -56,23 +65,17 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
56
65
  from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
57
66
  from sglang.srt.layers.radix_attention import RadixAttention
58
67
  from sglang.srt.layers.rotary_embedding import get_rope
68
+ from sglang.srt.layers.utils import PPMissingLayer
59
69
  from sglang.srt.layers.vocab_parallel_embedding import (
60
70
  ParallelLMHead,
61
71
  VocabParallelEmbedding,
62
72
  )
63
73
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
64
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
74
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
65
75
  from sglang.srt.model_loader.weight_utils import default_weight_loader
66
- from sglang.srt.models.deepseek_v2 import (
67
- DeepseekV2DecoderLayer,
68
- DeepseekV2ForCausalLM,
69
- DeepseekV2Model,
70
- DeepseekV2MoE,
71
- )
72
76
  from sglang.srt.server_args import get_global_server_args
77
+ from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
73
78
  from sglang.srt.utils import (
74
- BumpAllocator,
75
- LazyValue,
76
79
  add_prefix,
77
80
  cpu_has_amx_support,
78
81
  get_bool_env_var,
@@ -80,8 +83,7 @@ from sglang.srt.utils import (
80
83
  is_cpu,
81
84
  is_cuda,
82
85
  is_hip,
83
- log_info_on_rank0,
84
- use_intel_amx_backend,
86
+ make_layers,
85
87
  )
86
88
 
87
89
  _is_hip = is_hip()
@@ -92,11 +94,6 @@ _is_cpu_amx_available = cpu_has_amx_support()
92
94
  _is_cpu = is_cpu()
93
95
  _device_sm = get_device_sm()
94
96
 
95
- if _is_cuda:
96
- from sgl_kernel import dsv3_router_gemm
97
- elif _is_cpu and _is_cpu_amx_available:
98
- pass
99
-
100
97
  logger = logging.getLogger(__name__)
101
98
 
102
99
 
@@ -136,8 +133,7 @@ class Glm4MoeMLP(nn.Module):
136
133
  )
137
134
  if hidden_act != "silu":
138
135
  raise ValueError(
139
- f"Unsupported activation: {hidden_act}. "
140
- "Only silu is supported for now."
136
+ f"Unsupported activation: {hidden_act}. Only silu is supported for now."
141
137
  )
142
138
  self.act_fn = SiluAndMul()
143
139
 
@@ -146,7 +142,6 @@ class Glm4MoeMLP(nn.Module):
146
142
  x,
147
143
  forward_batch=None,
148
144
  should_allreduce_fusion=False,
149
- gemm_output_zero_allocator: BumpAllocator = None,
150
145
  ):
151
146
  if (self.tp_size == 1) and x.shape[0] == 0:
152
147
  return x
@@ -326,47 +321,21 @@ class Glm4MoeGate(nn.Module):
326
321
  self,
327
322
  config,
328
323
  prefix: str = "",
329
- is_nextn: bool = False,
330
324
  ):
331
325
  super().__init__()
332
- self.is_nextn = is_nextn
333
326
  self.weight = nn.Parameter(
334
327
  torch.empty((config.n_routed_experts, config.hidden_size))
335
328
  )
336
329
  self.e_score_correction_bias = nn.Parameter(
337
330
  torch.empty((config.n_routed_experts), dtype=torch.float32)
338
331
  )
339
- if _is_cpu and _is_cpu_amx_available:
340
- self.quant_method = PackWeightMethod(weight_names=["weight"])
341
332
 
342
333
  def forward(self, hidden_states):
343
- if use_intel_amx_backend(self):
344
- return torch.ops.sgl_kernel.weight_packed_linear(
345
- hidden_states,
346
- self.weight,
347
- None, # bias
348
- True, # is_vnni
349
- )
350
-
351
- # NOTE: For some unknown reason, router_gemm seems degrade accept length.
352
- if (
353
- _is_cuda
354
- and not self.is_nextn
355
- and hidden_states.shape[0] < 4
356
- and hidden_states.shape[1] == 7168
357
- and self.weight.shape[0] == 256
358
- and _device_sm >= 90
359
- ):
360
- logits = dsv3_router_gemm(hidden_states, self.weight).to(
361
- hidden_states.dtype
362
- )
363
- else:
364
- logits = F.linear(hidden_states, self.weight, None)
365
-
334
+ logits = F.linear(hidden_states, self.weight, None)
366
335
  return logits
367
336
 
368
337
 
369
- class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
338
+ class Glm4MoeSparseMoeBlock(nn.Module):
370
339
  def __init__(
371
340
  self,
372
341
  config: PretrainedConfig,
@@ -374,18 +343,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
374
343
  quant_config: Optional[QuantizationConfig] = None,
375
344
  prefix: str = "",
376
345
  alt_stream: Optional[torch.cuda.Stream] = None,
377
- is_nextn: bool = False,
378
346
  ):
379
347
  nn.Module.__init__(self)
348
+ self.top_k = config.num_experts_per_tok
380
349
  self.tp_size = get_tensor_model_parallel_world_size()
381
- self.ep_size = get_moe_expert_parallel_world_size()
382
350
  self.routed_scaling_factor = config.routed_scaling_factor
383
351
  self.n_shared_experts = config.n_shared_experts
384
- self.num_fused_shared_experts = (
385
- 0
386
- if get_global_server_args().disable_shared_experts_fusion
387
- else config.n_shared_experts
388
- )
389
352
  self.config = config
390
353
  self.layer_id = layer_id
391
354
  self.alt_stream = alt_stream
@@ -402,39 +365,31 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
402
365
  "Only silu is supported for now."
403
366
  )
404
367
 
405
- self.gate = Glm4MoeGate(
406
- config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
407
- )
368
+ self.gate = Glm4MoeGate(config=config, prefix=add_prefix("gate", prefix))
408
369
 
409
370
  self.topk = TopK(
410
- top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
371
+ top_k=self.top_k,
411
372
  renormalize=config.norm_topk_prob,
412
373
  use_grouped_topk=True,
413
374
  num_expert_group=config.n_group,
414
- num_fused_shared_experts=self.num_fused_shared_experts,
415
375
  topk_group=config.topk_group,
416
376
  correction_bias=self.gate.e_score_correction_bias,
417
377
  routed_scaling_factor=self.routed_scaling_factor,
418
378
  )
419
379
 
420
380
  self.experts = get_moe_impl_class(quant_config)(
421
- num_experts=config.n_routed_experts
422
- + self.num_fused_shared_experts
423
- + get_global_server_args().ep_num_redundant_experts,
424
- num_fused_shared_experts=self.num_fused_shared_experts,
425
- top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
381
+ num_experts=config.n_routed_experts,
382
+ top_k=self.top_k,
383
+ layer_id=self.layer_id,
426
384
  hidden_size=config.hidden_size,
427
385
  intermediate_size=config.moe_intermediate_size,
428
- layer_id=self.layer_id,
429
386
  quant_config=quant_config,
430
387
  routed_scaling_factor=self.routed_scaling_factor,
431
388
  prefix=add_prefix("experts", prefix),
432
389
  )
433
390
 
434
- self.shared_experts_is_int8 = False
435
- self.shared_experts_is_fp8 = False
436
- # self.shared_experts_weight_block_size = None
437
- if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
391
+ # shared expert
392
+ if config.n_shared_experts is not None:
438
393
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
439
394
  self.shared_experts = Glm4MoeMLP(
440
395
  hidden_size=config.hidden_size,
@@ -443,21 +398,14 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
443
398
  quant_config=quant_config,
444
399
  reduce_results=False,
445
400
  prefix=add_prefix("shared_experts", prefix),
446
- **(dict(tp_rank=0, tp_size=1) if self.ep_size > 1 else {}),
401
+ **(
402
+ dict(tp_rank=0, tp_size=1)
403
+ if get_moe_a2a_backend().is_deepep()
404
+ or get_moe_a2a_backend().is_mooncake()
405
+ or should_use_flashinfer_cutlass_moe_fp4_allgather()
406
+ else {}
407
+ ),
447
408
  )
448
- is_packed_weight = hasattr(
449
- self.shared_experts.gate_up_proj.quant_method, "quant_config"
450
- )
451
- self.shared_experts_is_int8 = (
452
- not is_packed_weight
453
- and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
454
- )
455
- self.shared_experts_is_fp8 = (
456
- not is_packed_weight
457
- and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
458
- )
459
-
460
- self.top_k = config.num_experts_per_tok
461
409
 
462
410
  if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
463
411
  # TODO: we will support tp < ep in the future
@@ -479,12 +427,46 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
479
427
  get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
480
428
  )
481
429
 
430
+ def get_moe_weights(self):
431
+ return [
432
+ x.data
433
+ for name, x in self.experts.named_parameters()
434
+ if name not in ["correction_bias"]
435
+ ]
436
+
437
+ def forward(
438
+ self,
439
+ hidden_states: torch.Tensor,
440
+ forward_batch: Optional[ForwardBatch] = None,
441
+ should_allreduce_fusion: bool = False,
442
+ use_reduce_scatter: bool = False,
443
+ ) -> torch.Tensor:
444
+ if not self._enable_a2a_moe:
445
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024
446
+ if (
447
+ self.alt_stream is not None
448
+ and hidden_states.shape[0] > 0
449
+ and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
450
+ ):
451
+ return self.forward_normal_dual_stream(
452
+ hidden_states,
453
+ should_allreduce_fusion,
454
+ use_reduce_scatter,
455
+ )
456
+ else:
457
+ return self.forward_normal(
458
+ hidden_states,
459
+ should_allreduce_fusion,
460
+ use_reduce_scatter,
461
+ )
462
+ else:
463
+ return self.forward_deepep(hidden_states, forward_batch)
464
+
482
465
  def forward_normal_dual_stream(
483
466
  self,
484
467
  hidden_states: torch.Tensor,
485
468
  should_allreduce_fusion: bool = False,
486
469
  use_reduce_scatter: bool = False,
487
- gemm_output_zero_allocator: BumpAllocator = None,
488
470
  ) -> torch.Tensor:
489
471
 
490
472
  current_stream = torch.cuda.current_stream()
@@ -498,28 +480,21 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
498
480
  final_hidden_states = self.experts(hidden_states, topk_output)
499
481
  if not _is_cuda:
500
482
  final_hidden_states *= self.routed_scaling_factor
483
+
501
484
  current_stream.wait_stream(self.alt_stream)
485
+ with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
486
+ final_hidden_states_out = torch.empty_like(final_hidden_states)
502
487
 
503
- if self.ep_size > 1:
504
- if (
505
- self.tp_size > 1
506
- and not should_allreduce_fusion
507
- and not use_reduce_scatter
508
- ):
509
- final_hidden_states = tensor_model_parallel_all_reduce(
510
- final_hidden_states
511
- )
512
- final_hidden_states += shared_output
513
- else:
514
- final_hidden_states += shared_output
515
- if (
516
- self.tp_size > 1
517
- and not should_allreduce_fusion
518
- and not use_reduce_scatter
519
- ):
520
- final_hidden_states = tensor_model_parallel_all_reduce(
521
- final_hidden_states
522
- )
488
+ torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
489
+ final_hidden_states = final_hidden_states_out
490
+ sm.tag(final_hidden_states)
491
+ if (
492
+ self.tp_size > 1
493
+ and not should_allreduce_fusion
494
+ and not use_reduce_scatter
495
+ and not should_use_flashinfer_cutlass_moe_fp4_allgather()
496
+ ):
497
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
523
498
  return final_hidden_states
524
499
 
525
500
  def forward_normal(
@@ -527,39 +502,69 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
527
502
  hidden_states: torch.Tensor,
528
503
  should_allreduce_fusion: bool = False,
529
504
  use_reduce_scatter: bool = False,
530
- gemm_output_zero_allocator: BumpAllocator = None,
531
505
  ) -> torch.Tensor:
532
- if hasattr(self, "shared_experts") and use_intel_amx_backend(
533
- self.shared_experts.gate_up_proj
534
- ):
535
- return self.forward_cpu(hidden_states, should_allreduce_fusion)
506
+ if hidden_states.shape[0] > 0:
507
+ shared_output = self._forward_shared_experts(hidden_states)
508
+ # router_logits: (num_tokens, n_experts)
509
+ router_logits = self.gate(hidden_states)
510
+ topk_output = self.topk(hidden_states, router_logits)
511
+ else:
512
+ shared_output = None
513
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
536
514
 
537
- shared_output = self._forward_shared_experts(hidden_states)
538
- # router_logits: (num_tokens, n_experts)
539
- router_logits = self.gate(hidden_states)
540
- topk_output = self.topk(hidden_states, router_logits)
541
515
  final_hidden_states = self.experts(hidden_states, topk_output)
542
516
  if not _is_cuda and not _use_aiter:
543
517
  # fused in biased_grouped_topk so we can skip here
544
518
  final_hidden_states *= self.routed_scaling_factor
545
- if self.ep_size > 1:
546
- if self.tp_size > 1 and not should_allreduce_fusion:
547
- final_hidden_states = tensor_model_parallel_all_reduce(
548
- final_hidden_states
549
- )
550
- if shared_output is not None:
551
- final_hidden_states += shared_output
519
+ if shared_output is not None:
520
+ with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
521
+ final_hidden_states_out = torch.empty_like(final_hidden_states)
522
+ torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
523
+ final_hidden_states = final_hidden_states_out
524
+ sm.tag(final_hidden_states)
525
+ if (
526
+ self.tp_size > 1
527
+ and not should_allreduce_fusion
528
+ and not use_reduce_scatter
529
+ and not should_use_flashinfer_cutlass_moe_fp4_allgather()
530
+ ):
531
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
532
+ return final_hidden_states
533
+
534
+ def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
535
+ shared_output = None
536
+ if hidden_states.shape[0] > 0:
537
+ # router_logits: (num_tokens, n_experts)
538
+ router_logits, _ = self.gate(hidden_states)
539
+ shared_output = self._forward_shared_experts(hidden_states)
540
+ topk_output = self.topk(
541
+ hidden_states,
542
+ router_logits,
543
+ num_token_non_padded=forward_batch.num_token_non_padded,
544
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
545
+ layer_id=self.layer_id,
546
+ ),
547
+ )
552
548
  else:
553
- if shared_output is not None:
554
- final_hidden_states += shared_output
555
- if self.tp_size > 1 and not should_allreduce_fusion:
556
- final_hidden_states = tensor_model_parallel_all_reduce(
557
- final_hidden_states
558
- )
549
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
550
+ final_hidden_states = self.experts(
551
+ hidden_states=hidden_states,
552
+ topk_output=topk_output,
553
+ )
554
+
555
+ if shared_output is not None:
556
+ final_hidden_states.add_(shared_output)
557
+
559
558
  return final_hidden_states
560
559
 
560
+ def _forward_shared_experts(self, hidden_states: torch.Tensor):
561
+ shared_output = None
562
+ if hidden_states.shape[0] > 0:
563
+ shared_output = self.shared_experts(hidden_states)
564
+ return shared_output
561
565
 
562
- class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
566
+
567
+ class Glm4MoeDecoderLayer(nn.Module):
563
568
  def __init__(
564
569
  self,
565
570
  config: PretrainedConfig,
@@ -582,6 +587,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
582
587
  rms_norm_eps = config.rms_norm_eps
583
588
  attention_bias = config.attention_bias
584
589
  self.layer_id = layer_id
590
+
585
591
  self.self_attn = Glm4MoeAttention(
586
592
  hidden_size=self.hidden_size,
587
593
  num_heads=config.num_attention_heads,
@@ -597,15 +603,15 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
597
603
  quant_config=quant_config,
598
604
  prefix=add_prefix("self_attn", prefix),
599
605
  use_qk_norm=config.use_qk_norm,
606
+ alt_stream=alt_stream,
600
607
  )
601
608
 
602
609
  self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
603
610
  is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)
604
611
 
605
- num_layers = 1 if is_nextn else config.num_hidden_layers
606
612
  self.layer_scatter_modes = LayerScatterModes.init_new(
607
613
  layer_id=layer_id,
608
- num_layers=num_layers,
614
+ num_layers=1 if is_nextn else config.num_hidden_layers,
609
615
  is_layer_sparse=self.is_layer_sparse,
610
616
  is_previous_layer_sparse=is_previous_layer_sparse,
611
617
  )
@@ -616,6 +622,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
616
622
  quant_config=quant_config,
617
623
  prefix=add_prefix("mlp", prefix),
618
624
  layer_id=self.layer_id,
625
+ alt_stream=alt_stream,
619
626
  )
620
627
  else:
621
628
  if enable_moe_dense_fully_dp():
@@ -641,7 +648,16 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
641
648
  layer_scatter_modes=self.layer_scatter_modes,
642
649
  input_layernorm=self.input_layernorm,
643
650
  post_attention_layernorm=self.post_attention_layernorm,
644
- allow_reduce_scatter=False,
651
+ allow_reduce_scatter=True,
652
+ is_last_layer=(
653
+ is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
654
+ ),
655
+ )
656
+
657
+ def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
658
+ return is_nextn or (
659
+ self.config.n_routed_experts is not None
660
+ and layer_id >= self.config.first_k_dense_replace
645
661
  )
646
662
 
647
663
  def forward(
@@ -650,8 +666,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
650
666
  hidden_states: torch.Tensor,
651
667
  forward_batch: ForwardBatch,
652
668
  residual: Optional[torch.Tensor],
653
- zero_allocator: BumpAllocator,
654
- gemm_output_zero_allocator: BumpAllocator = None,
655
669
  ) -> torch.Tensor:
656
670
  hidden_states, residual = self.layer_communicator.prepare_attn(
657
671
  hidden_states, residual, forward_batch
@@ -676,44 +690,119 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
676
690
  return hidden_states, residual
677
691
 
678
692
 
679
- class Glm4MoeModel(DeepseekV2Model):
693
+ class Glm4MoeModel(nn.Module):
680
694
  def __init__(
681
695
  self,
682
696
  config: PretrainedConfig,
683
697
  quant_config: Optional[QuantizationConfig] = None,
684
698
  prefix: str = "",
685
- ) -> None:
686
- nn.Module.__init__(self)
687
- self.padding_id = config.pad_token_id
699
+ ):
700
+ super().__init__()
701
+ self.pp_group = get_pp_group()
702
+ self.config = config
688
703
  self.vocab_size = config.vocab_size
689
- self.first_k_dense_replace = config.first_k_dense_replace
704
+ self.embed_dim = config.hidden_size
705
+ if self.pp_group.is_first_rank:
706
+ self.embed_tokens = VocabParallelEmbedding(
707
+ config.vocab_size,
708
+ config.hidden_size,
709
+ enable_tp=not is_dp_attention_enabled(),
710
+ )
711
+ else:
712
+ self.embed_tokens = PPMissingLayer()
690
713
 
691
- self.embed_tokens = VocabParallelEmbedding(
692
- config.vocab_size,
693
- config.hidden_size,
694
- enable_tp=not is_dp_attention_enabled(),
695
- )
696
714
  self.alt_stream = torch.cuda.Stream() if _is_cuda else None
697
- self.layers = nn.ModuleList(
698
- [
699
- Glm4MoeDecoderLayer(
700
- config,
701
- layer_id,
702
- quant_config=quant_config,
703
- prefix=add_prefix(f"layers.{layer_id}", prefix),
704
- alt_stream=self.alt_stream,
705
- )
706
- for layer_id in range(config.num_hidden_layers)
707
- ]
715
+ self.layers, self.start_layer, self.end_layer = make_layers(
716
+ config.num_hidden_layers,
717
+ lambda idx, prefix: Glm4MoeDecoderLayer(
718
+ layer_id=idx,
719
+ config=config,
720
+ quant_config=quant_config,
721
+ prefix=prefix,
722
+ alt_stream=self.alt_stream,
723
+ ),
724
+ pp_rank=self.pp_group.rank_in_group,
725
+ pp_size=self.pp_group.world_size,
726
+ prefix=add_prefix("layers", prefix),
708
727
  )
709
- self.pp_group = get_pp_group()
710
- self.start_layer = 0
711
- self.end_layer = config.num_hidden_layers
712
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
728
+ if self.pp_group.is_last_rank:
729
+ self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
730
+ else:
731
+ self.norm = PPMissingLayer(return_tuple=True)
713
732
 
733
+ def get_input_embeddings(self) -> torch.Tensor:
734
+ return self.embed_tokens
714
735
 
715
- class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
736
+ def forward(
737
+ self,
738
+ input_ids: torch.Tensor,
739
+ positions: torch.Tensor,
740
+ forward_batch: ForwardBatch,
741
+ input_embeds: torch.Tensor = None,
742
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
743
+ ) -> Union[torch.Tensor, PPProxyTensors]:
744
+ if self.pp_group.is_first_rank:
745
+ if input_embeds is None:
746
+ hidden_states = self.embed_tokens(input_ids)
747
+ else:
748
+ hidden_states = input_embeds
749
+ residual = None
750
+ else:
751
+ assert pp_proxy_tensors is not None
752
+ hidden_states = pp_proxy_tensors["hidden_states"]
753
+ residual = pp_proxy_tensors["residual"]
716
754
 
755
+ normal_start_layer = self.start_layer
756
+ normal_end_layer = self.end_layer
757
+ if forward_batch.can_run_tbo:
758
+ if (
759
+ self.first_k_dense_replace > normal_start_layer
760
+ and self.first_k_dense_replace < normal_end_layer
761
+ ):
762
+ normal_end_layer = self.first_k_dense_replace
763
+ elif self.first_k_dense_replace < normal_start_layer:
764
+ normal_end_layer = normal_start_layer = 0
765
+
766
+ for i in range(normal_start_layer, normal_end_layer):
767
+ with get_global_expert_distribution_recorder().with_current_layer(i):
768
+ layer = self.layers[i]
769
+ hidden_states, residual = layer(
770
+ positions,
771
+ hidden_states,
772
+ forward_batch,
773
+ residual,
774
+ )
775
+
776
+ if normal_end_layer != self.end_layer:
777
+ hidden_states, residual = model_forward_maybe_tbo(
778
+ layers=self.layers[normal_end_layer : self.end_layer],
779
+ enable_tbo=True,
780
+ positions=positions,
781
+ forward_batch=forward_batch,
782
+ hidden_states=hidden_states,
783
+ residual=residual,
784
+ input_data_scatter_mode=self.layers[
785
+ normal_end_layer - 1
786
+ ].layer_scatter_modes.layer_output_mode,
787
+ )
788
+
789
+ if not self.pp_group.is_last_rank:
790
+ return PPProxyTensors(
791
+ {
792
+ "hidden_states": hidden_states,
793
+ "residual": residual,
794
+ }
795
+ )
796
+ else:
797
+ if not forward_batch.forward_mode.is_idle():
798
+ if residual is None:
799
+ hidden_states = self.norm(hidden_states)
800
+ else:
801
+ hidden_states, _ = self.norm(hidden_states, residual)
802
+ return hidden_states
803
+
804
+
805
+ class Glm4MoeForCausalLM(nn.Module):
717
806
  def __init__(
718
807
  self,
719
808
  config: PretrainedConfig,
@@ -721,12 +810,10 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
721
810
  prefix: str = "",
722
811
  ) -> None:
723
812
  nn.Module.__init__(self)
724
- config.moe_layer_freq = 1
725
813
  self.config = config
726
814
  self.tp_size = get_tensor_model_parallel_world_size()
727
815
  self.quant_config = quant_config
728
816
  self.pp_group = get_pp_group()
729
- self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
730
817
  self.model = Glm4MoeModel(
731
818
  config, quant_config, prefix=add_prefix("model", prefix)
732
819
  )
@@ -739,49 +826,41 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
739
826
  )
740
827
  self.logits_processor = LogitsProcessor(config)
741
828
 
742
- self._routed_experts_weights_of_layer = LazyValue(
743
- lambda: {
744
- layer_id: layer.mlp.get_moe_weights()
745
- for layer_id, layer in enumerate(self.model.layers)
746
- if isinstance(layer.mlp, DeepseekV2MoE)
747
- }
748
- )
829
+ # For EAGLE3 support
830
+ self.capture_aux_hidden_states = False
749
831
 
750
- def determine_num_fused_shared_experts(
751
- self, architecture: str = "Glm4MoeForCausalLM"
752
- ):
753
- self.num_fused_shared_experts = 0
754
- if get_global_server_args().disable_shared_experts_fusion:
755
- return
832
+ def get_input_embeddings(self) -> nn.Embedding:
833
+ return self.model.embed_tokens
756
834
 
757
- # Only Deepseek V3/R1 can use shared experts fusion optimization now.
758
- disable_reason = None
759
- if (
760
- not _is_cuda
761
- or torch.cuda.get_device_capability("cuda") < (8, 0)
762
- or self.config.architectures[0] != architecture
763
- or self.config.n_shared_experts != 1
764
- ):
765
- disable_reason = "Only GLM-4.5 or GLM-4.6 on NV-platform with capability >= 80 can use shared experts fusion optimization."
766
- elif get_moe_expert_parallel_world_size() > 1:
767
- disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
768
-
769
- if disable_reason is not None:
770
- get_global_server_args().disable_shared_experts_fusion = True
771
- self.num_fused_shared_experts = 0
772
- log_info_on_rank0(
773
- logger,
774
- f"{disable_reason} Shared experts fusion optimization is disabled.",
835
+ @torch.no_grad()
836
+ def forward(
837
+ self,
838
+ input_ids: torch.Tensor,
839
+ positions: torch.Tensor,
840
+ forward_batch: ForwardBatch,
841
+ input_embeds: torch.Tensor = None,
842
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
843
+ ) -> torch.Tensor:
844
+ hidden_states = self.model(
845
+ input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
846
+ )
847
+
848
+ if self.pp_group.is_last_rank:
849
+ return self.logits_processor(
850
+ input_ids, hidden_states, self.lm_head, forward_batch
775
851
  )
776
- return
852
+ else:
853
+ return hidden_states
777
854
 
778
- self.num_fused_shared_experts = self.config.n_shared_experts
855
+ @property
856
+ def start_layer(self):
857
+ return self.model.start_layer
779
858
 
780
- def get_input_embeddings(self) -> nn.Embedding:
781
- return self.model.embed_tokens
859
+ @property
860
+ def end_layer(self):
861
+ return self.model.end_layer
782
862
 
783
863
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
784
-
785
864
  if is_nextn:
786
865
  if hasattr(self.config, "num_nextn_predict_layers"):
787
866
  num_nextn_layers = self.config.num_nextn_predict_layers
@@ -803,117 +882,14 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
803
882
  ("gate_up_proj", "gate_proj", 0),
804
883
  ("gate_up_proj", "up_proj", 1),
805
884
  ]
806
- if self.num_fused_shared_experts > 0:
807
- assert self.num_fused_shared_experts == 1
808
- weights_list = list(weights)
809
- weights_dict = dict(weights_list)
810
- if self.quant_config is not None:
811
- if self.quant_config.get_name() == "w8a8_int8":
812
- suffix_list = [
813
- "down_proj.weight",
814
- "down_proj.weight_scale",
815
- "gate_proj.weight",
816
- "gate_proj.weight_scale",
817
- "up_proj.weight",
818
- "up_proj.weight_scale",
819
- ]
820
- elif (
821
- self.quant_config.get_name() == "fp8"
822
- or self.quant_config.get_name() == "blockwise_int8"
823
- or self.quant_config.get_name() == "compressed_tensors"
824
- ):
825
- suffix_list = [
826
- "down_proj.weight",
827
- "down_proj.weight_scale",
828
- "gate_proj.weight",
829
- "gate_proj.weight_scale",
830
- "up_proj.weight",
831
- "up_proj.weight_scale",
832
- ]
833
- elif self.quant_config.get_name() == "awq":
834
- suffix_list = [
835
- "down_proj.qweight",
836
- "down_proj.qzeros",
837
- "down_proj.scales",
838
- "gate_proj.qweight",
839
- "gate_proj.qzeros",
840
- "gate_proj.scales",
841
- "up_proj.qweight",
842
- "up_proj.qzeros",
843
- "up_proj.scales",
844
- ]
845
- elif self.quant_config.get_name() == "modelopt_fp4":
846
- suffix_list = [
847
- "down_proj.weight",
848
- "down_proj.weight_scale",
849
- "down_proj.weight_scale_2",
850
- "down_proj.input_scale",
851
- "gate_proj.weight",
852
- "gate_proj.weight_scale",
853
- "gate_proj.weight_scale_2",
854
- "gate_proj.input_scale",
855
- "up_proj.weight",
856
- "up_proj.weight_scale",
857
- "up_proj.weight_scale_2",
858
- "up_proj.input_scale",
859
- ]
860
- else:
861
- raise ValueError(
862
- f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
863
- )
864
- else:
865
- suffix_list = [
866
- "down_proj.weight",
867
- "gate_proj.weight",
868
- "up_proj.weight",
869
- ]
870
- names_to_remove = []
871
-
872
- moe_layers = (
873
- range(
874
- self.config.first_k_dense_replace,
875
- self.config.num_hidden_layers,
876
- self.config.moe_layer_freq,
877
- )
878
- if not is_nextn
879
- else [nextn_layer_id]
880
- )
881
-
882
- for moe_layer in moe_layers:
883
- for suffix in suffix_list:
884
- shared_expert_weight_name = (
885
- f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
886
- )
887
- # online fp8 quantization does not load weight_scale
888
- if shared_expert_weight_name not in weights_dict:
889
- continue
890
- weights_list.append(
891
- (
892
- f"model.layers.{moe_layer}."
893
- f"mlp.experts."
894
- f"{self.config.n_routed_experts + 0}"
895
- f".{suffix}",
896
- weights_dict[shared_expert_weight_name],
897
- )
898
- )
899
- names_to_remove += [shared_expert_weight_name]
900
- weights = [w for w in weights_list if w[0] not in names_to_remove]
901
885
 
902
- # Params for weights, fp8 weight scales, fp8 activation scales
903
- # (param_name, weight_name, expert_id, shard_id)
904
886
  expert_params_mapping = FusedMoE.make_expert_params_mapping(
905
887
  ckpt_gate_proj_name="gate_proj",
906
888
  ckpt_down_proj_name="down_proj",
907
889
  ckpt_up_proj_name="up_proj",
908
- num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
890
+ num_experts=self.config.n_routed_experts,
909
891
  )
910
892
 
911
- # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
912
- fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
913
- self.config.q_lora_rank is not None
914
- )
915
- cached_a_proj = {} if fuse_qkv_a_proj else None
916
-
917
893
  if is_nextn:
918
894
  nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
919
895
  nextn_spec_weight_names = [
@@ -969,22 +945,36 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
969
945
  # name will be updated to mlp.experts[0].gate_up_proj, which
970
946
  # will then be updated below in expert_params_mapping
971
947
  # for mlp.experts[0].gate_gate_up_proj, which breaks load.
972
- if ("mlp.experts." in name) and name not in params_dict:
948
+ if "mlp.experts" in name:
973
949
  continue
974
950
  name = name.replace(weight_name, param_name)
975
951
  # Skip loading extra bias for GPTQ models.
976
952
  if name.endswith(".bias") and name not in params_dict:
977
953
  continue
954
+ if name not in params_dict:
955
+ continue
956
+
978
957
  param = params_dict[name]
979
958
  weight_loader = param.weight_loader
980
959
  weight_loader(param, loaded_weight, shard_id)
981
960
  break
982
961
  else:
962
+ # Track if this is an expert weight to enable early skipping
963
+ is_expert_weight = False
964
+
983
965
  for mapping in expert_params_mapping:
984
966
  param_name, weight_name, expert_id, shard_id = mapping
985
967
  if weight_name not in name:
986
968
  continue
969
+
970
+ # Mark as expert weight regardless of whether we can process it
971
+ is_expert_weight = True
972
+
987
973
  name = name.replace(weight_name, param_name)
974
+ if name not in params_dict:
975
+ # Expert weight not on this rank, will be skipped below
976
+ continue
977
+
988
978
  param = params_dict[name]
989
979
  weight_loader = param.weight_loader
990
980
  weight_loader(
@@ -996,65 +986,43 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
996
986
  )
997
987
  break
998
988
  else:
989
+ if is_expert_weight:
990
+ # This is an expert weight but not mapped to this rank, skip all remaining processing
991
+ continue
992
+
999
993
  # Skip loading extra bias for GPTQ models.
1000
994
  if name.endswith(".bias") and name not in params_dict:
1001
995
  continue
1002
- if fuse_qkv_a_proj and (
1003
- "q_a_proj" in name or "kv_a_proj_with_mqa" in name
1004
- ):
1005
- cached_a_proj[name] = loaded_weight
1006
- q_a_proj_name = (
1007
- name
1008
- if "q_a_proj" in name
1009
- else name.replace("kv_a_proj_with_mqa", "q_a_proj")
1010
- )
1011
- kv_a_proj_name = (
1012
- name
1013
- if "kv_a_proj_with_mqa" in name
1014
- else name.replace("q_a_proj", "kv_a_proj_with_mqa")
1015
- )
996
+ if name not in params_dict:
997
+ continue
1016
998
 
1017
- # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
1018
- if (
1019
- q_a_proj_name in cached_a_proj
1020
- and kv_a_proj_name in cached_a_proj
1021
- ):
1022
- q_a_proj_weight = cached_a_proj[q_a_proj_name]
1023
- kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
1024
- fused_weight = torch.cat(
1025
- [q_a_proj_weight, kv_a_proj_weight], dim=0
1026
- )
1027
- param_name = (
1028
- name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
1029
- if "q_a_proj" in name
1030
- else name.replace(
1031
- "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
1032
- )
1033
- )
1034
- param = params_dict[param_name]
1035
-
1036
- weight_loader = getattr(
1037
- param, "weight_loader", default_weight_loader
1038
- )
1039
- weight_loader(param, fused_weight)
1040
- cached_a_proj.pop(q_a_proj_name)
1041
- cached_a_proj.pop(kv_a_proj_name)
1042
- else:
1043
- if (
1044
- "k_scale" in name or "v_scale" in name
1045
- ) and name not in params_dict:
1046
- # modelopt attn kv scale is named differently
1047
- if any(scale in name for scale in ["k_scale", "v_scale"]):
1048
- name = name.replace("_proj", "attn_mqa")
1049
- else:
1050
- logger.warning(
1051
- f"Unknown scale found in checkpoint: {name}"
1052
- )
999
+ if name in params_dict.keys():
1053
1000
  param = params_dict[name]
1054
1001
  weight_loader = getattr(
1055
1002
  param, "weight_loader", default_weight_loader
1056
1003
  )
1057
1004
  weight_loader(param, loaded_weight)
1005
+ else:
1006
+ logger.warning(f"Parameter {name} not found in params_dict")
1007
+
1008
+ def get_embed_and_head(self):
1009
+ return self.model.embed_tokens.weight, self.lm_head.weight
1010
+
1011
+ def set_embed_and_head(self, embed, head):
1012
+ del self.model.embed_tokens.weight
1013
+ del self.lm_head.weight
1014
+ self.model.embed_tokens.weight = embed
1015
+ self.lm_head.weight = head
1016
+ torch.cuda.empty_cache()
1017
+ torch.cuda.synchronize()
1018
+
1019
+ @classmethod
1020
+ def get_model_config_for_expert_location(cls, config):
1021
+ return ModelConfigForExpertLocation(
1022
+ num_layers=config.num_hidden_layers,
1023
+ num_logical_experts=config.n_routed_experts,
1024
+ num_groups=config.n_group,
1025
+ )
1058
1026
 
1059
1027
 
1060
1028
  EntryClass = [Glm4MoeForCausalLM]