sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -82,6 +82,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
82
82
|
self.max_context_len = model_runner.model_config.context_len
|
83
83
|
self.skip_prefill = skip_prefill
|
84
84
|
self.is_multimodal = model_runner.model_config.is_multimodal
|
85
|
+
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
86
|
+
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
85
87
|
|
86
88
|
assert not (
|
87
89
|
model_runner.sliding_window_size is not None
|
@@ -98,8 +100,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
98
100
|
self.num_wrappers = 1
|
99
101
|
self.dispatch_reason = None
|
100
102
|
|
101
|
-
# Qwen2 models require higher flashinfer workspace size
|
102
|
-
if
|
103
|
+
# Qwen2/Qwen3 models require higher flashinfer workspace size
|
104
|
+
if (
|
105
|
+
"Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures
|
106
|
+
or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
|
107
|
+
):
|
103
108
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
104
109
|
|
105
110
|
# Allocate buffers
|
@@ -391,6 +396,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
391
396
|
forward_batch: ForwardBatch,
|
392
397
|
save_kv_cache=True,
|
393
398
|
):
|
399
|
+
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
400
|
+
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
394
401
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
395
402
|
self._get_wrapper_idx(layer)
|
396
403
|
]
|
@@ -407,7 +414,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
407
414
|
assert v is not None
|
408
415
|
if save_kv_cache:
|
409
416
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
410
|
-
layer, cache_loc, k, v,
|
417
|
+
layer, cache_loc, k, v, k_scale, v_scale
|
411
418
|
)
|
412
419
|
|
413
420
|
o = prefill_wrapper_paged.forward(
|
@@ -417,8 +424,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
417
424
|
sm_scale=layer.scaling,
|
418
425
|
window_left=layer.sliding_window_size,
|
419
426
|
logits_soft_cap=logits_soft_cap,
|
420
|
-
k_scale=
|
421
|
-
v_scale=
|
427
|
+
k_scale=k_scale,
|
428
|
+
v_scale=v_scale,
|
422
429
|
)
|
423
430
|
else:
|
424
431
|
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
@@ -445,7 +452,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
445
452
|
|
446
453
|
if save_kv_cache:
|
447
454
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
448
|
-
layer, cache_loc, k, v,
|
455
|
+
layer, cache_loc, k, v, k_scale, v_scale
|
449
456
|
)
|
450
457
|
|
451
458
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
@@ -459,6 +466,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
459
466
|
forward_batch: ForwardBatch,
|
460
467
|
save_kv_cache=True,
|
461
468
|
):
|
469
|
+
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
470
|
+
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
462
471
|
decode_wrapper = self.forward_metadata.decode_wrappers[
|
463
472
|
self._get_wrapper_idx(layer)
|
464
473
|
]
|
@@ -472,7 +481,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
472
481
|
assert v is not None
|
473
482
|
if save_kv_cache:
|
474
483
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
475
|
-
layer, cache_loc, k, v,
|
484
|
+
layer, cache_loc, k, v, k_scale, v_scale
|
476
485
|
)
|
477
486
|
|
478
487
|
o = decode_wrapper.forward(
|
@@ -480,8 +489,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
480
489
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
481
490
|
sm_scale=layer.scaling,
|
482
491
|
logits_soft_cap=layer.logit_cap,
|
483
|
-
k_scale=
|
484
|
-
v_scale=
|
492
|
+
k_scale=k_scale,
|
493
|
+
v_scale=v_scale,
|
485
494
|
)
|
486
495
|
|
487
496
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
@@ -6,6 +6,7 @@ import torch
|
|
6
6
|
from torch.nn.functional import scaled_dot_product_attention
|
7
7
|
|
8
8
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
9
|
+
from sglang.srt.layers.radix_attention import AttentionType
|
9
10
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
10
11
|
|
11
12
|
if TYPE_CHECKING:
|
@@ -202,6 +203,10 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
202
203
|
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
203
204
|
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
204
205
|
|
206
|
+
causal = True
|
207
|
+
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
208
|
+
causal = False
|
209
|
+
|
205
210
|
self._run_sdpa_forward_extend(
|
206
211
|
q_,
|
207
212
|
o_,
|
@@ -214,7 +219,7 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
214
219
|
forward_batch.extend_seq_lens,
|
215
220
|
scaling=layer.scaling,
|
216
221
|
enable_gqa=use_gqa,
|
217
|
-
causal=
|
222
|
+
causal=causal,
|
218
223
|
)
|
219
224
|
return o
|
220
225
|
|
@@ -10,6 +10,7 @@ import triton.language as tl
|
|
10
10
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
11
11
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
12
12
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
13
|
+
from sglang.srt.layers.radix_attention import AttentionType
|
13
14
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
14
15
|
from sglang.srt.utils import get_bool_env_var, get_device_core_count
|
15
16
|
|
@@ -528,6 +529,10 @@ class TritonAttnBackend(AttentionBackend):
|
|
528
529
|
layer, forward_batch.out_cache_loc, k, v
|
529
530
|
)
|
530
531
|
|
532
|
+
causal = True
|
533
|
+
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
534
|
+
causal = False
|
535
|
+
|
531
536
|
self.extend_attention_fwd(
|
532
537
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
533
538
|
k.contiguous(),
|
@@ -539,6 +544,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
539
544
|
self.forward_metadata.kv_indptr,
|
540
545
|
self.forward_metadata.kv_indices,
|
541
546
|
self.forward_metadata.custom_mask,
|
547
|
+
causal,
|
542
548
|
self.forward_metadata.mask_indptr,
|
543
549
|
self.forward_metadata.max_extend_len,
|
544
550
|
layer.scaling,
|
@@ -74,6 +74,7 @@ def _fwd_kernel(
|
|
74
74
|
BLOCK_M: tl.constexpr,
|
75
75
|
BLOCK_N: tl.constexpr,
|
76
76
|
USE_CUSTOM_MASK: tl.constexpr,
|
77
|
+
IS_CAUSAL: tl.constexpr,
|
77
78
|
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
78
79
|
STORE_TRANSPOSE: tl.constexpr,
|
79
80
|
):
|
@@ -129,6 +130,7 @@ def _fwd_kernel(
|
|
129
130
|
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
130
131
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
131
132
|
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
133
|
+
|
132
134
|
offs_kv_loc = tl.load(
|
133
135
|
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
|
134
136
|
)
|
@@ -196,7 +198,11 @@ def _fwd_kernel(
|
|
196
198
|
|
197
199
|
# stage 2: compute the triangle part
|
198
200
|
|
199
|
-
cur_block_m_end =
|
201
|
+
cur_block_m_end = (
|
202
|
+
cur_seq_len_extend
|
203
|
+
if not IS_CAUSAL
|
204
|
+
else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
205
|
+
)
|
200
206
|
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
201
207
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
202
208
|
mask_n = (start_n + offs_n) < cur_block_m_end
|
@@ -243,12 +249,15 @@ def _fwd_kernel(
|
|
243
249
|
)
|
244
250
|
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
245
251
|
qk = tl.where(custom_mask, qk, float("-inf"))
|
246
|
-
|
252
|
+
elif IS_CAUSAL:
|
247
253
|
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
248
254
|
start_n + offs_n[None, :]
|
249
255
|
)
|
250
256
|
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
251
257
|
qk = tl.where(mask_causual, qk, float("-inf"))
|
258
|
+
else:
|
259
|
+
mask_non_causal = mask_m[:, None] & mask_n[None, :]
|
260
|
+
qk = tl.where(mask_non_causal, qk, float("-inf"))
|
252
261
|
|
253
262
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
254
263
|
re_scale = tl.exp(e_max - n_e_max)
|
@@ -299,6 +308,7 @@ def extend_attention_fwd(
|
|
299
308
|
kv_indptr,
|
300
309
|
kv_indices,
|
301
310
|
custom_mask,
|
311
|
+
is_causal,
|
302
312
|
mask_indptr,
|
303
313
|
max_len_extend,
|
304
314
|
sm_scale=None,
|
@@ -411,6 +421,7 @@ def extend_attention_fwd(
|
|
411
421
|
Lq=Lq,
|
412
422
|
Lv=Lv,
|
413
423
|
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
424
|
+
IS_CAUSAL=is_causal,
|
414
425
|
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
415
426
|
STORE_TRANSPOSE=_is_hip,
|
416
427
|
num_warps=num_warps,
|
@@ -192,8 +192,7 @@ def _dp_gather(
|
|
192
192
|
|
193
193
|
if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
|
194
194
|
assert (
|
195
|
-
|
196
|
-
!= local_tokens.untyped_storage().data_ptr()
|
195
|
+
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
197
196
|
), "aliasing between global_tokens and local_tokens not allowed"
|
198
197
|
memcpy_triton(
|
199
198
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
@@ -243,8 +242,7 @@ def dp_scatter(
|
|
243
242
|
assert global_tokens.is_contiguous()
|
244
243
|
if local_tokens.shape[0] > 0:
|
245
244
|
assert (
|
246
|
-
local_tokens.untyped_storage().
|
247
|
-
!= global_tokens.untyped_storage().data_ptr()
|
245
|
+
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
248
246
|
), "aliasing between local_tokens and global_tokens not allowed"
|
249
247
|
memcpy_triton(
|
250
248
|
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
sglang/srt/layers/elementwise.py
CHANGED
@@ -4,6 +4,10 @@ import torch
|
|
4
4
|
import triton
|
5
5
|
import triton.language as tl
|
6
6
|
|
7
|
+
from sglang.srt.utils import is_hip
|
8
|
+
|
9
|
+
_is_hip = is_hip()
|
10
|
+
|
7
11
|
fused_softcap_autotune = triton.autotune(
|
8
12
|
configs=[
|
9
13
|
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
|
@@ -185,6 +189,9 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
|
|
185
189
|
assert x.shape == residual.shape and x.dtype == residual.dtype
|
186
190
|
output, mid = torch.empty_like(x), torch.empty_like(x)
|
187
191
|
bs, hidden_dim = x.shape
|
192
|
+
|
193
|
+
min_num_warps = 16 if _is_hip else 32
|
194
|
+
|
188
195
|
if autotune:
|
189
196
|
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
|
190
197
|
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
|
@@ -193,7 +200,10 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
|
|
193
200
|
config = {
|
194
201
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
195
202
|
"num_warps": max(
|
196
|
-
min(
|
203
|
+
min(
|
204
|
+
triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
|
205
|
+
),
|
206
|
+
4,
|
197
207
|
),
|
198
208
|
}
|
199
209
|
|
@@ -250,10 +260,13 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
|
|
250
260
|
else:
|
251
261
|
output = torch.empty_like(x)
|
252
262
|
bs, hidden_dim = x.shape
|
263
|
+
|
264
|
+
min_num_warps = 16 if _is_hip else 32
|
265
|
+
|
253
266
|
config = {
|
254
267
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
255
268
|
"num_warps": max(
|
256
|
-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)),
|
269
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
|
257
270
|
),
|
258
271
|
}
|
259
272
|
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
|
|
19
19
|
import torch
|
20
20
|
import torch.nn as nn
|
21
21
|
|
22
|
+
from sglang.srt.custom_op import CustomOp
|
22
23
|
from sglang.srt.utils import is_cuda_available
|
23
24
|
|
24
25
|
_is_cuda = is_cuda_available()
|
@@ -31,7 +32,6 @@ if _is_cuda:
|
|
31
32
|
rmsnorm,
|
32
33
|
)
|
33
34
|
|
34
|
-
from sglang.srt.custom_op import CustomOp
|
35
35
|
|
36
36
|
logger = logging.getLogger(__name__)
|
37
37
|
|
sglang/srt/layers/linear.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
|
2
2
|
|
3
|
+
import itertools
|
3
4
|
import logging
|
4
5
|
from abc import abstractmethod
|
5
6
|
from typing import Dict, List, Optional, Tuple
|
@@ -47,6 +48,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|
47
48
|
"GPTQLinearMethod",
|
48
49
|
"FBGEMMFp8LinearMethod",
|
49
50
|
"ModelOptFp8LinearMethod",
|
51
|
+
"ModelOptFp4LinearMethod",
|
50
52
|
"IPEXAWQLinearMethod",
|
51
53
|
]
|
52
54
|
|
@@ -60,12 +62,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
|
|
60
62
|
|
61
63
|
|
62
64
|
def adjust_bitsandbytes_4bit_shard(
|
63
|
-
param: Parameter,
|
65
|
+
param: Parameter, shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
|
64
66
|
) -> Tuple[int, int]:
|
65
67
|
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
66
68
|
|
67
|
-
total, _ =
|
68
|
-
orig_offset, orig_size =
|
69
|
+
total, _ = shard_offsets["total"]
|
70
|
+
orig_offset, orig_size = shard_offsets[loaded_shard_id]
|
69
71
|
|
70
72
|
quantized_total = param.data.shape[0]
|
71
73
|
quantized_offset = orig_offset * quantized_total // total
|
@@ -572,6 +574,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
572
574
|
shard_offsets.append((i, current_shard_offset, output_size))
|
573
575
|
current_shard_offset += output_size
|
574
576
|
packed_dim = getattr(param, "packed_dim", None)
|
577
|
+
|
578
|
+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
575
579
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
576
580
|
# Special case for Quantization.
|
577
581
|
# If quantized, we need to adjust the offset and size to account
|
@@ -584,6 +588,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
584
588
|
param, shard_size, shard_offset
|
585
589
|
)
|
586
590
|
|
591
|
+
if use_bitsandbytes_4bit:
|
592
|
+
index = list(itertools.accumulate([0] + self.output_sizes))
|
593
|
+
orig_offsets = {
|
594
|
+
str(i): (index[i], size)
|
595
|
+
for i, size in enumerate(self.output_sizes)
|
596
|
+
}
|
597
|
+
orig_offsets["total"] = (self.output_size, 0)
|
598
|
+
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
599
|
+
param, orig_offsets, str(shard_id)
|
600
|
+
)
|
601
|
+
|
587
602
|
loaded_weight_shard = loaded_weight.narrow(
|
588
603
|
output_dim, shard_offset, shard_size
|
589
604
|
)
|
@@ -2,6 +2,7 @@ import logging
|
|
2
2
|
from typing import Callable, List, Optional, Tuple
|
3
3
|
|
4
4
|
import torch
|
5
|
+
from torch.nn import Module
|
5
6
|
|
6
7
|
try:
|
7
8
|
from deep_gemm import (
|
@@ -13,8 +14,6 @@ try:
|
|
13
14
|
except ImportError:
|
14
15
|
use_deep_gemm = False
|
15
16
|
|
16
|
-
from torch.nn import Module
|
17
|
-
|
18
17
|
from sglang.srt.custom_op import CustomOp
|
19
18
|
from sglang.srt.distributed import (
|
20
19
|
get_tensor_model_parallel_rank,
|
@@ -37,22 +36,17 @@ from sglang.srt.layers.quantization.base_config import (
|
|
37
36
|
QuantizeMethodBase,
|
38
37
|
)
|
39
38
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
39
|
+
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
40
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
41
|
-
from sglang.srt.utils import DeepEPMode,
|
41
|
+
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
|
42
42
|
|
43
|
-
|
43
|
+
_is_hip = is_hip()
|
44
44
|
|
45
|
-
if
|
46
|
-
from
|
47
|
-
else:
|
48
|
-
from vllm import _custom_ops as vllm_ops
|
45
|
+
if _is_hip:
|
46
|
+
from vllm._custom_ops import scaled_fp8_quant
|
49
47
|
|
50
48
|
logger = logging.getLogger(__name__)
|
51
49
|
|
52
|
-
_is_hip = is_hip()
|
53
|
-
|
54
|
-
_buffer = None
|
55
|
-
|
56
50
|
|
57
51
|
class GroupedGemmRunner(torch.nn.Module):
|
58
52
|
flashinfer_gemm_warpper = None
|
@@ -142,6 +136,7 @@ class EPMoE(torch.nn.Module):
|
|
142
136
|
correction_bias: Optional[torch.Tensor] = None,
|
143
137
|
custom_routing_function: Optional[Callable] = None,
|
144
138
|
activation: str = "silu",
|
139
|
+
routed_scaling_factor: Optional[float] = None,
|
145
140
|
):
|
146
141
|
super().__init__()
|
147
142
|
|
@@ -170,6 +165,7 @@ class EPMoE(torch.nn.Module):
|
|
170
165
|
self.correction_bias = correction_bias
|
171
166
|
self.custom_routing_function = custom_routing_function
|
172
167
|
self.activation = activation
|
168
|
+
self.routed_scaling_factor = routed_scaling_factor
|
173
169
|
|
174
170
|
if quant_config is None:
|
175
171
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
@@ -221,6 +217,7 @@ class EPMoE(torch.nn.Module):
|
|
221
217
|
num_expert_group=self.num_expert_group,
|
222
218
|
correction_bias=self.correction_bias,
|
223
219
|
custom_routing_function=self.custom_routing_function,
|
220
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
224
221
|
)
|
225
222
|
|
226
223
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
@@ -740,20 +737,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
740
737
|
)
|
741
738
|
|
742
739
|
for expert in range(layer.num_experts_per_partition):
|
743
|
-
|
744
|
-
w13_weight[expert, :, :]
|
745
|
-
|
746
|
-
|
747
|
-
w2_weight[expert, :, :]
|
748
|
-
|
749
|
-
)
|
750
|
-
else:
|
751
|
-
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
752
|
-
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
753
|
-
)
|
754
|
-
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
755
|
-
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
756
|
-
)
|
740
|
+
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
741
|
+
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
742
|
+
)
|
743
|
+
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
744
|
+
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
745
|
+
)
|
757
746
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
758
747
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
759
748
|
return
|
@@ -986,9 +975,6 @@ class DeepEPMoE(EPMoE):
|
|
986
975
|
):
|
987
976
|
assert self.quant_method is not None
|
988
977
|
assert self.activation == "silu"
|
989
|
-
assert (
|
990
|
-
hidden_states_fp8[0].size(0) % 4 == 0
|
991
|
-
), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"
|
992
978
|
|
993
979
|
# GroupGemm-0
|
994
980
|
num_groups, m, k = hidden_states_fp8[0].size()
|