sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +48 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +34 -0
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/nixl/conn.py +6 -6
  10. sglang/srt/disaggregation/prefill.py +2 -2
  11. sglang/srt/disaggregation/utils.py +1 -1
  12. sglang/srt/distributed/parallel_state.py +44 -17
  13. sglang/srt/entrypoints/EngineBase.py +8 -0
  14. sglang/srt/entrypoints/engine.py +40 -6
  15. sglang/srt/entrypoints/http_server.py +111 -24
  16. sglang/srt/entrypoints/openai/protocol.py +4 -2
  17. sglang/srt/eplb/__init__.py +0 -0
  18. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  19. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  20. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  21. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  22. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  24. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  25. sglang/srt/hf_transformers_utils.py +2 -1
  26. sglang/srt/layers/activation.py +2 -2
  27. sglang/srt/layers/amx_utils.py +86 -0
  28. sglang/srt/layers/attention/ascend_backend.py +219 -0
  29. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  30. sglang/srt/layers/attention/tbo_backend.py +37 -9
  31. sglang/srt/layers/communicator.py +18 -2
  32. sglang/srt/layers/dp_attention.py +9 -3
  33. sglang/srt/layers/elementwise.py +76 -12
  34. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  35. sglang/srt/layers/layernorm.py +26 -0
  36. sglang/srt/layers/linear.py +84 -14
  37. sglang/srt/layers/logits_processor.py +4 -4
  38. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +36 -13
  40. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  41. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
  42. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
  43. sglang/srt/layers/moe/router.py +60 -22
  44. sglang/srt/layers/moe/topk.py +10 -28
  45. sglang/srt/layers/parameter.py +67 -7
  46. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  47. sglang/srt/layers/quantization/fp8.py +44 -0
  48. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  49. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  50. sglang/srt/layers/quantization/gptq.py +5 -1
  51. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  52. sglang/srt/layers/quantization/quant_utils.py +166 -0
  53. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  54. sglang/srt/layers/rotary_embedding.py +2 -2
  55. sglang/srt/layers/vocab_parallel_embedding.py +11 -7
  56. sglang/srt/lora/lora.py +4 -5
  57. sglang/srt/lora/lora_manager.py +73 -20
  58. sglang/srt/managers/configure_logging.py +1 -1
  59. sglang/srt/managers/io_struct.py +50 -13
  60. sglang/srt/managers/mm_utils.py +73 -59
  61. sglang/srt/managers/multimodal_processor.py +2 -6
  62. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  63. sglang/srt/managers/schedule_batch.py +77 -84
  64. sglang/srt/managers/scheduler.py +113 -59
  65. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  66. sglang/srt/managers/session_controller.py +12 -3
  67. sglang/srt/managers/tokenizer_manager.py +314 -103
  68. sglang/srt/managers/tp_worker.py +13 -1
  69. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  70. sglang/srt/mem_cache/allocator.py +290 -0
  71. sglang/srt/mem_cache/chunk_cache.py +34 -2
  72. sglang/srt/mem_cache/memory_pool.py +289 -3
  73. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  74. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  75. sglang/srt/model_executor/forward_batch_info.py +17 -4
  76. sglang/srt/model_executor/model_runner.py +297 -56
  77. sglang/srt/model_loader/loader.py +41 -0
  78. sglang/srt/model_loader/weight_utils.py +72 -4
  79. sglang/srt/models/deepseek_nextn.py +1 -3
  80. sglang/srt/models/deepseek_v2.py +181 -45
  81. sglang/srt/models/deepseek_vl2.py +3 -5
  82. sglang/srt/models/gemma3_causal.py +1 -2
  83. sglang/srt/models/gemma3n_causal.py +4 -3
  84. sglang/srt/models/gemma3n_mm.py +4 -20
  85. sglang/srt/models/hunyuan.py +1 -1
  86. sglang/srt/models/kimi_vl.py +1 -2
  87. sglang/srt/models/llama.py +10 -4
  88. sglang/srt/models/llama4.py +32 -45
  89. sglang/srt/models/llama_eagle3.py +61 -11
  90. sglang/srt/models/llava.py +5 -5
  91. sglang/srt/models/minicpmo.py +2 -2
  92. sglang/srt/models/mistral.py +1 -1
  93. sglang/srt/models/mllama4.py +43 -11
  94. sglang/srt/models/phi4mm.py +1 -3
  95. sglang/srt/models/pixtral.py +3 -7
  96. sglang/srt/models/qwen2.py +31 -3
  97. sglang/srt/models/qwen2_5_vl.py +1 -3
  98. sglang/srt/models/qwen2_audio.py +200 -0
  99. sglang/srt/models/qwen2_moe.py +32 -6
  100. sglang/srt/models/qwen2_vl.py +1 -4
  101. sglang/srt/models/qwen3.py +94 -25
  102. sglang/srt/models/qwen3_moe.py +68 -21
  103. sglang/srt/models/vila.py +3 -8
  104. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  105. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  106. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  107. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  108. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  109. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  110. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  111. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  112. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  117. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  120. sglang/srt/operations_strategy.py +6 -2
  121. sglang/srt/reasoning_parser.py +26 -0
  122. sglang/srt/sampling/sampling_batch_info.py +39 -1
  123. sglang/srt/server_args.py +69 -22
  124. sglang/srt/speculative/build_eagle_tree.py +57 -18
  125. sglang/srt/speculative/eagle_worker.py +6 -4
  126. sglang/srt/two_batch_overlap.py +200 -27
  127. sglang/srt/utils.py +306 -146
  128. sglang/srt/warmup.py +12 -3
  129. sglang/test/runners.py +10 -1
  130. sglang/test/test_utils.py +15 -3
  131. sglang/version.py +1 -1
  132. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  133. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
  134. sglang/math_utils.py +0 -8
  135. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  136. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  137. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  138. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  139. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  140. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  141. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,7 @@ from sglang.srt.utils import is_hip
