sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,7 @@ from sglang.srt.utils import is_hip
8
8
 
9
9
  _is_hip = is_hip()
10
10
 
11
+
11
12
  fused_softcap_autotune = triton.autotune(
12
13
  configs=[
13
14
  triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
189
190
  assert x.shape == residual.shape and x.dtype == residual.dtype
190
191
  output, mid = torch.empty_like(x), torch.empty_like(x)
191
192
  bs, hidden_dim = x.shape
192
-
193
- min_num_warps = 16 if _is_hip else 32
194
-
195
193
  if autotune:
196
194
  fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
197
195
  output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
198
196
  )
199
197
  else:
198
+ max_warps = 16 if _is_hip else 32
200
199
  config = {
201
200
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
202
201
  "num_warps": max(
203
- min(
204
- triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
205
- ),
206
- 4,
202
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
207
203
  ),
208
204
  }
209
205
 
@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
260
256
  else:
261
257
  output = torch.empty_like(x)
262
258
  bs, hidden_dim = x.shape
263
-
264
- min_num_warps = 16 if _is_hip else 32
265
-
259
+ max_warps = 16 if _is_hip else 32
266
260
  config = {
267
261
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
268
262
  "num_warps": max(
269
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
263
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
270
264
  ),
271
265
  }
272
266
 
@@ -331,6 +325,75 @@ class FusedDualResidualRMSNorm:
331
325
  return self.rmsnorm2.forward_native(residual), residual
332
326
 
333
327
 
328
+ @triton.jit
329
+ def experts_combine_kernel(
330
+ out_hidden_states,
331
+ moe_hidden_states,
332
+ mlp_hidden_states,
333
+ combine_k: tl.constexpr,
334
+ hidden_dim: tl.constexpr,
335
+ BLOCK_SIZE: tl.constexpr,
336
+ ):
337
+ pid = tl.program_id(0)
338
+ start_index_mlp = pid * hidden_dim
339
+ start_index_rmoe = pid * hidden_dim * combine_k
340
+ offsets = tl.arange(0, BLOCK_SIZE)
341
+ mask = offsets < hidden_dim
342
+ combine_k_offsets = tl.arange(0, combine_k)
343
+
344
+ moe_x = tl.load(
345
+ moe_hidden_states
346
+ + start_index_rmoe
347
+ + combine_k_offsets[:, None] * hidden_dim
348
+ + offsets[None, :],
349
+ mask=mask[None, :],
350
+ other=0.0,
351
+ )
352
+ moe_x = tl.sum(moe_x, axis=0)
353
+ mlp_x = tl.load(mlp_hidden_states + start_index_mlp + offsets, mask=mask, other=0.0)
354
+ combined_x = (moe_x + mlp_x) / 1.4142135623730951
355
+
356
+ tl.store(out_hidden_states + start_index_mlp + offsets, combined_x, mask=mask)
357
+
358
+
359
+ def experts_combine_triton(moe_hidden_states, mlp_hidden_states, output_buffer=None):
360
+ assert moe_hidden_states.is_contiguous()
361
+ assert mlp_hidden_states.is_contiguous()
362
+
363
+ if len(moe_hidden_states.shape) == 2:
364
+ combine_k = 1 # pre-combined
365
+ else:
366
+ combine_k = moe_hidden_states.shape[1]
367
+
368
+ if output_buffer is None:
369
+ out_hidden_states = torch.empty_like(mlp_hidden_states)
370
+ else:
371
+ flat_output_buffer = output_buffer.view(mlp_hidden_states.dtype).reshape(-1)
372
+ assert flat_output_buffer.numel() >= mlp_hidden_states.numel()
373
+ out_hidden_states = flat_output_buffer[: mlp_hidden_states.numel()].reshape(
374
+ mlp_hidden_states.shape
375
+ )
376
+
377
+ bs, hidden_dim = mlp_hidden_states.shape
378
+
379
+ config = {
380
+ "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
381
+ "num_warps": max(
382
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 1024)), 8), 4
383
+ ),
384
+ }
385
+
386
+ experts_combine_kernel[(bs,)](
387
+ out_hidden_states,
388
+ moe_hidden_states,
389
+ mlp_hidden_states,
390
+ combine_k,
391
+ hidden_dim,
392
+ **config,
393
+ )
394
+ return out_hidden_states
395
+
396
+
334
397
  # gelu on first half of vector
