sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +0 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  12. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  13. sglang/srt/constrained/xgrammar_backend.py +26 -4
  14. sglang/srt/custom_op.py +0 -62
  15. sglang/srt/disaggregation/decode.py +62 -6
  16. sglang/srt/disaggregation/mini_lb.py +5 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +32 -62
  18. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  19. sglang/srt/disaggregation/prefill.py +40 -4
  20. sglang/srt/disaggregation/utils.py +15 -0
  21. sglang/srt/entrypoints/verl_engine.py +7 -5
  22. sglang/srt/layers/activation.py +6 -8
  23. sglang/srt/layers/attention/flashattention_backend.py +114 -71
  24. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  25. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  26. sglang/srt/layers/attention/triton_backend.py +6 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  28. sglang/srt/layers/layernorm.py +1 -1
  29. sglang/srt/layers/linear.py +17 -3
  30. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  31. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  34. sglang/srt/layers/moe/topk.py +27 -30
  35. sglang/srt/layers/parameter.py +0 -2
  36. sglang/srt/layers/quantization/__init__.py +1 -0
  37. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  38. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
  39. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  40. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  41. sglang/srt/layers/quantization/fp8.py +115 -132
  42. sglang/srt/layers/quantization/fp8_kernel.py +213 -57
  43. sglang/srt/layers/quantization/fp8_utils.py +187 -262
  44. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  45. sglang/srt/layers/quantization/utils.py +5 -11
  46. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  47. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  48. sglang/srt/layers/radix_attention.py +15 -0
  49. sglang/srt/layers/rotary_embedding.py +3 -2
  50. sglang/srt/layers/sampler.py +5 -10
  51. sglang/srt/lora/backend/base_backend.py +18 -2
  52. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  53. sglang/srt/lora/backend/triton_backend.py +1 -1
  54. sglang/srt/lora/layers.py +1 -1
  55. sglang/srt/lora/lora.py +1 -1
  56. sglang/srt/lora/lora_manager.py +1 -1
  57. sglang/srt/managers/detokenizer_manager.py +0 -1
  58. sglang/srt/managers/io_struct.py +1 -0
  59. sglang/srt/managers/mm_utils.py +4 -3
  60. sglang/srt/managers/multimodal_processor.py +0 -2
  61. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  62. sglang/srt/managers/schedule_batch.py +2 -4
  63. sglang/srt/managers/scheduler.py +12 -71
  64. sglang/srt/managers/tokenizer_manager.py +1 -0
  65. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  66. sglang/srt/mem_cache/memory_pool.py +7 -2
  67. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  68. sglang/srt/model_executor/model_runner.py +20 -27
  69. sglang/srt/models/bert.py +398 -0
  70. sglang/srt/models/deepseek.py +1 -1
  71. sglang/srt/models/deepseek_nextn.py +74 -70
  72. sglang/srt/models/deepseek_v2.py +289 -348
  73. sglang/srt/models/llama.py +5 -5
  74. sglang/srt/models/minicpm3.py +29 -201
  75. sglang/srt/models/qwen2.py +4 -1
  76. sglang/srt/models/qwen2_moe.py +14 -13
  77. sglang/srt/models/qwen3.py +335 -0
  78. sglang/srt/models/qwen3_moe.py +423 -0
  79. sglang/srt/reasoning_parser.py +0 -1
  80. sglang/srt/sampling/sampling_batch_info.py +2 -3
  81. sglang/srt/server_args.py +34 -32
  82. sglang/srt/speculative/eagle_worker.py +4 -7
  83. sglang/srt/utils.py +16 -1
  84. sglang/test/runners.py +5 -1
  85. sglang/test/test_block_fp8.py +167 -0
  86. sglang/test/test_custom_ops.py +1 -1
  87. sglang/version.py +1 -1
  88. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
  89. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
  90. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  91. sglang/lang/__init__.py +0 -0
  92. sglang/srt/lora/backend/__init__.py +0 -25
  93. sglang/srt/server.py +0 -18
  94. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  95. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -142,6 +142,16 @@ def make_local_attention_virtual_batches(
142
142
  seqlens_k_local: Key sequence lengths for local attention
143
143
  block_table_local: Block table for local attention
144
144
  """
145
+ # Adjust attention_chunk_size based on the actual sequence length
146
+ # to avoid index out of bounds errors
147
+ max_seq_len = seq_lens_np.max()
148
+ effective_chunk_size = min(attn_chunk_size, max_seq_len)
149
+ # Make sure effective_chunk_size is divisible by page_size
150
+ effective_chunk_size = (effective_chunk_size // page_size) * page_size
151
+ if effective_chunk_size < page_size:
152
+ effective_chunk_size = page_size
153
+ attn_chunk_size = effective_chunk_size
154
+
145
155
  q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
146
156
  actual_batch_size = seq_lens_np.shape[0]
147
157
 
@@ -299,9 +309,7 @@ class FlashAttentionBackend(AttentionBackend):
299
309
  self.kv_cache_dtype = model_runner.kv_cache_dtype
300
310
  self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
301
311
  self.page_size = model_runner.page_size
302
- self.use_mla = (
303
- model_runner.model_config.attention_arch == AttentionArch.MLA
304
- ) and (not global_server_args_dict["disable_mla"])
312
+ self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
305
313
  self.skip_prefill = skip_prefill
306
314
 
307
315
  self.topk = topk
@@ -346,6 +354,8 @@ class FlashAttentionBackend(AttentionBackend):
346
354
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
347
355
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
348
356
  ]
357
+
358
+ self._init_local_attn_metadata(metadata, device)
349
359
  else:
350
360
  # Normal Decode
351
361
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
@@ -359,6 +369,8 @@ class FlashAttentionBackend(AttentionBackend):
359
369
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
360
370
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
361
371
  ]
372
+
373
+ self._init_local_attn_metadata(metadata, device)
362
374
  elif forward_batch.forward_mode.is_target_verify():
363
375
  metadata.cache_seqlens_int32 = (
364
376
  forward_batch.seq_lens + self.speculative_num_draft_tokens
@@ -407,49 +419,8 @@ class FlashAttentionBackend(AttentionBackend):
407
419
  metadata.cu_seqlens_q = metadata.cu_seqlens_k
408
420
 
409
421
  # Setup local attention if enabled
410
- if (
411
- self.attention_chunk_size is not None
412
- and forward_batch.forward_mode == ForwardMode.EXTEND
413
- ):
414
- # Convert tensors to numpy for local attention processing
415
- cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
416
- seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
417
-
418
- # Adjust attention_chunk_size based on the actual sequence length
419
- # to avoid index out of bounds errors
420
- max_seq_len = seq_lens_np.max()
421
- effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
422
- # Make sure effective_chunk_size is divisible by page_size
423
- effective_chunk_size = (
424
- effective_chunk_size // self.page_size
425
- ) * self.page_size
426
- if effective_chunk_size < self.page_size:
427
- effective_chunk_size = self.page_size
428
-
429
- # Create local attention metadata
430
- (
431
- seqlens_q_local_np,
432
- cu_seqlens_q_local_np,
433
- seqlens_k_local_np,
434
- block_table_local,
435
- ) = make_local_attention_virtual_batches(
436
- effective_chunk_size,
437
- cu_seqlens_q_np,
438
- seq_lens_np,
439
- metadata.page_table,
440
- self.page_size,
441
- )
442
-
443
- local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
444
- local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
445
- device
446
- ),
447
- local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
448
- local_block_table=block_table_local,
449
- local_max_query_len=seqlens_q_local_np.max(),
450
- local_max_seq_len=seqlens_k_local_np.max(),
451
- )
452
- metadata.local_attn_metadata = local_metadata
422
+ if forward_batch.forward_mode == ForwardMode.EXTEND:
423
+ self._init_local_attn_metadata(metadata, device)
453
424
 
454
425
  # Encoder metadata for cross attention
455
426
  if forward_batch.encoder_lens is not None:
@@ -706,6 +677,10 @@ class FlashAttentionBackend(AttentionBackend):
706
677
 
707
678
  # Use precomputed metadata across all layers
708
679
  metadata = self.forward_metadata
680
+ local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
681
+ use_local_attention = (
682
+ self.attention_chunk_size is not None and local_attn_metadata is not None
683
+ )
709
684
 
710
685
  # Calculate window size (can be moved to metadata if layer properties don't change)
711
686
  # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
@@ -740,33 +715,60 @@ class FlashAttentionBackend(AttentionBackend):
740
715
  -1, self.page_size, layer.tp_v_head_num, layer.head_dim
741
716
  )
742
717
 
743
- q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
744
718
  if layer.is_cross_attention:
745
- page_table = metadata.encoder_page_table
746
- cache_seqlens = metadata.encoder_lens_int32
747
- cu_seqlens_k = metadata.encoder_cu_seqlens_k
748
- window_size = (-1, -1)
719
+ # Always use non-chunked logic for cross-attention
720
+ o = flash_attn_with_kvcache(
721
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
722
+ k_cache=key_cache,
723
+ v_cache=value_cache,
724
+ page_table=metadata.encoder_page_table,
725
+ cache_seqlens=metadata.encoder_lens_int32,
726
+ cu_seqlens_q=metadata.cu_seqlens_q,
727
+ cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,
728
+ max_seqlen_q=1,
729
+ softmax_scale=layer.scaling,
730
+ causal=False,
731
+ window_size=(-1, -1),
732
+ softcap=layer.logit_cap,
733
+ k_descale=k_descale,
734
+ v_descale=v_descale,
735
+ )
736
+ elif use_local_attention:
737
+ # Use chunked (local) attention batching for self-attention
738
+ o = flash_attn_with_kvcache(
739
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
740
+ k_cache=key_cache,
741
+ v_cache=value_cache,
742
+ page_table=local_attn_metadata.local_block_table,
743
+ cache_seqlens=local_attn_metadata.local_seqused_k,
744
+ cu_seqlens_q=local_attn_metadata.local_query_start_loc,
745
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
746
+ max_seqlen_q=local_attn_metadata.local_max_query_len,
747
+ softmax_scale=layer.scaling,
748
+ causal=True,
749
+ window_size=(-1, -1),
750
+ softcap=layer.logit_cap,
751
+ k_descale=k_descale,
752
+ v_descale=v_descale,
753
+ )
749
754
  else:
750
- page_table = metadata.page_table
751
- cache_seqlens = metadata.cache_seqlens_int32
752
- cu_seqlens_k = metadata.cu_seqlens_k
753
-
754
- o = flash_attn_with_kvcache(
755
- q=q_reshaped,
756
- k_cache=key_cache,
757
- v_cache=value_cache,
758
- page_table=page_table,
759
- cache_seqlens=cache_seqlens,
760
- cu_seqlens_q=metadata.cu_seqlens_q,
761
- cu_seqlens_k_new=cu_seqlens_k,
762
- max_seqlen_q=1,
763
- softmax_scale=layer.scaling,
764
- causal=causal,
765
- window_size=window_size,
766
- softcap=layer.logit_cap,
767
- k_descale=k_descale,
768
- v_descale=v_descale,
769
- )
755
+ # Default: single-token self-attention
756
+ o = flash_attn_with_kvcache(
757
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
758
+ k_cache=key_cache,
759
+ v_cache=value_cache,
760
+ page_table=metadata.page_table,
761
+ cache_seqlens=metadata.cache_seqlens_int32,
762
+ cu_seqlens_q=metadata.cu_seqlens_q,
763
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
764
+ max_seqlen_q=1,
765
+ softmax_scale=layer.scaling,
766
+ causal=True,
767
+ window_size=window_size,
768
+ softcap=layer.logit_cap,
769
+ k_descale=k_descale,
770
+ v_descale=v_descale,
771
+ )
770
772
  else:
771
773
  # Do absorbed multi-latent attention
772
774
  kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
@@ -988,6 +990,8 @@ class FlashAttentionBackend(AttentionBackend):
988
990
  seq_lens = seq_lens[:bs]
989
991
  seq_lens_cpu = seq_lens_cpu[:bs]
990
992
  req_pool_indices = req_pool_indices[:bs]
993
+ device = seq_lens.device
994
+
991
995
  if forward_mode.is_decode_or_idle():
992
996
  metadata = self.decode_cuda_graph_metadata[bs]
993
997
 
@@ -1014,6 +1018,8 @@ class FlashAttentionBackend(AttentionBackend):
1014
1018
  ]
1015
1019
 
1016
1020
  metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1021
+
1022
+ self._init_local_attn_metadata(metadata, device)
1017
1023
  else:
1018
1024
  # Normal Decode
1019
1025
  max_len = seq_lens_cpu.max().item()
@@ -1037,6 +1043,7 @@ class FlashAttentionBackend(AttentionBackend):
1037
1043
  metadata.page_table[:, :max_seq_pages].copy_(page_indices)
1038
1044
  metadata.page_table[:, max_seq_pages:].fill_(0)
1039
1045
 
1046
+ self._init_local_attn_metadata(metadata, device)
1040
1047
  elif forward_mode.is_target_verify():
1041
1048
  metadata = self.target_verify_metadata[bs]
1042
1049
  metadata.cache_seqlens_int32.copy_(
@@ -1087,6 +1094,42 @@ class FlashAttentionBackend(AttentionBackend):
1087
1094
  """Get the fill value for sequence length in CUDA graph."""
1088
1095
  return 0
1089
1096
 
1097
+ def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
1098
+ """Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
1099
+ if self.attention_chunk_size is None:
1100
+ metadata.local_attn_metadata = None
1101
+ return
1102
+
1103
+ cu_seqlens_q = metadata.cu_seqlens_q
1104
+ cache_seqlens_int32 = metadata.cache_seqlens_int32
1105
+ page_table = metadata.page_table
1106
+ if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
1107
+ metadata.local_attn_metadata = None
1108
+ return
1109
+
1110
+ cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
1111
+ seq_lens_np = cache_seqlens_int32.cpu().numpy()
1112
+ (
1113
+ seqlens_q_local_np,
1114
+ cu_seqlens_q_local_np,
1115
+ seqlens_k_local_np,
1116
+ block_table_local,
1117
+ ) = make_local_attention_virtual_batches(
1118
+ self.attention_chunk_size,
1119
+ cu_seqlens_q_np,
1120
+ seq_lens_np,
1121
+ page_table,
1122
+ self.page_size,
1123
+ )
1124
+ local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
1125
+ local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
1126
+ local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
1127
+ local_block_table=block_table_local.to(device),
1128
+ local_max_query_len=int(seqlens_q_local_np.max()),
1129
+ local_max_seq_len=int(seqlens_k_local_np.max()),
1130
+ )
1131
+ metadata.local_attn_metadata = local_metadata
1132
+
1090
1133
 