8
8
 
9
9
  _is_hip = is_hip()
10
10
 
11
+
11
12
  fused_softcap_autotune = triton.autotune(
12
13
  configs=[
13
14
  triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
189
190
  assert x.shape == residual.shape and x.dtype == residual.dtype
190
191
  output, mid = torch.empty_like(x), torch.empty_like(x)
191
192
  bs, hidden_dim = x.shape
192
-
193
- min_num_warps = 16 if _is_hip else 32
194
-
195
193
  if autotune:
196
194
  fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
197
195
  output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
198
196
  )
199
197
  else:
198
+ max_warps = 16 if _is_hip else 32
200
199
  config = {
201
200
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
202
201
  "num_warps": max(
203
- min(
204
- triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
205
- ),
206
- 4,
202
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
207
203
  ),
208
204
  }
209
205
 
@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
260
256
  else:
261
257
  output = torch.empty_like(x)
262
258
  bs, hidden_dim = x.shape
263
-
264
- min_num_warps = 16 if _is_hip else 32
265
-
259
+ max_warps = 16 if _is_hip else 32
266
260
  config = {
267
261
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
268
262
  "num_warps": max(
269
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
263
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
270
264
  ),
271
265
  }
272
266
 
@@ -331,6 +325,75 @@ class FusedDualResidualRMSNorm:
331
325
  return self.rmsnorm2.forward_native(residual), residual
332
326
 
333
327
 
328
+ @triton.jit
329
+ def experts_combine_kernel(
330
+ out_hidden_states,
331
+ moe_hidden_states,
332
+ mlp_hidden_states,
333
+ combine_k: tl.constexpr,
334
+ hidden_dim: tl.constexpr,
335
+ BLOCK_SIZE: tl.constexpr,
336
+ ):
337
+ pid = tl.program_id(0)
338
+ start_index_mlp = pid * hidden_dim
339
+ start_index_rmoe = pid * hidden_dim * combine_k
340
+ offsets = tl.arange(0, BLOCK_SIZE)
341
+ mask = offsets < hidden_dim
342
+ combine_k_offsets = tl.arange(0, combine_k)
343
+
344
+ moe_x = tl.load(
345
+ moe_hidden_states
346
+ + start_index_rmoe
347
+ + combine_k_offsets[:, None] * hidden_dim
348
+ + offsets[None, :],
349
+ mask=mask[None, :],
350
+ other=0.0,
351
+ )
352
+ moe_x = tl.sum(moe_x, axis=0)
353
+ mlp_x = tl.load(mlp_hidden_states + start_index_mlp + offsets, mask=mask, other=0.0)
354
+ combined_x = (moe_x + mlp_x) / 1.4142135623730951
355
+
356
+ tl.store(out_hidden_states + start_index_mlp + offsets, combined_x, mask=mask)
357
+
358
+
359
+ def experts_combine_triton(moe_hidden_states, mlp_hidden_states, output_buffer=None):
360
+ assert moe_hidden_states.is_contiguous()
361
+ assert mlp_hidden_states.is_contiguous()
362
+
363
+ if len(moe_hidden_states.shape) == 2:
364
+ combine_k = 1 # pre-combined
365
+ else:
366
+ combine_k = moe_hidden_states.shape[1]
367
+
368
+ if output_buffer is None:
369
+ out_hidden_states = torch.empty_like(mlp_hidden_states)
370
+ else:
371
+ flat_output_buffer = output_buffer.view(mlp_hidden_states.dtype).reshape(-1)
372
+ assert flat_output_buffer.numel() >= mlp_hidden_states.numel()
373
+ out_hidden_states = flat_output_buffer[: mlp_hidden_states.numel()].reshape(
374
+ mlp_hidden_states.shape
375
+ )
376
+
377
+ bs, hidden_dim = mlp_hidden_states.shape
378
+
379
+ config = {
380
+ "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
381
+ "num_warps": max(
382
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 1024)), 8), 4
383
+ ),
384
+ }
385
+
386
+ experts_combine_kernel[(bs,)](
387
+ out_hidden_states,
388
+ moe_hidden_states,
389
+ mlp_hidden_states,
390
+ combine_k,
391
+ hidden_dim,
392
+ **config,
393
+ )
394
+ return out_hidden_states
395
+
396
+
334
397
  # gelu on first half of vector