335
398
  @triton.jit
336
399
  def gelu_and_mul_kernel(
@@ -400,10 +463,11 @@ def gelu_and_mul_triton(
400
463
  out_scales = scales
401
464
  static_scale = True
402
465
 
466
+ max_warps = 16 if _is_hip else 32
403
467
  config = {
404
468
  # 8 ele per thread (not tuned)
405
469
  "num_warps": max(
406
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
470
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
407
471
  ),
408
472
  }
409
473
 
@@ -0,0 +1,202 @@
1
+ import logging
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
8
+ from sglang.srt.utils import is_flashinfer_available
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ _flashinfer_comm = None
13
+ _workspace_manager = None
14
+
15
+ if is_flashinfer_available():
16
+ try:
17
+ import flashinfer.comm as comm
18
+
19
+ _flashinfer_comm = comm
20
+ except ImportError:
21
+ logger.warning(
22
+ "flashinfer.comm is not available, falling back to standard "
23
+ "implementation"
24
+ )
25
+
26
+
27
+ class FlashInferWorkspaceManager:
28
+ def __init__(self):
29
+ self.workspace_tensor = None
30
+ self.ipc_handles = None
31
+ self.world_size = None
32
+ self.rank = None
33
+ self.initialized = False
34
+
35
+ def initialize(
36
+ self,
37
+ world_size: int,
38
+ rank: int,
39
+ max_token_num: int,
40
+ hidden_dim: int,
41
+ group=None,
42
+ use_fp32_lamport: bool = False,
43
+ ):
44
+ """Initialize workspace"""
45
+ if self.initialized and self.world_size == world_size:
46
+ return
47
+
48
+ if _flashinfer_comm is None:
49
+ logger.warning(
50
+ "FlashInfer comm not available, skipping workspace " "initialization"
51
+ )
52
+ return
53
+
54
+ self.cleanup()
55
+
56
+ self.ipc_handles, self.workspace_tensor = (
57
+ comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
58
+ rank,
59
+ world_size,
60
+ max_token_num,
61
+ hidden_dim,
62
+ group=group,
63
+ use_fp32_lamport=use_fp32_lamport,
64
+ )
65
+ )
66
+
67
+ self.world_size = world_size
68
+ self.rank = rank
69
+ self.initialized = True
70
+
71
+ logger.info(
72
+ f"FlashInfer workspace initialized for rank {rank}, "
73
+ f"world_size {world_size}"
74
+ )
75
+
76
+ def cleanup(self):
77
+ """Clean up workspace"""
78
+ if self.initialized and self.ipc_handles is not None:
79
+ try:
80
+ _flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
81
+ self.ipc_handles, group=dist.group.WORLD
82
+ )
83
+ except Exception as e:
84
+ logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
85
+ finally:
86
+ self.workspace_tensor = None
87
+ self.ipc_handles = None
88
+ self.initialized = False
89
+
90
+
91
+ _workspace_manager = FlashInferWorkspaceManager()
92
+
93
+
94
+ def ensure_workspace_initialized(
95
+ max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False
96
+ ):
97
+ """Ensure workspace is initialized"""
98
+ if not is_flashinfer_available() or _flashinfer_comm is None:
99
+ return False
100
+
101
+ world_size = get_tensor_model_parallel_world_size()
102
+ if world_size <= 1:
103
+ return False
104
+
105
+ rank = dist.get_rank()
106
+
107
+ if (
108
+ not _workspace_manager.initialized
109
+ or _workspace_manager.world_size != world_size
110
+ ):
111
+ _workspace_manager.initialize(
112
+ world_size=world_size,
113
+ rank=rank,
114
+ max_token_num=max_token_num,
115
+ hidden_dim=hidden_dim,
116
+ use_fp32_lamport=use_fp32_lamport,
117
+ )
118
+
119
+ return _workspace_manager.initialized
120
+
121
+
122
+ def flashinfer_allreduce_add_rmsnorm(
123
+ input_tensor: torch.Tensor,
124
+ residual: torch.Tensor,
125
+ weight: torch.Tensor,
126
+ eps: float = 1e-6,
127
+ max_token_num: int = 1024,
128
+ use_oneshot: bool = True,
129
+ trigger_completion_at_end: bool = False,
130
+ fp32_acc: bool = False,
131
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
132
+ """
133
+ Use FlashInfer's fused allreduce + residual + RMS norm operation
134
+
135
+ Args:
136
+ input_tensor: Input tensor that needs allreduce
137
+ residual: Residual tensor
138
+ weight: RMS norm weight
139
+ eps: RMS norm epsilon
140
+ max_token_num: Maximum token number
141
+ use_oneshot: Whether to use oneshot mode
142
+ trigger_completion_at_end: Whether to trigger completion at end
143
+ fp32_acc: Whether to use fp32 precision
144
+
145
+ Returns:
146
+ Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output)
147
+ """
148
+ if not is_flashinfer_available() or _flashinfer_comm is None:
149
+ logger.debug(
150
+ "FlashInfer not available, falling back to standard " "implementation"
151
+ )
152
+ return None, None
153
+
154
+ world_size = get_tensor_model_parallel_world_size()
155
+ if world_size <= 1:
156
+ logger.debug("Single GPU, no need for allreduce fusion")
157
+ return None, None
158
+
159
+ if not ensure_workspace_initialized(
160
+ max_token_num=max_token_num,
161
+ hidden_dim=input_tensor.shape[-1],
162
+ use_fp32_lamport=(input_tensor.dtype == torch.float32),
163
+ ):
164
+ logger.debug("FlashInfer workspace not available")
165
+ return None, None
166
+
167
+ token_num, hidden_dim = input_tensor.shape
168
+
169
+ residual_out = torch.empty_like(residual)
170
+ norm_out = torch.empty_like(input_tensor)
171
+
172
+ _flashinfer_comm.trtllm_allreduce_fusion(
173
+ allreduce_in=input_tensor,
174
+ world_size=world_size,
175
+ world_rank=dist.get_rank(),
176
+ token_num=token_num,
177
+ hidden_dim=hidden_dim,
178
+ workspace_ptrs=_workspace_manager.workspace_tensor,
179
+ launch_with_pdl=True,
180
+ use_oneshot=use_oneshot,
181
+ trigger_completion_at_end=trigger_completion_at_end,
182
+ fp32_acc=fp32_acc,
183
+ pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm),
184
+ allreduce_out=None,
185
+ residual_in=residual,
186
+ residual_out=residual_out,
187
+ norm_out=norm_out,
188
+ quant_out=None,
189
+ scale_out=None,
190
+ rms_gamma=weight,
191
+ rms_eps=eps,
192
+ scale_factor=None,
193
+ layout_code=None,
194
+ )
195
+
196
+ return norm_out, residual_out
197
+
198
+
199
+ def cleanup_flashinfer_workspace():
200
+ global _workspace_manager
201
+ if _workspace_manager is not None:
202
+ _workspace_manager.cleanup()
@@ -52,6 +52,9 @@ elif _is_hip:
52
52
 