1091
1134
  class FlashAttentionMultiStepBackend:
1092
1135
 
@@ -100,8 +100,11 @@ class FlashInferAttnBackend(AttentionBackend):
100
100
  self.num_wrappers = 1
101
101
  self.dispatch_reason = None
102
102
 
103
- # Qwen2 models require higher flashinfer workspace size
104
- if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
103
+ # Qwen2/Qwen3 models require higher flashinfer workspace size
104
+ if (
105
+ "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures
106
+ or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
107
+ ):
105
108
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
106
109
 
107
110
  # Allocate buffers
@@ -6,6 +6,7 @@ import torch
6
6
  from torch.nn.functional import scaled_dot_product_attention
7
7
 
8
8
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
9
+ from sglang.srt.layers.radix_attention import AttentionType
9
10
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
10
11
 
11
12
  if TYPE_CHECKING:
@@ -202,6 +203,10 @@ class TorchNativeAttnBackend(AttentionBackend):
202
203
  q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
203
204
  o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
204
205
 
206
+ causal = True
207
+ if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
208
+ causal = False
209
+
205
210
  self._run_sdpa_forward_extend(
206
211
  q_,
207
212
  o_,
@@ -214,7 +219,7 @@ class TorchNativeAttnBackend(AttentionBackend):
214
219
  forward_batch.extend_seq_lens,
215
220
  scaling=layer.scaling,
216
221
  enable_gqa=use_gqa,
217
- causal=not layer.is_cross_attention,
222
+ causal=causal,
218
223
  )