335
398
  @triton.jit
336
399
  def gelu_and_mul_kernel(
@@ -400,10 +463,11 @@ def gelu_and_mul_triton(
400
463
  out_scales = scales
401
464
  static_scale = True
402
465
 
466
+ max_warps = 16 if _is_hip else 32
403
467
  config = {
404
468
  # 8 ele per thread (not tuned)
405
469
  "num_warps": max(
406
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
470
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
407
471
  ),
408
472
  }
409
473
 
@@ -0,0 +1,202 @@
1
+ import logging
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
8
+ from sglang.srt.utils import is_flashinfer_available
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ _flashinfer_comm = None
13
+ _workspace_manager = None
14
+
15
+ if is_flashinfer_available():
16
+ try:
17
+ import flashinfer.comm as comm
18
+
19
+ _flashinfer_comm = comm
20
+ except ImportError:
21
+ logger.warning(
22
+ "flashinfer.comm is not available, falling back to standard "
23
+ "implementation"
24
+ )
25
+
26
+
27
+ class FlashInferWorkspaceManager:
28
+ def __init__(self):
29
+ self.workspace_tensor = None
30
+ self.ipc_handles = None
31
+ self.world_size = None
32
+ self.rank = None
33
+ self.initialized = False
34
+
35
+ def initialize(
36
+ self,
37
+ world_size: int,
38
+ rank: int,
39
+ max_token_num: int,
40
+ hidden_dim: int,
41
+ group=None,
42
+ use_fp32_lamport: bool = False,
43
+ ):
44
+ """Initialize workspace"""
45
+ if self.initialized and self.world_size == world_size:
46
+ return
47
+
48
+ if _flashinfer_comm is None:
49
+ logger.warning(
50
+ "FlashInfer comm not available, skipping workspace " "initialization"
51
+ )
52
+ return
53
+
54
+ self.cleanup()
55
+
56
+ self.ipc_handles, self.workspace_tensor = (
57
+ comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
58
+ rank,
59
+ world_size,
60
+ max_token_num,
61
+ hidden_dim,
62
+ group=group,
63
+ use_fp32_lamport=use_fp32_lamport,
64
+ )
65
+ )
66
+
67
+ self.world_size = world_size
68
+ self.rank = rank
69
+ self.initialized = True
70
+
71
+ logger.info(
72
+ f"FlashInfer workspace initialized for rank {rank}, "
73
+ f"world_size {world_size}"
74
+ )
75
+
76
+ def cleanup(self):
77
+ """Clean up workspace"""
78
+ if self.initialized and self.ipc_handles is not None:
79
+ try:
80
+ _flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
81
+ self.ipc_handles, group=dist.group.WORLD
82
+ )
83
+ except Exception as e:
84
+ logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
85
+ finally:
86
+ self.workspace_tensor = None
87
+ self.ipc_handles = None
88
+ self.initialized = False
89
+
90
+
91
+ _workspace_manager = FlashInferWorkspaceManager()
92
+
93
+
94
+ def ensure_workspace_initialized(
95
+ max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False
96
+ ):
97
+ """Ensure workspace is initialized"""
98
+ if not is_flashinfer_available() or _flashinfer_comm is None:
99
+ return False
100
+
101
+ world_size = get_tensor_model_parallel_world_size()
102
+ if world_size <= 1:
103
+ return False
104
+
105
+ rank = dist.get_rank()
106
+
107
+ if (
108
+ not _workspace_manager.initialized
109
+ or _workspace_manager.world_size != world_size
110
+ ):
111
+ _workspace_manager.initialize(
112
+ world_size=world_size,
113
+ rank=rank,
114
+ max_token_num=max_token_num,
115
+ hidden_dim=hidden_dim,
116
+ use_fp32_lamport=use_fp32_lamport,
117
+ )
118
+
119
+ return _workspace_manager.initialized
120
+
121
+
122
+ def flashinfer_allreduce_add_rmsnorm(
123
+ input_tensor: torch.Tensor,
124
+ residual: torch.Tensor,
125
+ weight: torch.Tensor,
126
+ eps: float = 1e-6,
127
+ max_token_num: int = 1024,
128
+ use_oneshot: bool = True,
129
+ trigger_completion_at_end: bool = False,
130
+ fp32_acc: bool = False,
131
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
132
+ """
133
+ Use FlashInfer's fused allreduce + residual + RMS norm operation
134
+
135
+ Args:
136
+ input_tensor: Input tensor that needs allreduce
137
+ residual: Residual tensor
138
+ weight: RMS norm weight
139
+ eps: RMS norm epsilon
140
+ max_token_num: Maximum token number
141
+ use_oneshot: Whether to use oneshot mode
142
+ trigger_completion_at_end: Whether to trigger completion at end
143
+ fp32_acc: Whether to use fp32 precision
144
+
145
+ Returns:
146
+ Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output)
147
+ """
148
+ if not is_flashinfer_available() or _flashinfer_comm is None:
149
+ logger.debug(
150
+ "FlashInfer not available, falling back to standard " "implementation"
151
+ )
152
+ return None, None
153
+
154
+ world_size = get_tensor_model_parallel_world_size()
155
+ if world_size <= 1:
156
+ logger.debug("Single GPU, no need for allreduce fusion")
157
+ return None, None
158
+
159
+ if not ensure_workspace_initialized(
160
+ max_token_num=max_token_num,
161
+ hidden_dim=input_tensor.shape[-1],
162
+ use_fp32_lamport=(input_tensor.dtype == torch.float32),
163
+ ):
164
+ logger.debug("FlashInfer workspace not available")
165
+ return None, None
166
+
167
+ token_num, hidden_dim = input_tensor.shape
168
+
169
+ residual_out = torch.empty_like(residual)
170
+ norm_out = torch.empty_like(input_tensor)
171
+
172
+ _flashinfer_comm.trtllm_allreduce_fusion(
173
+ allreduce_in=input_tensor,
174
+ world_size=world_size,
175
+ world_rank=dist.get_rank(),
176
+ token_num=token_num,
177
+ hidden_dim=hidden_dim,
178
+ workspace_ptrs=_workspace_manager.workspace_tensor,
179
+ launch_with_pdl=True,
180
+ use_oneshot=use_oneshot,
181
+ trigger_completion_at_end=trigger_completion_at_end,
182
+ fp32_acc=fp32_acc,
183
+ pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm),
184
+ allreduce_out=None,
185
+ residual_in=residual,
186
+ residual_out=residual_out,
187
+ norm_out=norm_out,
188
+ quant_out=None,
189
+ scale_out=None,
190
+ rms_gamma=weight,
191
+ rms_eps=eps,
192
+ scale_factor=None,
193
+ layout_code=None,
194
+ )
195
+
196
+ return norm_out, residual_out
197
+
198
+
199
+ def cleanup_flashinfer_workspace():
200
+ global _workspace_manager
201
+ if _workspace_manager is not None:
202
+ _workspace_manager.cleanup()
@@ -163,6 +163,32 @@ class RMSNorm(CustomOp):
163
163
  else:
