sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -82,6 +82,8 @@ class FlashInferAttnBackend(AttentionBackend):
82
82
  self.max_context_len = model_runner.model_config.context_len
83
83
  self.skip_prefill = skip_prefill
84
84
  self.is_multimodal = model_runner.model_config.is_multimodal
85
+ self.kv_cache_dtype = model_runner.kv_cache_dtype
86
+ self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
85
87
 
86
88
  assert not (
87
89
  model_runner.sliding_window_size is not None
@@ -98,8 +100,11 @@ class FlashInferAttnBackend(AttentionBackend):
98
100
  self.num_wrappers = 1
99
101
  self.dispatch_reason = None
100
102
 
101
- # Qwen2 models require higher flashinfer workspace size
102
- if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
103
+ # Qwen2/Qwen3 models require higher flashinfer workspace size
104
+ if (
105
+ "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures
106
+ or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
107
+ ):
103
108
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
104
109
 
105
110
  # Allocate buffers
@@ -391,6 +396,8 @@ class FlashInferAttnBackend(AttentionBackend):
391
396
  forward_batch: ForwardBatch,
392
397
  save_kv_cache=True,
393
398
  ):
399
+ k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
400
+ v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
394
401
  prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
395
402
  self._get_wrapper_idx(layer)
396
403
  ]
@@ -407,7 +414,7 @@ class FlashInferAttnBackend(AttentionBackend):
407
414
  assert v is not None
408
415
  if save_kv_cache:
409
416
  forward_batch.token_to_kv_pool.set_kv_buffer(
410
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
417
+ layer, cache_loc, k, v, k_scale, v_scale
411
418
  )
412
419
 
413
420
  o = prefill_wrapper_paged.forward(
@@ -417,8 +424,8 @@ class FlashInferAttnBackend(AttentionBackend):
417
424
  sm_scale=layer.scaling,
418
425
  window_left=layer.sliding_window_size,
419
426
  logits_soft_cap=logits_soft_cap,
420
- k_scale=layer.k_scale,
421
- v_scale=layer.v_scale,
427
+ k_scale=k_scale,
428
+ v_scale=v_scale,
422
429
  )
423
430
  else:
424
431
  o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
@@ -445,7 +452,7 @@ class FlashInferAttnBackend(AttentionBackend):
445
452
 
446
453
  if save_kv_cache:
447
454
  forward_batch.token_to_kv_pool.set_kv_buffer(
448
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
455
+ layer, cache_loc, k, v, k_scale, v_scale
449
456
  )
450
457
 
451
458
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -459,6 +466,8 @@ class FlashInferAttnBackend(AttentionBackend):
459
466
  forward_batch: ForwardBatch,
460
467
  save_kv_cache=True,
461
468
  ):
469
+ k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
470
+ v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
462
471
  decode_wrapper = self.forward_metadata.decode_wrappers[
463
472
  self._get_wrapper_idx(layer)
464
473
  ]
@@ -472,7 +481,7 @@ class FlashInferAttnBackend(AttentionBackend):
472
481
  assert v is not None
473
482
  if save_kv_cache:
474
483
  forward_batch.token_to_kv_pool.set_kv_buffer(
475
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
484
+ layer, cache_loc, k, v, k_scale, v_scale
476
485
  )
477
486
 
478
487
  o = decode_wrapper.forward(
@@ -480,8 +489,8 @@ class FlashInferAttnBackend(AttentionBackend):
480
489
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
481
490
  sm_scale=layer.scaling,
482
491
  logits_soft_cap=layer.logit_cap,
483
- k_scale=layer.k_scale,
484
- v_scale=layer.v_scale,
492
+ k_scale=k_scale,
493
+ v_scale=v_scale,
485
494
  )
486
495
 
487
496
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -6,6 +6,7 @@ import torch
6
6
  from torch.nn.functional import scaled_dot_product_attention
7
7
 
8
8
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
9
+ from sglang.srt.layers.radix_attention import AttentionType
9
10
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
10
11
 
11
12
  if TYPE_CHECKING:
@@ -202,6 +203,10 @@ class TorchNativeAttnBackend(AttentionBackend):
202
203
  q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
203
204
  o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
204
205
 