219
224
  return o
220
225
 
@@ -10,6 +10,7 @@ import triton.language as tl
10
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
11
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
12
12
  from sglang.srt.layers.dp_attention import get_attention_tp_size
13
+ from sglang.srt.layers.radix_attention import AttentionType
13
14
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
14
15
  from sglang.srt.utils import get_bool_env_var, get_device_core_count
15
16
 
@@ -528,6 +529,10 @@ class TritonAttnBackend(AttentionBackend):
528
529
  layer, forward_batch.out_cache_loc, k, v
529
530
  )
530
531
 
532
+ causal = True
533
+ if layer.attn_type == AttentionType.ENCODER_ONLY:
534
+ causal = False
535
+
531
536
  self.extend_attention_fwd(
532
537
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
533
538
  k.contiguous(),
@@ -539,6 +544,7 @@ class TritonAttnBackend(AttentionBackend):
539
544
  self.forward_metadata.kv_indptr,
540
545
  self.forward_metadata.kv_indices,
541
546
  self.forward_metadata.custom_mask,
547
+ causal,
542
548
  self.forward_metadata.mask_indptr,
543
549
  self.forward_metadata.max_extend_len,
544
550
  layer.scaling,
@@ -74,6 +74,7 @@ def _fwd_kernel(
74
74
  BLOCK_M: tl.constexpr,
75
75
  BLOCK_N: tl.constexpr,
76
76
  USE_CUSTOM_MASK: tl.constexpr,
77
+ IS_CAUSAL: tl.constexpr,
77
78
  SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
78
79
  STORE_TRANSPOSE: tl.constexpr,
79
80
  ):
@@ -129,6 +130,7 @@ def _fwd_kernel(
129
130
  for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
130
131
  start_n = tl.multiple_of(start_n, BLOCK_N)
131
132
  mask_n = (start_n + offs_n) < cur_seq_len_prefix
133
+
132
134
  offs_kv_loc = tl.load(
133
135
  kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
134
136
  )
@@ -196,7 +198,11 @@ def _fwd_kernel(
196
198
 
197
199
  # stage 2: compute the triangle part
198
200
 
199
- cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
201
+ cur_block_m_end = (
202
+ cur_seq_len_extend
203
+ if not IS_CAUSAL
204
+ else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
205
+ )
200
206
  for start_n in range(0, cur_block_m_end, BLOCK_N):
201
207
  start_n = tl.multiple_of(start_n, BLOCK_N)
202
208
  mask_n = (start_n + offs_n) < cur_block_m_end
@@ -243,12 +249,15 @@ def _fwd_kernel(
243
249
  )
244
250
  custom_mask &= mask_m[:, None] & mask_n[None, :]
245
251
  qk = tl.where(custom_mask, qk, float("-inf"))
246
- else:
252
+ elif IS_CAUSAL:
247
253
  mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
248
254
  start_n + offs_n[None, :]
249
255
  )
250
256
  mask_causual &= mask_m[:, None] & mask_n[None, :]
251
257
  qk = tl.where(mask_causual, qk, float("-inf"))
258
+ else:
259
+ mask_non_causal = mask_m[:, None] & mask_n[None, :]
260
+ qk = tl.where(mask_non_causal, qk, float("-inf"))
252
261
 
253
262
  n_e_max = tl.maximum(tl.max(qk, 1), e_max)
254
263
  re_scale = tl.exp(e_max - n_e_max)
@@ -299,6 +308,7 @@ def extend_attention_fwd(
299
308
  kv_indptr,
300
309
  kv_indices,
301
310
  custom_mask,
311
+ is_causal,
302
312
  mask_indptr,
303
313
  max_len_extend,
304
314
  sm_scale=None,
@@ -411,6 +421,7 @@ def extend_attention_fwd(
411
421
  Lq=Lq,
412
422
  Lv=Lv,
413
423
  USE_CUSTOM_MASK=USE_CUSTOM_MASK,
424
+ IS_CAUSAL=is_causal,
414
425
  SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
415
426
  STORE_TRANSPOSE=_is_hip,
416
427
  num_warps=num_warps,
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
19
19
  import torch
20
20
  import torch.nn as nn
21
21
 
22
+ from sglang.srt.custom_op import CustomOp
22
23
  from sglang.srt.utils import is_cuda_available
23
24
 
24
25
  _is_cuda = is_cuda_available()
@@ -31,7 +32,6 @@ if _is_cuda:
31
32
  rmsnorm,
32
33
  )
33
34
 
34
- from sglang.srt.custom_op import CustomOp
35
35
 
36
36
  logger = logging.getLogger(__name__)
37
37
 
@@ -1,5 +1,6 @@
1
1
  """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
2
2
 
3
+ import itertools
3
4
  import logging
4
5
  from abc import abstractmethod
5
6
  from typing import Dict, List, Optional, Tuple
@@ -61,12 +62,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
61
62
 
62
63
 
63
64
  def adjust_bitsandbytes_4bit_shard(
64
- param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
65
+ param: Parameter, shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
65
66
  ) -> Tuple[int, int]:
66
67
  """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
67
68
 
68
- total, _ = qkv_offsets["total"]
69
- orig_offset, orig_size = qkv_offsets[loaded_shard_id]
69
+ total, _ = shard_offsets["total"]
70
+ orig_offset, orig_size = shard_offsets[loaded_shard_id]
70
71
 
71
72
  quantized_total = param.data.shape[0]
72
73
  quantized_offset = orig_offset * quantized_total // total
@@ -573,6 +574,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
573
574
  shard_offsets.append((i, current_shard_offset, output_size))
574
575
  current_shard_offset += output_size
575
576
  packed_dim = getattr(param, "packed_dim", None)
577
+
578
+ use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
576
579
  for shard_id, shard_offset, shard_size in shard_offsets:
577
580
  # Special case for Quantization.
578
581
  # If quantized, we need to adjust the offset and size to account
@@ -585,6 +588,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
585
588
  param, shard_size, shard_offset
586
589
  )
587
590
 
591
+ if use_bitsandbytes_4bit:
592
+ index = list(itertools.accumulate([0] + self.output_sizes))
593
+ orig_offsets = {
594
+ str(i): (index[i], size)
595
+ for i, size in enumerate(self.output_sizes)
596
+ }
597
+ orig_offsets["total"] = (self.output_size, 0)
598
+ shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
599
+ param, orig_offsets, str(shard_id)
600
+ )
601
+
588
602
  loaded_weight_shard = loaded_weight.narrow(
589
603
  output_dim, shard_offset, shard_size
590
604
  )
@@ -2,6 +2,7 @@ import logging
2
2
  from typing import Callable, List, Optional, Tuple
3
3
 
4
4
  import torch
5
+ from torch.nn import Module
5
6
 
6
7
  try:
7
8
  from deep_gemm import (
@@ -13,8 +14,6 @@ try:
13
14
  except ImportError:
14
15
  use_deep_gemm = False
15
16
 
16
- from torch.nn import Module
17
-
18
17
  from sglang.srt.custom_op import CustomOp
19
18
  from sglang.srt.distributed import (
20
19
  get_tensor_model_parallel_rank,
@@ -37,22 +36,17 @@ from sglang.srt.layers.quantization.base_config import (
37
36
  QuantizeMethodBase,
38
37
  )
39
38
  from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
39
+ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
40
40
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
41
- from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
41
+ from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
42
42
 
43
- _is_cuda = is_cuda()
43
+ _is_hip = is_hip()
44
44
 
45
- if _is_cuda:
46
- from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
47
- else:
48
- from vllm import _custom_ops as vllm_ops
45
+ if _is_hip:
46
+ from vllm._custom_ops import scaled_fp8_quant
49
47
 
50
48
  logger = logging.getLogger(__name__)
51
49
 
52
- _is_hip = is_hip()
53
-
54
- _buffer = None
55
-
56
50
 
57
51
  class GroupedGemmRunner(torch.nn.Module):
58
52
  flashinfer_gemm_warpper = None
@@ -142,6 +136,7 @@ class EPMoE(torch.nn.Module):
142
136
  correction_bias: Optional[torch.Tensor] = None,
143
137
  custom_routing_function: Optional[Callable] = None,
144
138
  activation: str = "silu",
139
+ routed_scaling_factor: Optional[float] = None,
145
140
  ):
146
141
  super().__init__()
147
142
 
@@ -170,6 +165,7 @@ class EPMoE(torch.nn.Module):
170
165
  self.correction_bias = correction_bias
171
166
  self.custom_routing_function = custom_routing_function
172
167
  self.activation = activation
168
+ self.routed_scaling_factor = routed_scaling_factor
173
169
 
174
170
  if quant_config is None:
175
171
  self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
@@ -221,6 +217,7 @@ class EPMoE(torch.nn.Module):
221
217
  num_expert_group=self.num_expert_group,
222
218
  correction_bias=self.correction_bias,
223
219
  custom_routing_function=self.custom_routing_function,
220
+ routed_scaling_factor=self.routed_scaling_factor,
224
221
  )
225
222
 
226
223
  reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
@@ -740,20 +737,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
740
737
  )