164
164
  return self.forward_native(x, residual)
165
165
 
166
+ def forward_with_allreduce_fusion(
167
+ self,
168
+ x: torch.Tensor,
169
+ residual: Optional[torch.Tensor] = None,
170
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
171
+ """
172
+ Forward method with allreduce fusion, prioritizing flashinfer fused operations
173
+ """
174
+ if residual is not None:
175
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
176
+ from sglang.srt.layers.flashinfer_comm_fusion import (
177
+ flashinfer_allreduce_add_rmsnorm,
178
+ )
179
+
180
+ if get_tensor_model_parallel_world_size() > 1:
181
+ fused_result = flashinfer_allreduce_add_rmsnorm(
182
+ input_tensor=x,
183
+ residual=residual,
184
+ weight=self.weight,
185
+ eps=self.variance_epsilon,
186
+ )
187
+ if fused_result[0] is not None:
188
+ return fused_result
189
+
190
+ return self.forward(x, residual)
191
+
166
192
 
167
193
  class GemmaRMSNorm(CustomOp):
168
194
  def __init__(
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
17
17
  tensor_model_parallel_all_gather,
18
18
  tensor_model_parallel_all_reduce,
19
19
  )
20
+ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
20
21
  from sglang.srt.layers.parameter import (
21
22
  BasevLLMParameter,
22
23
  BlockQuantScaleParameter,
@@ -31,10 +32,10 @@ from sglang.srt.layers.quantization.base_config import (
31
32
  QuantizeMethodBase,
32
33
  )
33
34
  from sglang.srt.utils import (
34
- _process_weight_after_loading,
35
35
  cpu_has_amx_support,
36
36
  is_cpu,
37
37
  set_weight_attrs,
38
+ use_intel_amx_backend,
38
39
  )
39
40
 
40
41
  logger = logging.getLogger(__name__)
@@ -175,7 +176,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
175
176
 
176
177
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
177
178
  if _is_cpu and _is_cpu_amx_available:
178
- _process_weight_after_loading(layer, ["weight"])
179
+ _amx_process_weight_after_loading(layer, ["weight"])
179
180
 
180
181
  def apply(
181
182
  self,
@@ -184,7 +185,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
184
185
  bias: Optional[torch.Tensor] = None,
185
186
  ) -> torch.Tensor:
186
187
 
187
- if getattr(layer, "use_intel_amx_backend", False):
188
+ if use_intel_amx_backend(layer):
188
189
  return torch.ops.sgl_kernel.weight_packed_linear(
189
190
  x, layer.weight, bias, True # is_vnni
190
191
  )
@@ -425,8 +426,26 @@ class ColumnParallelLinear(LinearBase):
425
426
  if output_dim is not None and not use_bitsandbytes_4bit:
426
427
  shard_size = param_data.shape[output_dim]
427
428
  start_idx = self.tp_rank * shard_size
428
- if not self.use_presharded_weights:
429
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
429
+
430
+ if _is_cpu:
431
+ from sglang.srt.model_loader.weight_utils import (
432
+ narrow_padded_param_and_loaded_weight,
433
+ )
434
+
435
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
436
+ param_data,
437
+ loaded_weight,
438
+ 0, # param_data_start
439
+ start_idx,
440
+ output_dim,
441
+ shard_size,
442
+ not self.use_presharded_weights,
443
+ )
444
+ else:
445
+ if not self.use_presharded_weights:
446
+ loaded_weight = loaded_weight.narrow(
447
+ output_dim, start_idx, shard_size
448
+ )
430
449
 
431
450
  # Special case for loading scales off disk, which often do not
432
451
  # have a shape (such as in the case of AutoFP8).
@@ -643,10 +662,29 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
643
662
 
644
663
  param_data = param_data.narrow(output_dim, shard_offset, shard_size)
645
664
  start_idx = self.tp_rank * shard_size
646
- # bitsandbytes loads the weights of the specific portion
647
- # no need to narrow here
648
- if not use_bitsandbytes_4bit and not self.use_presharded_weights:
649
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
665
+
666
+ if _is_cpu:
667
+ from sglang.srt.model_loader.weight_utils import (
668
+ narrow_padded_param_and_loaded_weight,
669
+ )
670
+
671
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
672
+ param_data,
673
+ loaded_weight,
674
+ 0, # param_data_start
675
+ start_idx,
676
+ output_dim,
677
+ shard_size,
678
+ not use_bitsandbytes_4bit and not self.use_presharded_weights,
679
+ )
680
+ else:
681
+ # bitsandbytes loads the weights of the specific portion
682
+ # no need to narrow here
683
+ if not use_bitsandbytes_4bit and not self.use_presharded_weights:
684
+ loaded_weight = loaded_weight.narrow(
685
+ output_dim, start_idx, shard_size
686
+ )
687
+
650
688
  # Special case for AQLM codebooks.
