sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,202 @@
1
+ import logging
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
8
+ from sglang.srt.utils import is_flashinfer_available
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ _flashinfer_comm = None
13
+ _workspace_manager = None
14
+
15
+ if is_flashinfer_available():
16
+ try:
17
+ import flashinfer.comm as comm
18
+
19
+ _flashinfer_comm = comm
20
+ except ImportError:
21
+ logger.warning(
22
+ "flashinfer.comm is not available, falling back to standard "
23
+ "implementation"
24
+ )
25
+
26
+
27
+ class FlashInferWorkspaceManager:
28
+ def __init__(self):
29
+ self.workspace_tensor = None
30
+ self.ipc_handles = None
31
+ self.world_size = None
32
+ self.rank = None
33
+ self.initialized = False
34
+
35
+ def initialize(
36
+ self,
37
+ world_size: int,
38
+ rank: int,
39
+ max_token_num: int,
40
+ hidden_dim: int,
41
+ group=None,
42
+ use_fp32_lamport: bool = False,
43
+ ):
44
+ """Initialize workspace"""
45
+ if self.initialized and self.world_size == world_size:
46
+ return
47
+
48
+ if _flashinfer_comm is None:
49
+ logger.warning(
50
+ "FlashInfer comm not available, skipping workspace " "initialization"
51
+ )
52
+ return
53
+
54
+ self.cleanup()
55
+
56
+ self.ipc_handles, self.workspace_tensor = (
57
+ comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
58
+ rank,
59
+ world_size,
60
+ max_token_num,
61
+ hidden_dim,
62
+ group=group,
63
+ use_fp32_lamport=use_fp32_lamport,
64
+ )
65
+ )
66
+
67
+ self.world_size = world_size
68
+ self.rank = rank
69
+ self.initialized = True
70
+
71
+ logger.info(
72
+ f"FlashInfer workspace initialized for rank {rank}, "
73
+ f"world_size {world_size}"
74
+ )
75
+
76
+ def cleanup(self):
77
+ """Clean up workspace"""
78
+ if self.initialized and self.ipc_handles is not None:
79
+ try:
80
+ _flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
81
+ self.ipc_handles, group=dist.group.WORLD
82
+ )
83
+ except Exception as e:
84
+ logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
85
+ finally:
86
+ self.workspace_tensor = None
87
+ self.ipc_handles = None
88
+ self.initialized = False
89
+
90
+
91
+ _workspace_manager = FlashInferWorkspaceManager()
92
+
93
+
94
+ def ensure_workspace_initialized(
95
+ max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False
96
+ ):
97
+ """Ensure workspace is initialized"""
98
+ if not is_flashinfer_available() or _flashinfer_comm is None:
99
+ return False
100
+
101
+ world_size = get_tensor_model_parallel_world_size()
102
+ if world_size <= 1:
103
+ return False
104
+
105
+ rank = dist.get_rank()
106
+
107
+ if (
108
+ not _workspace_manager.initialized
109
+ or _workspace_manager.world_size != world_size
110
+ ):
111
+ _workspace_manager.initialize(
112
+ world_size=world_size,
113
+ rank=rank,
114
+ max_token_num=max_token_num,
115
+ hidden_dim=hidden_dim,
116
+ use_fp32_lamport=use_fp32_lamport,
117
+ )
118
+
119
+ return _workspace_manager.initialized
120
+
121
+
122
+ def flashinfer_allreduce_residual_rmsnorm(
123
+ input_tensor: torch.Tensor,
124
+ residual: torch.Tensor,
125
+ weight: torch.Tensor,
126
+ eps: float = 1e-6,
127
+ max_token_num: int = 128,
128
+ use_oneshot: bool = True,
129
+ trigger_completion_at_end: bool = False,
130
+ fp32_acc: bool = False,
131
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
132
+ """
133
+ Use FlashInfer's fused allreduce + residual + RMS norm operation
134
+
135
+ Args:
136
+ input_tensor: Input tensor that needs allreduce
137
+ residual: Residual tensor
138
+ weight: RMS norm weight
139
+ eps: RMS norm epsilon
140
+ max_token_num: Maximum token number
141
+ use_oneshot: Whether to use oneshot mode
142
+ trigger_completion_at_end: Whether to trigger completion at end
143
+ fp32_acc: Whether to use fp32 precision
144
+
145
+ Returns:
146
+ Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output)
147
+ """
148
+ if not is_flashinfer_available() or _flashinfer_comm is None:
149
+ logger.debug(
150
+ "FlashInfer not available, falling back to standard " "implementation"
151
+ )
152
+ return None, None
153
+
154
+ world_size = get_tensor_model_parallel_world_size()
155
+ if world_size <= 1:
156
+ logger.debug("Single GPU, no need for allreduce fusion")
157
+ return None, None
158
+
159
+ if not ensure_workspace_initialized(
160
+ max_token_num=max_token_num,
161
+ hidden_dim=input_tensor.shape[-1],
162
+ use_fp32_lamport=(input_tensor.dtype == torch.float32),
163
+ ):
164
+ logger.debug("FlashInfer workspace not available")
165
+ return None, None
166
+
167
+ token_num, hidden_dim = input_tensor.shape
168
+
169
+ residual_out = torch.empty_like(residual)
170
+ norm_out = torch.empty_like(input_tensor)
171
+
172
+ _flashinfer_comm.trtllm_allreduce_fusion(
173
+ allreduce_in=input_tensor,
174
+ world_size=world_size,
175
+ world_rank=dist.get_rank(),
176
+ token_num=token_num,
177
+ hidden_dim=hidden_dim,
178
+ workspace_ptrs=_workspace_manager.workspace_tensor,
179
+ launch_with_pdl=True,
180
+ use_oneshot=use_oneshot,
181
+ trigger_completion_at_end=trigger_completion_at_end,
182
+ fp32_acc=fp32_acc,
183
+ pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm),
184
+ allreduce_out=None,
185
+ residual_in=residual,
186
+ residual_out=residual_out,
187
+ norm_out=norm_out,
188
+ quant_out=None,
189
+ scale_out=None,
190
+ rms_gamma=weight,
191
+ rms_eps=eps,
192
+ scale_factor=None,
193
+ layout_code=None,
194
+ )
195
+
196
+ return norm_out, residual_out
197
+
198
+
199
+ def cleanup_flashinfer_workspace():
200
+ global _workspace_manager
201
+ if _workspace_manager is not None:
202
+ _workspace_manager.cleanup()
@@ -163,6 +163,32 @@ class RMSNorm(CustomOp):
163
163
  else:
164
164
  return self.forward_native(x, residual)
165
165
 
166
+ def forward_with_allreduce_fusion(
167
+ self,
168
+ x: torch.Tensor,
169
+ residual: Optional[torch.Tensor] = None,
170
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
171
+ """
172
+ Forward method with allreduce fusion, prioritizing flashinfer fused operations
173
+ """
174
+ if residual is not None:
175
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
176
+ from sglang.srt.layers.flashinfer_comm_fusion import (
177
+ flashinfer_allreduce_residual_rmsnorm,
178
+ )
179
+
180
+ if get_tensor_model_parallel_world_size() > 1:
181
+ fused_result = flashinfer_allreduce_residual_rmsnorm(
182
+ input_tensor=x,
183
+ residual=residual,
184
+ weight=self.weight,
185
+ eps=self.variance_epsilon,
186
+ )
187
+ if fused_result[0] is not None:
188
+ return fused_result
189
+
190
+ return self.forward(x, residual)
191
+
166
192
 
167
193
  class GemmaRMSNorm(CustomOp):
168
194
  def __init__(
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
17
17
  tensor_model_parallel_all_gather,
18
18
  tensor_model_parallel_all_reduce,
19
19
  )
20
+ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
20
21
  from sglang.srt.layers.parameter import (
21
22
  BasevLLMParameter,
22
23
  BlockQuantScaleParameter,
@@ -31,10 +32,10 @@ from sglang.srt.layers.quantization.base_config import (
31
32
  QuantizeMethodBase,
32
33
  )
33
34
  from sglang.srt.utils import (
34
- _process_weight_after_loading,
35
35
  cpu_has_amx_support,
36
36
  is_cpu,
37
37
  set_weight_attrs,
38
+ use_intel_amx_backend,
38
39
  )
39
40
 
40
41
  logger = logging.getLogger(__name__)
@@ -175,7 +176,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
175
176
 
176
177
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
177
178
  if _is_cpu and _is_cpu_amx_available:
178
- _process_weight_after_loading(layer, ["weight"])
179
+ _amx_process_weight_after_loading(layer, ["weight"])
179
180
 
180
181
  def apply(
181
182
  self,
@@ -184,7 +185,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
184
185
  bias: Optional[torch.Tensor] = None,
185
186
  ) -> torch.Tensor:
186
187
 
187
- if getattr(layer, "use_intel_amx_backend", False):
188
+ if use_intel_amx_backend(layer):
188
189
  return torch.ops.sgl_kernel.weight_packed_linear(
189
190
  x, layer.weight, bias, True # is_vnni
190
191
  )
@@ -425,8 +426,26 @@ class ColumnParallelLinear(LinearBase):
425
426
  if output_dim is not None and not use_bitsandbytes_4bit:
426
427
  shard_size = param_data.shape[output_dim]
427
428
  start_idx = self.tp_rank * shard_size
428
- if not self.use_presharded_weights:
429
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
429
+
430
+ if _is_cpu:
431
+ from sglang.srt.model_loader.weight_utils import (
432
+ narrow_padded_param_and_loaded_weight,
433
+ )
434
+
435
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
436
+ param_data,
437
+ loaded_weight,
438
+ 0, # param_data_start
439
+ start_idx,
440
+ output_dim,
441
+ shard_size,
442
+ not self.use_presharded_weights,
443
+ )
444
+ else:
445
+ if not self.use_presharded_weights:
446
+ loaded_weight = loaded_weight.narrow(
447
+ output_dim, start_idx, shard_size
448
+ )
430
449
 
431
450
  # Special case for loading scales off disk, which often do not
432
451
  # have a shape (such as in the case of AutoFP8).
@@ -643,10 +662,29 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
643
662
 
644
663
  param_data = param_data.narrow(output_dim, shard_offset, shard_size)
645
664
  start_idx = self.tp_rank * shard_size
646
- # bitsandbytes loads the weights of the specific portion
647
- # no need to narrow here
648
- if not use_bitsandbytes_4bit and not self.use_presharded_weights:
649
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
665
+
666
+ if _is_cpu:
667
+ from sglang.srt.model_loader.weight_utils import (
668
+ narrow_padded_param_and_loaded_weight,
669
+ )
670
+
671
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
672
+ param_data,
673
+ loaded_weight,
674
+ 0, # param_data_start
675
+ start_idx,
676
+ output_dim,
677
+ shard_size,
678
+ not use_bitsandbytes_4bit and not self.use_presharded_weights,
679
+ )
680
+ else:
681
+ # bitsandbytes loads the weights of the specific portion
682
+ # no need to narrow here
683
+ if not use_bitsandbytes_4bit and not self.use_presharded_weights:
684
+ loaded_weight = loaded_weight.narrow(
685
+ output_dim, start_idx, shard_size
686
+ )
687
+
650
688
  # Special case for AQLM codebooks.