741
738
 
742
739
  for expert in range(layer.num_experts_per_partition):
743
- if _is_cuda:
744
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
745
- sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
746
- )
747
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
748
- sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
749
- )
750
- else:
751
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
752
- vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
753
- )
754
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
755
- vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
756
- )
740
+ w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
741
+ scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
742
+ )
743
+ w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
744
+ scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
745
+ )
757
746
  layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
758
747
  layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
759
748
  return
@@ -986,9 +975,6 @@ class DeepEPMoE(EPMoE):
986
975
  ):
987
976
  assert self.quant_method is not None
988
977
  assert self.activation == "silu"
989
- assert (
990
- hidden_states_fp8[0].size(0) % 4 == 0
991
- ), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"
992
978
 
993
979
  # GroupGemm-0
994
980
  num_groups, m, k = hidden_states_fp8[0].size()
@@ -26,6 +26,7 @@ def fused_moe_forward_native(
26
26
  apply_router_weight_on_input: bool = False,
27
27
  inplace: bool = True,
28
28
  no_combine: bool = False,
29
+ routed_scaling_factor: Optional[float] = None,
29
30
  ) -> torch.Tensor:
30
31
 
31
32
  if apply_router_weight_on_input:
@@ -41,6 +42,7 @@ def fused_moe_forward_native(
41
42
  num_expert_group=num_expert_group,
42
43
  custom_routing_function=custom_routing_function,
43
44
  correction_bias=correction_bias,
45
+ routed_scaling_factor=routed_scaling_factor,
44
46
  torch_native=True,
45
47
  )