53
53
  logger = logging.getLogger(__name__)
54
54
 
55
+ if is_npu():
56
+ import torch_npu
57
+
55
58
 
56
59
  class RMSNorm(CustomOp):
57
60
  def __init__(
@@ -76,6 +79,18 @@ class RMSNorm(CustomOp):
76
79
  out = rmsnorm(x, self.weight.data, self.variance_epsilon)
77
80
  return out
78
81
 
82
+ def forward_npu(
83
+ self,
84
+ x: torch.Tensor,
85
+ residual: Optional[torch.Tensor] = None,
86
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
87
+ if residual is not None:
88
+ out, _, residual_out = torch_npu.npu_add_rms_norm(
89
+ residual, x, self.weight.data, self.variance_epsilon
90
+ )
91
+ return out, residual_out
92
+ return torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
93
+
79
94
  def forward_aiter(
80
95
  self,
81
96
  x: torch.Tensor,
@@ -148,6 +163,32 @@ class RMSNorm(CustomOp):
148
163
  else:
149
164
  return self.forward_native(x, residual)
150
165
 
166
+ def forward_with_allreduce_fusion(
167
+ self,
168
+ x: torch.Tensor,
169
+ residual: Optional[torch.Tensor] = None,
170
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
171
+ """
172
+ Forward method with allreduce fusion, prioritizing flashinfer fused operations
173
+ """
174
+ if residual is not None:
175
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
176
+ from sglang.srt.layers.flashinfer_comm_fusion import (
177
+ flashinfer_allreduce_add_rmsnorm,
178
+ )
179
+
180
+ if get_tensor_model_parallel_world_size() > 1:
181
+ fused_result = flashinfer_allreduce_add_rmsnorm(
182
+ input_tensor=x,
183
+ residual=residual,
184
+ weight=self.weight,
185
+ eps=self.variance_epsilon,
186
+ )
187
+ if fused_result[0] is not None:
188
+ return fused_result
189
+
190
+ return self.forward(x, residual)
191
+
151
192
 
152
193
  class GemmaRMSNorm(CustomOp):
153
194
  def __init__(
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
17
17
  tensor_model_parallel_all_gather,
18
18
  tensor_model_parallel_all_reduce,
19
19
  )
20
+ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
20
21
  from sglang.srt.layers.parameter import (
21
22
  BasevLLMParameter,
22
23
  BlockQuantScaleParameter,
@@ -30,7 +31,12 @@ from sglang.srt.layers.quantization.base_config import (
30
31
  QuantizationConfig,
31
32
  QuantizeMethodBase,
32
33
  )
33
- from sglang.srt.utils import set_weight_attrs
34
+ from sglang.srt.utils import (
35
+ cpu_has_amx_support,
36
+ is_cpu,
37
+ set_weight_attrs,
38
+ use_intel_amx_backend,
39
+ )
34
40
 
35
41
  logger = logging.getLogger(__name__)
36
42
 
@@ -52,6 +58,9 @@ WEIGHT_LOADER_V2_SUPPORTED = [
52
58
  "IPEXAWQLinearMethod",
53
59
  ]
54
60
 
61
+ _is_cpu_amx_available = cpu_has_amx_support()
62
+ _is_cpu = is_cpu()
63
+
55
64
 
56
65
  def adjust_marlin_shard(param, shard_size, shard_offset):
57
66
  marlin_tile_size = getattr(param, "marlin_tile_size", None)
@@ -165,6 +174,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
165
174
  layer.register_parameter("weight", weight)
166
175
  set_weight_attrs(weight, extra_weight_attrs)
167
176
 
177
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
178
+ if _is_cpu and _is_cpu_amx_available:
179
+ _amx_process_weight_after_loading(layer, ["weight"])
180
+
168
181
  def apply(
169
182
  self,
170
183
  layer: torch.nn.Module,
@@ -172,6 +185,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
172
185
  bias: Optional[torch.Tensor] = None,
173
186
  ) -> torch.Tensor:
174
187
 
188
+ if use_intel_amx_backend(layer):
189
+ return torch.ops.sgl_kernel.weight_packed_linear(
190
+ x, layer.weight, bias, True # is_vnni
191
+ )
192
+
175
193
  return F.linear(x, layer.weight, bias)
176
194
 
177
195
 
@@ -408,8 +426,26 @@ class ColumnParallelLinear(LinearBase):
408
426
  if output_dim is not None and not use_bitsandbytes_4bit:
409
427
  shard_size = param_data.shape[output_dim]
410
428
  start_idx = self.tp_rank * shard_size
411
- if not self.use_presharded_weights:
412
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
429
+
430
+ if _is_cpu:
431
+ from sglang.srt.model_loader.weight_utils import (
432
+ narrow_padded_param_and_loaded_weight,
433
+ )
434
+
435
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
436
+ param_data,
437
+ loaded_weight,
438
+ 0, # param_data_start
439
+ start_idx,
440
+ output_dim,
441
+ shard_size,
442
+ not self.use_presharded_weights,
443
+ )
444
+ else:
445
+ if not self.use_presharded_weights:
446
+ loaded_weight = loaded_weight.narrow(
447
+ output_dim, start_idx, shard_size
448
+ )
413
449
 
414
450
  # Special case for loading scales off disk, which often do not
415
451
  # have a shape (such as in the case of AutoFP8).
@@ -626,10 +662,29 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
626
662
 
627
663
  param_data = param_data.narrow(output_dim, shard_offset, shard_size)
628
664
  start_idx = self.tp_rank * shard_size
629
- # bitsandbytes loads the weights of the specific portion
630
- # no need to narrow here
631
- if not use_bitsandbytes_4bit and not self.use_presharded_weights:
632
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
665
+
666
+ if _is_cpu:
667
+ from sglang.srt.model_loader.weight_utils import (
668
+ narrow_padded_param_and_loaded_weight,
669
+ )
670
+
671
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
672
+ param_data,
673
+ loaded_weight,
674
+ 0, # param_data_start
675
+ start_idx,
676
+ output_dim,
677
+ shard_size,
678
+ not use_bitsandbytes_4bit and not self.use_presharded_weights,
679
+ )
680
+ else:
681
+ # bitsandbytes loads the weights of the specific portion
682
+ # no need to narrow here
683
+ if not use_bitsandbytes_4bit and not self.use_presharded_weights:
684
+ loaded_weight = loaded_weight.narrow(
685
+ output_dim, start_idx, shard_size
686
+ )
687
+
633
688
  # Special case for AQLM codebooks.
634
689
  elif is_metadata:
635
690
  # metadata indicates fixed size concatenated along dim 0
@@ -1094,10 +1149,27 @@ class QKVParallelLinear(ColumnParallelLinear):
1094
1149
  shard_id = self.tp_rank // self.num_kv_head_replicas
1095
1150
  start_idx = shard_id * shard_size
1096
1151
 
1097
- # bitsandbytes loads the weights of the specific portion
1098
- # no need to narrow here
1099
- if not use_bitsandbytes_4bit and not self.use_presharded_weights:
1100
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1152
+ if _is_cpu:
1153
+ from sglang.srt.model_loader.weight_utils import (
1154
+ narrow_padded_param_and_loaded_weight,
1155
+ )
1156
+
1157
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
1158
+ param_data,
1159
+ loaded_weight,
1160
+ 0, # param_data_start
1161
+ start_idx,
1162
+ output_dim,
1163
+ shard_size,
1164
+ not use_bitsandbytes_4bit and not self.use_presharded_weights,
1165
+ )
1166
+ else:
1167
+ # bitsandbytes loads the weights of the specific portion
1168
+ # no need to narrow here
1169
+ if not use_bitsandbytes_4bit and not self.use_presharded_weights:
1170
+ loaded_weight = loaded_weight.narrow(
1171
+ output_dim, start_idx, shard_size
1172
+ )
1101
1173
 
1102
1174
  # Special case for for AQLM codebooks.
1103
1175
  elif is_metadata:
@@ -1239,7 +1311,22 @@ class RowParallelLinear(LinearBase):
1239
1311
  ):
1240
1312
  shard_size = param_data.shape[input_dim]
1241
1313
  start_idx = self.tp_rank * shard_size
1242
- loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1314
+
1315
+ if _is_cpu:
1316
+ from sglang.srt.model_loader.weight_utils import (
1317
+ narrow_padded_param_and_loaded_weight,
1318
+ )
1319
+
1320
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
1321
+ param_data,
1322
+ loaded_weight,
1323
+ 0, # param_data_start
1324
+ start_idx,
1325
+ input_dim,
1326
+ shard_size,
1327
+ )
1328
+ else:
1329
+ loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1243
1330
 