651
689
  elif is_metadata:
652
690
  # metadata indicates fixed size concatenated along dim 0
@@ -1111,10 +1149,27 @@ class QKVParallelLinear(ColumnParallelLinear):
1111
1149
  shard_id = self.tp_rank // self.num_kv_head_replicas
1112
1150
  start_idx = shard_id * shard_size
1113
1151
 
1114
- # bitsandbytes loads the weights of the specific portion
1115
- # no need to narrow here
1116
- if not use_bitsandbytes_4bit and not self.use_presharded_weights:
1117
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1152
+ if _is_cpu:
1153
+ from sglang.srt.model_loader.weight_utils import (
1154
+ narrow_padded_param_and_loaded_weight,
1155
+ )
1156
+
1157
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
1158
+ param_data,
1159
+ loaded_weight,
1160
+ 0, # param_data_start
1161
+ start_idx,
1162
+ output_dim,
1163
+ shard_size,
1164
+ not use_bitsandbytes_4bit and not self.use_presharded_weights,
1165
+ )
1166
+ else:
1167
+ # bitsandbytes loads the weights of the specific portion
1168
+ # no need to narrow here
1169
+ if not use_bitsandbytes_4bit and not self.use_presharded_weights:
1170
+ loaded_weight = loaded_weight.narrow(
1171
+ output_dim, start_idx, shard_size
1172
+ )
1118
1173
 