46
48
 
@@ -71,6 +73,7 @@ def moe_forward_native(
71
73
  custom_routing_function: Optional[Callable] = None,
72
74
  correction_bias: Optional[torch.Tensor] = None,
73
75
  activation: str = "silu",
76
+ routed_scaling_factor: Optional[float] = None,
74
77
  ) -> torch.Tensor:
75
78
 
76
79
  from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
@@ -86,6 +89,7 @@ def moe_forward_native(
86
89
  custom_routing_function=custom_routing_function,
87
90
  correction_bias=correction_bias,
88
91
  torch_native=True,
92
+ routed_scaling_factor=routed_scaling_factor,
89
93
  )
90
94
 
91
95
  # Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
@@ -13,6 +13,7 @@ import triton
13
13
  import triton.language as tl
14
14
 
15
15
  from sglang.srt.layers.moe.topk import select_experts
16
+ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
16
17
  from sglang.srt.utils import (
17
18
  direct_register_custom_op,
18
19
  get_bool_env_var,
@@ -22,28 +23,25 @@ from sglang.srt.utils import (
22
23
  )
23
24
 
24
25
  _is_hip = is_hip()
25
-
26
-
27
- logger = logging.getLogger(__name__)
28
- padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
29
-
30
- enable_moe_align_block_size_triton = bool(
31
- int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
32
- )
33
-
34
26
  _is_cuda = is_cuda()
35
27
 
36
28
  if _is_cuda:
37
29
  from sgl_kernel import gelu_and_mul, silu_and_mul
38
-
39
- from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
40
30
  else:
41
31
  from vllm import _custom_ops as vllm_ops
32
+ from vllm._custom_ops import scaled_fp8_quant
42
33
 
43
34
  if _is_cuda or _is_hip:
44
35
  from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
45
36
 
46
37
 
38
+ logger = logging.getLogger(__name__)
39
+ padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
40
+ enable_moe_align_block_size_triton = bool(
41
+ int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
42
+ )
43
+
44
+
47
45
  @triton.jit
48
46
  def write_zeros_to_output(
49
47
  c_ptr,
@@ -770,14 +768,9 @@ def invoke_fused_moe_kernel(
770
768
  # activation tensor-wise fp8 quantization, dynamic or static
771
769
  padded_size = padding_size
772
770
  # activations apply per-token quantization when weights apply per-channel quantization by default
773
- if _is_cuda:
774
- A, A_scale = sgl_scaled_fp8_quant(
775
- A, A_scale, use_per_token_if_dynamic=per_channel_quant
776
- )
777
- else:
778
- A, A_scale = vllm_ops.scaled_fp8_quant(
779
- A, A_scale, use_per_token_if_dynamic=per_channel_quant
780
- )
771
+ A, A_scale = scaled_fp8_quant(
772
+ A, A_scale, use_per_token_if_dynamic=per_channel_quant
773
+ )
781
774
  else:
782
775
  # activation block-wise fp8 quantization
783
776
  assert len(block_shape) == 2
@@ -1554,6 +1547,7 @@ def fused_moe(
1554
1547
  a2_scale: Optional[torch.Tensor] = None,
1555
1548
  block_shape: Optional[List[int]] = None,
1556
1549
  no_combine: bool = False,
1550
+ routed_scaling_factor: Optional[float] = None,
1557
1551
  ) -> torch.Tensor:
1558
1552
  """
1559
1553
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1608,6 +1602,7 @@ def fused_moe(
1608
1602
  topk_group=topk_group,
1609
1603
  num_expert_group=num_expert_group,
1610
1604
  custom_routing_function=custom_routing_function,
1605
+ routed_scaling_factor=routed_scaling_factor,
1611
1606
  )
1612
1607
 
1613
1608
  return fused_experts(