1244
1331
  # Special case for loading scales off disk, which often do not
1245
1332
  # have a shape (such as in the case of AutoFP8).
@@ -42,7 +42,7 @@ from sglang.srt.model_executor.forward_batch_info import (
42
42
  ForwardBatch,
43
43
  ForwardMode,
44
44
  )
45
- from sglang.srt.utils import dump_to_file
45
+ from sglang.srt.utils import dump_to_file, use_intel_amx_backend
46
46
 
47
47
  logger = logging.getLogger(__name__)
48
48
 
@@ -436,17 +436,26 @@ class LogitsProcessor(nn.Module):
436
436
  if self.do_tensor_parallel_all_gather_dp_attn:
437
437
  logits_metadata.compute_dp_attention_metadata(hidden_states)
438
438
  hidden_states, local_hidden_states = (
439
- logits_metadata.gathered_buffer,
440
- hidden_states.clone(),
439
+ torch.empty_like(logits_metadata.gathered_buffer),
440
+ hidden_states,
441
441
  )
442
442
  dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
443
443
 
444
444
  if hasattr(lm_head, "weight"):
445
- logits = torch.matmul(
446
- hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
447
- )
445
+ if use_intel_amx_backend(lm_head):
446
+ logits = torch.ops.sgl_kernel.weight_packed_linear(
447
+ hidden_states.to(lm_head.weight.dtype),
448
+ lm_head.weight,
449
+ None, # bias
450
+ True, # is_vnni
451
+ )
452
+ else:
453
+ logits = torch.matmul(
454
+ hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
455
+ )
448
456
  else:
449
457
  # GGUF models
458
+ # TODO: use weight_packed_linear for GGUF models
450
459
  logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
451
460
 
452
461
  if self.logit_scale is not None:
@@ -4,9 +4,8 @@ from typing import List, Optional
4
4
  import torch
5
5
  import triton
6
6
 
7
- from sglang.math_utils import ceil_div
8
7
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
9
- from sglang.srt.utils import dispose_tensor, is_cuda
8
+ from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda
10
9
 
11
10
  logger = logging.getLogger(__name__)
12
11
 
@@ -814,14 +813,17 @@ def _fwd_kernel_ep_scatter_2(
814
813
  offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
815
814
  mask = offset_in < HIDDEN_SIZE
816
815
 
817
- offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
818
- mask_s = offset_in_s < SCALE_HIDDEN_SIZE
816
+ index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
817
+ mask_s = index_in_s < SCALE_HIDDEN_SIZE
819
818
 
820
819
  for token_id_int32 in range(start_token_id, total_token_num, grid_num):
821
820
  token_id = token_id_int32.to(tl.int64)
822
821
  to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
823
822
  to_copy_s = tl.load(
824
- recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
823
+ recv_x_scale
824
+ + token_id * recv_x_scale_stride0
825
+ + index_in_s * recv_x_scale_stride1,
826
+ mask=mask_s,
825
827
  )
826
828
 
827
829
  for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
@@ -842,7 +844,11 @@ def _fwd_kernel_ep_scatter_2(
842
844
  output_tensor_scale + dest_token_index * output_tensor_scale_stride0
843
845
  )
844
846
  tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
845
- tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
847
+ tl.store(
848
+ output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1,
849
+ to_copy_s,
850
+ mask=mask_s,
851
+ )
846
852
 
847
853
 
848
854
  # copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
@@ -857,6 +863,7 @@ def ep_scatter(
857
863
  output_tensor_scale: torch.Tensor,
858
864
  m_indices: torch.Tensor,
859
865
  output_index: torch.Tensor,
866
+ scale_ue8m0: bool = False,
860
867
  ):
861
868
  BLOCK_E = 128 # token num of per expert is aligned to 128
862
869
  BLOCK_D = 128 # block size of quantization
@@ -866,7 +873,15 @@ def ep_scatter(
866
873
  # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
867
874
  grid = num_experts
868
875
 
876
+ scale_hidden_size = hidden_size // BLOCK_D
877
+ if scale_ue8m0:
878
+ # ue8m0 scales are packed here (4 scales per int32),
879
+ # hence the effective size of this dimension is divided by 4.
880
+ scale_hidden_size = ceil_div(scale_hidden_size, 4)
881
+
869
882
  assert m_indices.shape[0] % BLOCK_E == 0
883
+ assert recv_x_scale.dtype == output_tensor_scale.dtype
884
+ assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size
870
885
 
871
886
  _fwd_kernel_ep_scatter_1[(grid,)](
872
887
  num_recv_tokens_per_expert,
@@ -905,8 +920,8 @@ def ep_scatter(
905
920
  num_warps=num_warps,
906
921
  HIDDEN_SIZE=hidden_size,
907
922
  HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
908
- SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
909
- SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
923
+ SCALE_HIDDEN_SIZE=scale_hidden_size,
924
+ SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
910
925
  )
911
926
  return
912
927