1119
1174
  # Special case for for AQLM codebooks.
1120
1175
  elif is_metadata:
@@ -1256,7 +1311,22 @@ class RowParallelLinear(LinearBase):
1256
1311
  ):
1257
1312
  shard_size = param_data.shape[input_dim]
1258
1313
  start_idx = self.tp_rank * shard_size
1259
- loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1314
+
1315
+ if _is_cpu:
1316
+ from sglang.srt.model_loader.weight_utils import (
1317
+ narrow_padded_param_and_loaded_weight,
1318
+ )
1319
+
1320
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
1321
+ param_data,
1322
+ loaded_weight,
1323
+ 0, # param_data_start
1324
+ start_idx,
1325
+ input_dim,
1326
+ shard_size,
1327
+ )
1328
+ else:
1329
+ loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1260
1330
 
1261
1331
  # Special case for loading scales off disk, which often do not
1262
1332
  # have a shape (such as in the case of AutoFP8).
@@ -42,7 +42,7 @@ from sglang.srt.model_executor.forward_batch_info import (
42
42
  ForwardBatch,
43
43
  ForwardMode,
44
44
  )
45
- from sglang.srt.utils import dump_to_file
45
+ from sglang.srt.utils import dump_to_file, use_intel_amx_backend
46
46
 