651
689
  elif is_metadata:
652
690
  # metadata indicates fixed size concatenated along dim 0
@@ -1111,10 +1149,27 @@ class QKVParallelLinear(ColumnParallelLinear):
1111
1149
  shard_id = self.tp_rank // self.num_kv_head_replicas
1112
1150
  start_idx = shard_id * shard_size
1113
1151
 
1114
- # bitsandbytes loads the weights of the specific portion
1115
- # no need to narrow here
1116
- if not use_bitsandbytes_4bit and not self.use_presharded_weights:
1117
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1152
+ if _is_cpu:
1153
+ from sglang.srt.model_loader.weight_utils import (
1154
+ narrow_padded_param_and_loaded_weight,
1155
+ )
1156
+
1157
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
1158
+ param_data,
1159
+ loaded_weight,
1160
+ 0, # param_data_start
1161
+ start_idx,
1162
+ output_dim,
1163
+ shard_size,
1164
+ not use_bitsandbytes_4bit and not self.use_presharded_weights,
1165
+ )
1166
+ else:
1167
+ # bitsandbytes loads the weights of the specific portion
1168
+ # no need to narrow here
1169
+ if not use_bitsandbytes_4bit and not self.use_presharded_weights:
1170
+ loaded_weight = loaded_weight.narrow(
1171
+ output_dim, start_idx, shard_size
1172
+ )
1118
1173
 
1119
1174
  # Special case for for AQLM codebooks.
1120
1175
  elif is_metadata:
@@ -1256,7 +1311,22 @@ class RowParallelLinear(LinearBase):
1256
1311
  ):
1257
1312
  shard_size = param_data.shape[input_dim]
1258
1313
  start_idx = self.tp_rank * shard_size
1259
- loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1314
+
1315
+ if _is_cpu:
1316
+ from sglang.srt.model_loader.weight_utils import (
1317
+ narrow_padded_param_and_loaded_weight,
1318
+ )
1319
+
1320
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
1321
+ param_data,
1322
+ loaded_weight,
1323
+ 0, # param_data_start
1324
+ start_idx,
1325
+ input_dim,
1326
+ shard_size,
1327
+ )
1328
+ else:
1329
+ loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1260
1330
 