206
+ causal = True
207
+ if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
208
+ causal = False
209
+
205
210
  self._run_sdpa_forward_extend(
206
211
  q_,
207
212
  o_,
@@ -214,7 +219,7 @@ class TorchNativeAttnBackend(AttentionBackend):
214
219
  forward_batch.extend_seq_lens,
215
220
  scaling=layer.scaling,
216
221
  enable_gqa=use_gqa,
217
- causal=not layer.is_cross_attention,
222
+ causal=causal,
218
223
  )
219
224
  return o
220
225
 
@@ -10,6 +10,7 @@ import triton.language as tl
10
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
11
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
12
12
  from sglang.srt.layers.dp_attention import get_attention_tp_size
13
+ from sglang.srt.layers.radix_attention import AttentionType
13
14
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
14
15
  from sglang.srt.utils import get_bool_env_var, get_device_core_count
15
16
 
@@ -528,6 +529,10 @@ class TritonAttnBackend(AttentionBackend):
528
529
  layer, forward_batch.out_cache_loc, k, v
529
530
  )
530
531
 
532
+ causal = True
533
+ if layer.attn_type == AttentionType.ENCODER_ONLY:
534
+ causal = False
535
+
531
536
  self.extend_attention_fwd(
532
537
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
533
538
  k.contiguous(),
@@ -539,6 +544,7 @@ class TritonAttnBackend(AttentionBackend):
539
544
  self.forward_metadata.kv_indptr,
540
545
  self.forward_metadata.kv_indices,
541
546
  self.forward_metadata.custom_mask,
547
+ causal,
542
548
  self.forward_metadata.mask_indptr,
543
549
  self.forward_metadata.max_extend_len,
544
550
  layer.scaling,
@@ -74,6 +74,7 @@ def _fwd_kernel(
74
74
  BLOCK_M: tl.constexpr,
75
75
  BLOCK_N: tl.constexpr,
76
76
  USE_CUSTOM_MASK: tl.constexpr,
77
+ IS_CAUSAL: tl.constexpr,
77
78
  SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
78
79
  STORE_TRANSPOSE: tl.constexpr,
79
80
  ):
@@ -129,6 +130,7 @@ def _fwd_kernel(
129
130
  for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
130
131
  start_n = tl.multiple_of(start_n, BLOCK_N)
131
132
  mask_n = (start_n + offs_n) < cur_seq_len_prefix
133
+
132
134
  offs_kv_loc = tl.load(
133
135
  kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
134
136
  )
@@ -196,7 +198,11 @@ def _fwd_kernel(
196
198
 
197
199
  # stage 2: compute the triangle part
198
200
 
199
- cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
201
+ cur_block_m_end = (
202
+ cur_seq_len_extend
203
+ if not IS_CAUSAL
204
+ else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
205
+ )
200
206
  for start_n in range(0, cur_block_m_end, BLOCK_N):
201
207
  start_n = tl.multiple_of(start_n, BLOCK_N)
202
208
  mask_n = (start_n + offs_n) < cur_block_m_end
@@ -243,12 +249,15 @@ def _fwd_kernel(
243
249
  )
244
250
  custom_mask &= mask_m[:, None] & mask_n[None, :]
245
251
  qk = tl.where(custom_mask, qk, float("-inf"))
246
- else:
252
+ elif IS_CAUSAL:
247
253
  mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
248
254
  start_n + offs_n[None, :]
249
255
  )
250
256
  mask_causual &= mask_m[:, None] & mask_n[None, :]
251
257
  qk = tl.where(mask_causual, qk, float("-inf"))
258
+ else:
259
+ mask_non_causal = mask_m[:, None] & mask_n[None, :]
260
+ qk = tl.where(mask_non_causal, qk, float("-inf"))
252
261
 
253
262
  n_e_max = tl.maximum(tl.max(qk, 1), e_max)
254
263
  re_scale = tl.exp(e_max - n_e_max)
@@ -299,6 +308,7 @@ def extend_attention_fwd(
299
308
  kv_indptr,
300
309
  kv_indices,
301
310
  custom_mask,
311
+ is_causal,
302
312
  mask_indptr,
303
313
  max_len_extend,
304
314
  sm_scale=None,
@@ -411,6 +421,7 @@ def extend_attention_fwd(
411
421
  Lq=Lq,
412
422
  Lv=Lv,
413
423
  USE_CUSTOM_MASK=USE_CUSTOM_MASK,
424
+ IS_CAUSAL=is_causal,
414
425
  SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
415
426
  STORE_TRANSPOSE=_is_hip,
416
427
  num_warps=num_warps,
@@ -94,7 +94,7 @@ class VisionAttention(nn.Module):
94
94
  input_size=embed_dim,
95
95
  output_size=embed_dim,
96
96
  quant_config=quant_config,
97
- prefix=add_prefix("out_proj", prefix),
97
+ prefix=add_prefix("proj", prefix),
98
98
  )
99
99
 
100
100
  def forward(
@@ -192,8 +192,7 @@ def _dp_gather(
192
192
 
193
193
  if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
194
194
  assert (
195
- global_tokens.untyped_storage().data_ptr()
196
- != local_tokens.untyped_storage().data_ptr()
195
+ local_tokens.untyped_storage() is not global_tokens.untyped_storage()
197
196
  ), "aliasing between global_tokens and local_tokens not allowed"
198
197
  memcpy_triton(
199
198
  global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
@@ -243,8 +242,7 @@ def dp_scatter(
243
242
  assert global_tokens.is_contiguous()
244
243
  if local_tokens.shape[0] > 0:
245
244
  assert (
246
- local_tokens.untyped_storage().data_ptr()
247
- != global_tokens.untyped_storage().data_ptr()
245
+ local_tokens.untyped_storage() is not global_tokens.untyped_storage()
248
246
  ), "aliasing between local_tokens and global_tokens not allowed"
249
247
  memcpy_triton(
250
248
  local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
@@ -4,6 +4,10 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
+ from sglang.srt.utils import is_hip
8
+
9
+ _is_hip = is_hip()
10
+
7
11
  fused_softcap_autotune = triton.autotune(
8
12
  configs=[
9
13
  triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
@@ -185,6 +189,9 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
185
189
  assert x.shape == residual.shape and x.dtype == residual.dtype
186
190
  output, mid = torch.empty_like(x), torch.empty_like(x)
187
191
  bs, hidden_dim = x.shape
192
+
193
+ min_num_warps = 16 if _is_hip else 32
194
+
188
195
  if autotune:
189
196
  fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
190
197
  output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
@@ -193,7 +200,10 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
193
200
  config = {
194
201
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
195
202
  "num_warps": max(
196
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
203
+ min(
204
+ triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
205
+ ),
206
+ 4,
197
207
  ),
198
208
  }
199
209
 
@@ -250,10 +260,13 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
250
260
  else:
251
261
  output = torch.empty_like(x)
252
262
  bs, hidden_dim = x.shape
263
+
264
+ min_num_warps = 16 if _is_hip else 32
265
+
253
266
  config = {
254
267
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
255
268
  "num_warps": max(
256
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
269
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
257
270
  ),
258
271
  }
259
272
 
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
19
19
  import torch
20
20
  import torch.nn as nn
21
21
 
22
+ from sglang.srt.custom_op import CustomOp
22
23
  from sglang.srt.utils import is_cuda_available
23
24
 
24
25
  _is_cuda = is_cuda_available()
@@ -31,7 +32,6 @@ if _is_cuda:
31
32
  rmsnorm,
32
33
  )
33
34
 
34
- from sglang.srt.custom_op import CustomOp
35
35
 
36
36
  logger = logging.getLogger(__name__)
37
37
 
@@ -1,5 +1,6 @@
1
1
  """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
2
2
 
3
+ import itertools
3
4
  import logging
4
5
  from abc import abstractmethod
5
6
  from typing import Dict, List, Optional, Tuple
@@ -47,6 +48,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
47
48
  "GPTQLinearMethod",
48
49
  "FBGEMMFp8LinearMethod",
49
50
  "ModelOptFp8LinearMethod",
51
+ "ModelOptFp4LinearMethod",
50
52
  "IPEXAWQLinearMethod",
51
53
  ]
52
54
 
@@ -60,12 +62,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
60
62
 
61
63
 
62
64
  def adjust_bitsandbytes_4bit_shard(
63
- param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
65
+ param: Parameter, shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
64
66
  ) -> Tuple[int, int]:
65
67
  """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
66
68
 
67
- total, _ = qkv_offsets["total"]
68
- orig_offset, orig_size = qkv_offsets[loaded_shard_id]
69
+ total, _ = shard_offsets["total"]
70
+ orig_offset, orig_size = shard_offsets[loaded_shard_id]
69
71
 
70
72
  quantized_total = param.data.shape[0]
71
73
  quantized_offset = orig_offset * quantized_total // total
@@ -572,6 +574,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
572
574
  shard_offsets.append((i, current_shard_offset, output_size))
573
575
  current_shard_offset += output_size
574
576
  packed_dim = getattr(param, "packed_dim", None)
577
+
578
+ use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
575
579
  for shard_id, shard_offset, shard_size in shard_offsets:
576
580
  # Special case for Quantization.
577
581
  # If quantized, we need to adjust the offset and size to account
@@ -584,6 +588,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
584
588
  param, shard_size, shard_offset
585
589
  )
586
590
 
591
+ if use_bitsandbytes_4bit:
592
+ index = list(itertools.accumulate([0] + self.output_sizes))
593
+ orig_offsets = {
594
+ str(i): (index[i], size)
595
+ for i, size in enumerate(self.output_sizes)
596
+ }
597
+ orig_offsets["total"] = (self.output_size, 0)
598
+ shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
599
+ param, orig_offsets, str(shard_id)
600
+ )
601
+
587
602
  loaded_weight_shard = loaded_weight.narrow(
588
603
  output_dim, shard_offset, shard_size
589
604
  )
@@ -2,6 +2,7 @@ import logging
2
2
  from typing import Callable, List, Optional, Tuple
3
3
 
4
4
  import torch
5
+ from torch.nn import Module
5
6
 
6
7
  try:
7
8
  from deep_gemm import (
@@ -13,8 +14,6 @@ try:
13
14
  except ImportError:
14
15
  use_deep_gemm = False
15
16
 
16
- from torch.nn import Module
17
-
18
17
  from sglang.srt.custom_op import CustomOp
19
18
  from sglang.srt.distributed import (
20
19
  get_tensor_model_parallel_rank,
@@ -37,22 +36,17 @@ from sglang.srt.layers.quantization.base_config import (
37
36
  QuantizeMethodBase,
38
37
  )
39
38
  from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
39
+ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
40
40
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
41
- from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
41
+ from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
42
42
 
43
- _is_cuda = is_cuda()
43
+ _is_hip = is_hip()
44
44
 
45
- if _is_cuda:
46
- from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
47
- else:
48
- from vllm import _custom_ops as vllm_ops
45
+ if _is_hip:
46
+ from vllm._custom_ops import scaled_fp8_quant
49
47
 
50
48
  logger = logging.getLogger(__name__)
51
49
 
52
- _is_hip = is_hip()
53
-
54
- _buffer = None
55
-
56
50
 
57
51
  class GroupedGemmRunner(torch.nn.Module):
58
52
  flashinfer_gemm_warpper = None
@@ -142,6 +136,7 @@ class EPMoE(torch.nn.Module):
142
136
  correction_bias: Optional[torch.Tensor] = None,
143
137
  custom_routing_function: Optional[Callable] = None,
144
138
  activation: str = "silu",
139
+ routed_scaling_factor: Optional[float] = None,
145
140
  ):
146
141
  super().__init__()
147
142
 
@@ -170,6 +165,7 @@ class EPMoE(torch.nn.Module):
170
165
  self.correction_bias = correction_bias
171
166
  self.custom_routing_function = custom_routing_function
172
167
  self.activation = activation
168
+ self.routed_scaling_factor = routed_scaling_factor
173
169
 
174
170
  if quant_config is None:
175
171
  self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
@@ -221,6 +217,7 @@ class EPMoE(torch.nn.Module):
221
217
  num_expert_group=self.num_expert_group,
222
218
  correction_bias=self.correction_bias,
223
219
  custom_routing_function=self.custom_routing_function,
220
+ routed_scaling_factor=self.routed_scaling_factor,
224
221
  )
225
222
 
226
223
  reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
@@ -740,20 +737,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
740
737
  )
741
738
 
742
739
  for expert in range(layer.num_experts_per_partition):
743
- if _is_cuda:
744
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
745
- sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
746
- )
747
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
748
- sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
749
- )
750
- else:
751
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
752
- vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
753
- )
754
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
755
- vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
756
- )
740
+ w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
741
+ scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
742
+ )
743
+ w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
744
+ scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
745
+ )
757
746
  layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
758
747
  layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
759
748
  return
@@ -986,9 +975,6 @@ class DeepEPMoE(EPMoE):
986
975
  ):
987
976
  assert self.quant_method is not None
988
977
  assert self.activation == "silu"
989
- assert (
990
- hidden_states_fp8[0].size(0) % 4 == 0
991
- ), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"
992
978
 
993
979
  # GroupGemm-0
994
980
  num_groups, m, k = hidden_states_fp8[0].size()