47
47
  logger = logging.getLogger(__name__)
48
48
 
@@ -436,13 +436,13 @@ class LogitsProcessor(nn.Module):
436
436
  if self.do_tensor_parallel_all_gather_dp_attn:
437
437
  logits_metadata.compute_dp_attention_metadata(hidden_states)
438
438
  hidden_states, local_hidden_states = (
439
- logits_metadata.gathered_buffer,
440
- hidden_states.clone(),
439
+ torch.empty_like(logits_metadata.gathered_buffer),
440
+ hidden_states,
441
441
  )
442
442
  dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
443
443
 
444
444
  if hasattr(lm_head, "weight"):
445
- if getattr(lm_head, "use_intel_amx_backend", False):
445
+ if use_intel_amx_backend(lm_head):
446
446
  logits = torch.ops.sgl_kernel.weight_packed_linear(
447
447
  hidden_states.to(lm_head.weight.dtype),
448
448
  lm_head.weight,
@@ -4,9 +4,8 @@ from typing import List, Optional
4
4
  import torch
5
5
  import triton
6
6
 
7
- from sglang.math_utils import ceil_div
8
7
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
9
- from sglang.srt.utils import dispose_tensor, is_cuda
8
+ from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda
10
9
 
11
10
  logger = logging.getLogger(__name__)
12
11
 
@@ -814,14 +813,17 @@ def _fwd_kernel_ep_scatter_2(
814
813
  offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
815
814
  mask = offset_in < HIDDEN_SIZE
816
815
 
817
- offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
818
- mask_s = offset_in_s < SCALE_HIDDEN_SIZE
816
+ index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
817
+ mask_s = index_in_s < SCALE_HIDDEN_SIZE
819
818
 
820
819
  for token_id_int32 in range(start_token_id, total_token_num, grid_num):
821
820
  token_id = token_id_int32.to(tl.int64)
822
821
  to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
823
822
  to_copy_s = tl.load(
824
- recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
823
+ recv_x_scale
824
+ + token_id * recv_x_scale_stride0
825
+ + index_in_s * recv_x_scale_stride1,
826
+ mask=mask_s,
825
827
  )
826
828
 
827
829
  for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
@@ -842,7 +844,11 @@ def _fwd_kernel_ep_scatter_2(
842
844
  output_tensor_scale + dest_token_index * output_tensor_scale_stride0
843
845
  )
844
846
  tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
845
- tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
847
+ tl.store(
848
+ output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1,
849
+ to_copy_s,
850
+ mask=mask_s,
851
+ )
846
852
 
847
853
 
848
854
  # copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
@@ -857,6 +863,7 @@ def ep_scatter(
857
863
  output_tensor_scale: torch.Tensor,
858
864
  m_indices: torch.Tensor,
859
865
  output_index: torch.Tensor,
866
+ scale_ue8m0: bool = False,
860
867
  ):
861
868
  BLOCK_E = 128 # token num of per expert is aligned to 128
862
869
  BLOCK_D = 128 # block size of quantization
@@ -866,7 +873,15 @@ def ep_scatter(
866
873
  # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
867
874
  grid = num_experts
868
875
 
876
+ scale_hidden_size = hidden_size // BLOCK_D
877
+ if scale_ue8m0:
878
+ # ue8m0 scales are packed here (4 scales per int32),
879
+ # hence the effective size of this dimension is divided by 4.
880
+ scale_hidden_size = ceil_div(scale_hidden_size, 4)
881
+
869
882
  assert m_indices.shape[0] % BLOCK_E == 0
883
+ assert recv_x_scale.dtype == output_tensor_scale.dtype
884
+ assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size
870
885
 
871
886
  _fwd_kernel_ep_scatter_1[(grid,)](
872
887
  num_recv_tokens_per_expert,
@@ -905,8 +920,8 @@ def ep_scatter(
905
920
  num_warps=num_warps,
906
921
  HIDDEN_SIZE=hidden_size,
907
922
  HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
908
- SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
909
- SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
923
+ SCALE_HIDDEN_SIZE=scale_hidden_size,
924
+ SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
910
925
  )
911
926
  return
912
927