1261
1331
  # Special case for loading scales off disk, which often do not
1262
1332
  # have a shape (such as in the case of AutoFP8).
@@ -42,7 +42,7 @@ from sglang.srt.model_executor.forward_batch_info import (
42
42
  ForwardBatch,
43
43
  ForwardMode,
44
44
  )
45
- from sglang.srt.utils import dump_to_file
45
+ from sglang.srt.utils import dump_to_file, use_intel_amx_backend
46
46
 
47
47
  logger = logging.getLogger(__name__)
48
48
 
@@ -436,13 +436,13 @@ class LogitsProcessor(nn.Module):
436
436
  if self.do_tensor_parallel_all_gather_dp_attn:
437
437
  logits_metadata.compute_dp_attention_metadata(hidden_states)
438
438
  hidden_states, local_hidden_states = (
439
- logits_metadata.gathered_buffer,
440
- hidden_states.clone(),
439
+ torch.empty_like(logits_metadata.gathered_buffer),
440
+ hidden_states,
441
441
  )
442
442
  dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
443
443
 
444
444
  if hasattr(lm_head, "weight"):
445
- if getattr(lm_head, "use_intel_amx_backend", False):
445
+ if use_intel_amx_backend(lm_head):
446
446
  logits = torch.ops.sgl_kernel.weight_packed_linear(
447
447
  hidden_states.to(lm_head.weight.dtype),
448
448
  lm_head.weight,
@@ -0,0 +1,215 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Cutlass W4A8 MoE kernel."""
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from sgl_kernel import (
7
+ cutlass_w4a8_moe_mm,
8
+ get_cutlass_w4a8_moe_mm_data,
9
+ sgl_per_tensor_quant_fp8,
10
+ silu_and_mul,
11
+ )
12
+
13
+ from sglang.srt.layers.moe.ep_moe.kernels import (
14
+ post_reorder_triton_kernel,
15
+ pre_reorder_triton_kernel_for_cutlass_moe,
16
+ run_cutlass_moe_ep_preproess,
17
+ )
18
+
19
+
20
+ def cutlass_w4a8_moe(
21
+ start_expert_id: int,
22
+ end_expert_id: int,
23
+ total_num_experts: int,
24
+ a: torch.Tensor,
25
+ w1_q: torch.Tensor,
26
+ w2_q: torch.Tensor,
27
+ w1_scale: torch.Tensor,
28
+ w2_scale: torch.Tensor,
29
+ topk_weights: torch.Tensor,
30
+ topk_ids_: torch.Tensor,
31
+ local_topk_ids: torch.Tensor,
32
+ a_strides1: torch.Tensor,
33
+ b_strides1: torch.Tensor,
34
+ c_strides1: torch.Tensor,
35
+ a_strides2: torch.Tensor,
36
+ b_strides2: torch.Tensor,
37
+ c_strides2: torch.Tensor,
38
+ s_strides13: torch.Tensor,
39
+ s_strides2: torch.Tensor,
40
+ expert_offsets: torch.Tensor,
41
+ problem_sizes1: torch.Tensor,
42
+ problem_sizes2: torch.Tensor,
43
+ a1_scale: Optional[torch.Tensor] = None,
44
+ a2_scale: Optional[torch.Tensor] = None,
45
+ apply_router_weight_on_input: bool = False,
46
+ ) -> torch.Tensor:
47
+ """
48
+ This function computes a w4a8-quantized Mixture of Experts (MoE) layer
49
+ using two sets of quantized weights, w1_q and w2_q, and top-k gating
50
+ mechanism. The matrix multiplications are implemented with CUTLASS
51
+ grouped gemm.
52
+
53
+ Parameters:
54
+ - a (torch.Tensor): The input tensor to the MoE layer.
55
+ Shape: [M, K]
56
+ - w1_q (torch.Tensor): The first set of int4-quantized expert weights.
57
+ Shape: [num_experts, N * 2, K // 2]
58
+ (the weights are passed transposed and int4-packed)
59
+ - w2_q (torch.Tensor): The second set of int4-quantized expert weights.
60
+ Shape: [num_experts, K, N // 2]
61
+ (the weights are passed transposed and int4-packed)
62
+ - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
63
+ Shape: [num_experts, K // 512, N * 8]
64
+ - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
65
+ Shape: [num_experts, N // 512, K * 4]
66
+ - topk_weights (torch.Tensor): The weights of each token->expert mapping.
67
+ - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
68
+ - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
69
+ - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
70
+ - a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
71
+ - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
72
+ - c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
73
+ - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
74
+ - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
75
+ - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
76
+ Shape: scalar or [1, K]
77
+ - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
78
+ quantize the intermediate result between the gemms.
79
+ Shape: scalar or [1, N]
80
+ - apply_router_weight_on_input (bool): When true, the topk weights are
81
+ applied directly on the inputs. This is only applicable when topk is 1.
82
+
83
+ Returns:
84
+ - torch.Tensor: The fp8 output tensor after applying the MoE layer.
85
+ """
86
+ assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
87
+ assert w1_q.dtype == torch.int8
88
+ assert w2_q.dtype == torch.int8
89
+ assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
90
+ assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
91
+ assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
92
+ assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
93
+ assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
94
+ assert (
95
+ w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
96
+ and w1_scale.shape[2] == w1_q.shape[1] * 4
97
+ ), "W1 scale shape mismatch"
98
+ assert (
99
+ w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
100
+ and w2_scale.shape[2] == w2_q.shape[1] * 4
101
+ ), "W2 scale shape mismatch"
102
+
103
+ assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
104
+ assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
105
+ assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
106
+ assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
107
+ num_experts = w1_q.size(0)
108
+ m = a.size(0)
109
+ k = w1_q.size(2) * 2 # w1_q is transposed and packed
110
+ n = w2_q.size(2) * 2 # w2_q is transposed and packed
111
+ topk = topk_ids_.size(1)
112
+
113
+ if apply_router_weight_on_input:
114
+ assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
115
+
116
+ device = a.device
117
+
118
+ _, src2dst, _ = run_cutlass_moe_ep_preproess(
119
+ local_topk_ids,
120
+ num_experts,
121
+ )
122
+
123
+ gateup_input = torch.empty(
124
+ (m * topk, k),
125
+ device=device,
126
+ dtype=torch.float8_e4m3fn,
127
+ )
128
+
129
+ pre_reorder_triton_kernel_for_cutlass_moe[(m,)](
130
+ a,
131
+ gateup_input,
132
+ src2dst,
133
+ local_topk_ids,
134
+ a1_scale,
135
+ total_num_experts,
136
+ topk,
137
+ k,
138
+ BLOCK_SIZE=512,
139
+ )
140
+
141
+ # NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
142
+ # they are kept to allow for a quick switch of the permutation logic
143
+ # from the current triton kernel implementation to the cutlass-based one if needed.
144
+ a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
145
+ c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
146
+ get_cutlass_w4a8_moe_mm_data(
147
+ local_topk_ids,
148
+ expert_offsets,
149
+ problem_sizes1,
150
+ problem_sizes2,
151
+ a_map,
152
+ c_map,
153
+ num_experts,
154
+ n,
155
+ k,
156
+ )
157
+
158
+ c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
159
+ c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
160
+
161
+ cutlass_w4a8_moe_mm(
162
+ c1,
163
+ gateup_input,
164
+ w1_q,
165
+ a1_scale.float(),
166
+ w1_scale,
167
+ expert_offsets[:-1],
168
+ problem_sizes1,
169
+ a_strides1,
170
+ b_strides1,
171
+ c_strides1,
172
+ s_strides13,
173
+ 128,
174
+ topk,
175
+ )
176
+
177
+ intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
178
+ silu_and_mul(c1, intermediate)
179
+
180
+ intermediate_q = torch.empty(
181
+ intermediate.shape, dtype=torch.float8_e4m3fn, device=device
182
+ )
183
+ sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
184
+
185
+ cutlass_w4a8_moe_mm(
186
+ c2,
187
+ intermediate_q,
188
+ w2_q,
189
+ a2_scale.float(),
190
+ w2_scale,
191
+ expert_offsets[:-1],
192
+ problem_sizes2,
193
+ a_strides2,
194
+ b_strides2,
195
+ c_strides2,
196
+ s_strides2,
197
+ 128,
198
+ topk,
199
+ )
200
+
201
+ output = torch.empty_like(a)
202
+ post_reorder_triton_kernel[(m,)](
203
+ c2,
204
+ output,
205
+ src2dst,
206
+ topk_ids_,
207
+ topk_weights,
208
+ start_expert_id,
209
+ end_expert_id,
210
+ topk,
211
+ k,
212
+ 0,
213
+ BLOCK_SIZE=512,
214
+ )
215
+ return output