sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +3 -6
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +6 -2
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +105 -6
- sglang/srt/disaggregation/mini_lb.py +74 -9
- sglang/srt/disaggregation/mooncake/conn.py +33 -63
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +137 -17
- sglang/srt/disaggregation/utils.py +32 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +883 -209
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +20 -5
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +17 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -88
- sglang/srt/layers/quantization/fp8_utils.py +189 -264
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +9 -8
- sglang/srt/layers/sampler.py +7 -12
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +15 -4
- sglang/srt/managers/scheduler.py +28 -77
- sglang/srt/managers/tokenizer_manager.py +116 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +41 -29
- sglang/srt/mem_cache/memory_pool.py +38 -15
- sglang/srt/model_executor/cuda_graph_runner.py +15 -10
- sglang/srt/model_executor/model_runner.py +39 -31
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +292 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +31 -203
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +86 -72
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +6 -14
- sglang/srt/utils.py +62 -6
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -100,8 +100,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
100
100
|
self.num_wrappers = 1
|
101
101
|
self.dispatch_reason = None
|
102
102
|
|
103
|
-
# Qwen2 models require higher flashinfer workspace size
|
104
|
-
if
|
103
|
+
# Qwen2/Qwen3 models require higher flashinfer workspace size
|
104
|
+
if (
|
105
|
+
"Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures
|
106
|
+
or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
|
107
|
+
):
|
105
108
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
106
109
|
|
107
110
|
# Allocate buffers
|
@@ -6,6 +6,7 @@ import torch
|
|
6
6
|
from torch.nn.functional import scaled_dot_product_attention
|
7
7
|
|
8
8
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
9
|
+
from sglang.srt.layers.radix_attention import AttentionType
|
9
10
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
10
11
|
|
11
12
|
if TYPE_CHECKING:
|
@@ -202,6 +203,10 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
202
203
|
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
203
204
|
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
204
205
|
|
206
|
+
causal = True
|
207
|
+
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
208
|
+
causal = False
|
209
|
+
|
205
210
|
self._run_sdpa_forward_extend(
|
206
211
|
q_,
|
207
212
|
o_,
|
@@ -214,7 +219,7 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
214
219
|
forward_batch.extend_seq_lens,
|
215
220
|
scaling=layer.scaling,
|
216
221
|
enable_gqa=use_gqa,
|
217
|
-
causal=
|
222
|
+
causal=causal,
|
218
223
|
)
|
219
224
|
return o
|
220
225
|
|
@@ -10,6 +10,7 @@ import triton.language as tl
|
|
10
10
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
11
11
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
12
12
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
13
|
+
from sglang.srt.layers.radix_attention import AttentionType
|
13
14
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
14
15
|
from sglang.srt.utils import get_bool_env_var, get_device_core_count
|
15
16
|
|
@@ -528,6 +529,10 @@ class TritonAttnBackend(AttentionBackend):
|
|
528
529
|
layer, forward_batch.out_cache_loc, k, v
|
529
530
|
)
|
530
531
|
|
532
|
+
causal = True
|
533
|
+
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
534
|
+
causal = False
|
535
|
+
|
531
536
|
self.extend_attention_fwd(
|
532
537
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
533
538
|
k.contiguous(),
|
@@ -539,6 +544,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
539
544
|
self.forward_metadata.kv_indptr,
|
540
545
|
self.forward_metadata.kv_indices,
|
541
546
|
self.forward_metadata.custom_mask,
|
547
|
+
causal,
|
542
548
|
self.forward_metadata.mask_indptr,
|
543
549
|
self.forward_metadata.max_extend_len,
|
544
550
|
layer.scaling,
|
@@ -3,10 +3,10 @@ import triton
|
|
3
3
|
import triton.language as tl
|
4
4
|
|
5
5
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
6
|
-
from sglang.srt.utils import is_hip
|
6
|
+
from sglang.srt.utils import is_cuda, is_hip
|
7
7
|
|
8
|
-
|
9
|
-
if
|
8
|
+
_is_cuda = is_cuda()
|
9
|
+
if _is_cuda:
|
10
10
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
11
11
|
|
12
12
|
_is_hip = is_hip()
|
@@ -1037,12 +1037,12 @@ def extend_attention_fwd(
|
|
1037
1037
|
num_warps = 4
|
1038
1038
|
|
1039
1039
|
else:
|
1040
|
-
if
|
1040
|
+
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
|
1041
1041
|
if Lq <= 256:
|
1042
1042
|
BLOCK_M, BLOCK_N = (128, 64)
|
1043
1043
|
else:
|
1044
1044
|
BLOCK_M, BLOCK_N = (32, 64)
|
1045
|
-
elif
|
1045
|
+
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
|
1046
1046
|
if Lq <= 128:
|
1047
1047
|
BLOCK_M, BLOCK_N = (128, 128)
|
1048
1048
|
elif Lq <= 256:
|
@@ -23,10 +23,10 @@ import triton.language as tl
|
|
23
23
|
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
24
24
|
context_attention_fwd,
|
25
25
|
)
|
26
|
-
from sglang.srt.utils import is_hip
|
26
|
+
from sglang.srt.utils import is_cuda, is_hip
|
27
27
|
|
28
|
-
|
29
|
-
if
|
28
|
+
_is_cuda = is_cuda()
|
29
|
+
if _is_cuda:
|
30
30
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
31
31
|
|
32
32
|
_is_hip = is_hip()
|
@@ -74,6 +74,7 @@ def _fwd_kernel(
|
|
74
74
|
BLOCK_M: tl.constexpr,
|
75
75
|
BLOCK_N: tl.constexpr,
|
76
76
|
USE_CUSTOM_MASK: tl.constexpr,
|
77
|
+
IS_CAUSAL: tl.constexpr,
|
77
78
|
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
78
79
|
STORE_TRANSPOSE: tl.constexpr,
|
79
80
|
):
|
@@ -129,6 +130,7 @@ def _fwd_kernel(
|
|
129
130
|
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
130
131
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
131
132
|
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
133
|
+
|
132
134
|
offs_kv_loc = tl.load(
|
133
135
|
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
|
134
136
|
)
|
@@ -196,7 +198,11 @@ def _fwd_kernel(
|
|
196
198
|
|
197
199
|
# stage 2: compute the triangle part
|
198
200
|
|
199
|
-
cur_block_m_end =
|
201
|
+
cur_block_m_end = (
|
202
|
+
cur_seq_len_extend
|
203
|
+
if not IS_CAUSAL
|
204
|
+
else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
205
|
+
)
|
200
206
|
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
201
207
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
202
208
|
mask_n = (start_n + offs_n) < cur_block_m_end
|
@@ -243,12 +249,15 @@ def _fwd_kernel(
|
|
243
249
|
)
|
244
250
|
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
245
251
|
qk = tl.where(custom_mask, qk, float("-inf"))
|
246
|
-
|
252
|
+
elif IS_CAUSAL:
|
247
253
|
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
248
254
|
start_n + offs_n[None, :]
|
249
255
|
)
|
250
256
|
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
251
257
|
qk = tl.where(mask_causual, qk, float("-inf"))
|
258
|
+
else:
|
259
|
+
mask_non_causal = mask_m[:, None] & mask_n[None, :]
|
260
|
+
qk = tl.where(mask_non_causal, qk, float("-inf"))
|
252
261
|
|
253
262
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
254
263
|
re_scale = tl.exp(e_max - n_e_max)
|
@@ -299,6 +308,7 @@ def extend_attention_fwd(
|
|
299
308
|
kv_indptr,
|
300
309
|
kv_indices,
|
301
310
|
custom_mask,
|
311
|
+
is_causal,
|
302
312
|
mask_indptr,
|
303
313
|
max_len_extend,
|
304
314
|
sm_scale=None,
|
@@ -335,12 +345,12 @@ def extend_attention_fwd(
|
|
335
345
|
num_warps = 4
|
336
346
|
|
337
347
|
else:
|
338
|
-
if
|
348
|
+
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
|
339
349
|
if Lq <= 256:
|
340
350
|
BLOCK_M, BLOCK_N = (128, 64)
|
341
351
|
else:
|
342
352
|
BLOCK_M, BLOCK_N = (32, 64)
|
343
|
-
elif
|
353
|
+
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
|
344
354
|
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
345
355
|
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
|
346
356
|
if Lq <= 128:
|
@@ -411,6 +421,7 @@ def extend_attention_fwd(
|
|
411
421
|
Lq=Lq,
|
412
422
|
Lv=Lv,
|
413
423
|
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
424
|
+
IS_CAUSAL=is_causal,
|
414
425
|
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
415
426
|
STORE_TRANSPOSE=_is_hip,
|
416
427
|
num_warps=num_warps,
|
@@ -22,8 +22,12 @@ import torch
|
|
22
22
|
import triton
|
23
23
|
import triton.language as tl
|
24
24
|
|
25
|
-
|
26
|
-
|
25
|
+
from sglang.srt.utils import is_cuda, is_hip
|
26
|
+
|
27
|
+
_is_cuda = is_cuda()
|
28
|
+
_is_hip = is_hip()
|
29
|
+
|
30
|
+
if _is_cuda or _is_hip:
|
27
31
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
28
32
|
|
29
33
|
|
@@ -172,7 +176,7 @@ def context_attention_fwd(
|
|
172
176
|
b_seq_len: [b]
|
173
177
|
out: [b * s, head, head_dim]
|
174
178
|
"""
|
175
|
-
if
|
179
|
+
if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8:
|
176
180
|
BLOCK = 128
|
177
181
|
else:
|
178
182
|
BLOCK = 64
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -19,9 +19,13 @@ from typing import Optional, Tuple, Union
|
|
19
19
|
import torch
|
20
20
|
import torch.nn as nn
|
21
21
|
|
22
|
-
from sglang.srt.
|
22
|
+
from sglang.srt.custom_op import CustomOp
|
23
|
+
from sglang.srt.utils import is_cuda, is_hip
|
24
|
+
|
25
|
+
logger = logging.getLogger(__name__)
|
23
26
|
|
24
|
-
_is_cuda =
|
27
|
+
_is_cuda = is_cuda()
|
28
|
+
_is_hip = is_hip()
|
25
29
|
|
26
30
|
if _is_cuda:
|
27
31
|
from sgl_kernel import (
|
@@ -31,9 +35,20 @@ if _is_cuda:
|
|
31
35
|
rmsnorm,
|
32
36
|
)
|
33
37
|
|
34
|
-
|
38
|
+
if _is_hip:
|
35
39
|
|
36
|
-
|
40
|
+
from aiter.ops.rmsnorm import rms_norm, rmsnorm2d_fwd_with_add
|
41
|
+
|
42
|
+
rmsnorm = rms_norm
|
43
|
+
|
44
|
+
def fused_add_rmsnorm(
|
45
|
+
x: torch.Tensor,
|
46
|
+
residual: torch.Tensor,
|
47
|
+
w: torch.Tensor,
|
48
|
+
eps: float,
|
49
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
50
|
+
rmsnorm2d_fwd_with_add(x, x, residual, residual, w, eps)
|
51
|
+
return x, residual
|
37
52
|
|
38
53
|
|
39
54
|
class RMSNorm(CustomOp):
|
@@ -139,7 +154,7 @@ class Gemma3RMSNorm(nn.Module):
|
|
139
154
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
140
155
|
|
141
156
|
|
142
|
-
if not _is_cuda:
|
157
|
+
if not (_is_cuda or _is_hip):
|
143
158
|
logger.info(
|
144
159
|
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
145
160
|
)
|
sglang/srt/layers/linear.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
|
2
2
|
|
3
|
+
import itertools
|
3
4
|
import logging
|
4
5
|
from abc import abstractmethod
|
5
6
|
from typing import Dict, List, Optional, Tuple
|
@@ -61,12 +62,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
|
|
61
62
|
|
62
63
|
|
63
64
|
def adjust_bitsandbytes_4bit_shard(
|
64
|
-
param: Parameter,
|
65
|
+
param: Parameter, shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
|
65
66
|
) -> Tuple[int, int]:
|
66
67
|
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
67
68
|
|
68
|
-
total, _ =
|
69
|
-
orig_offset, orig_size =
|
69
|
+
total, _ = shard_offsets["total"]
|
70
|
+
orig_offset, orig_size = shard_offsets[loaded_shard_id]
|
70
71
|
|
71
72
|
quantized_total = param.data.shape[0]
|
72
73
|
quantized_offset = orig_offset * quantized_total // total
|
@@ -573,6 +574,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
573
574
|
shard_offsets.append((i, current_shard_offset, output_size))
|
574
575
|
current_shard_offset += output_size
|
575
576
|
packed_dim = getattr(param, "packed_dim", None)
|
577
|
+
|
578
|
+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
576
579
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
577
580
|
# Special case for Quantization.
|
578
581
|
# If quantized, we need to adjust the offset and size to account
|
@@ -585,6 +588,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
585
588
|
param, shard_size, shard_offset
|
586
589
|
)
|
587
590
|
|
591
|
+
if use_bitsandbytes_4bit:
|
592
|
+
index = list(itertools.accumulate([0] + self.output_sizes))
|
593
|
+
orig_offsets = {
|
594
|
+
str(i): (index[i], size)
|
595
|
+
for i, size in enumerate(self.output_sizes)
|
596
|
+
}
|
597
|
+
orig_offsets["total"] = (self.output_size, 0)
|
598
|
+
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
599
|
+
param, orig_offsets, str(shard_id)
|
600
|
+
)
|
601
|
+
|
588
602
|
loaded_weight_shard = loaded_weight.narrow(
|
589
603
|
output_dim, shard_offset, shard_size
|
590
604
|
)
|
@@ -2,6 +2,7 @@ import logging
|
|
2
2
|
from typing import Callable, List, Optional, Tuple
|
3
3
|
|
4
4
|
import torch
|
5
|
+
from torch.nn import Module
|
5
6
|
|
6
7
|
try:
|
7
8
|
from deep_gemm import (
|
@@ -13,8 +14,6 @@ try:
|
|
13
14
|
except ImportError:
|
14
15
|
use_deep_gemm = False
|
15
16
|
|
16
|
-
from torch.nn import Module
|
17
|
-
|
18
17
|
from sglang.srt.custom_op import CustomOp
|
19
18
|
from sglang.srt.distributed import (
|
20
19
|
get_tensor_model_parallel_rank,
|
@@ -37,22 +36,17 @@ from sglang.srt.layers.quantization.base_config import (
|
|
37
36
|
QuantizeMethodBase,
|
38
37
|
)
|
39
38
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
39
|
+
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
40
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
41
|
-
from sglang.srt.utils import DeepEPMode,
|
41
|
+
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
|
42
42
|
|
43
|
-
|
43
|
+
_is_hip = is_hip()
|
44
44
|
|
45
|
-
if
|
46
|
-
from
|
47
|
-
else:
|
48
|
-
from vllm import _custom_ops as vllm_ops
|
45
|
+
if _is_hip:
|
46
|
+
from vllm._custom_ops import scaled_fp8_quant
|
49
47
|
|
50
48
|
logger = logging.getLogger(__name__)
|
51
49
|
|
52
|
-
_is_hip = is_hip()
|
53
|
-
|
54
|
-
_buffer = None
|
55
|
-
|
56
50
|
|
57
51
|
class GroupedGemmRunner(torch.nn.Module):
|
58
52
|
flashinfer_gemm_warpper = None
|
@@ -142,6 +136,7 @@ class EPMoE(torch.nn.Module):
|
|
142
136
|
correction_bias: Optional[torch.Tensor] = None,
|
143
137
|
custom_routing_function: Optional[Callable] = None,
|
144
138
|
activation: str = "silu",
|
139
|
+
routed_scaling_factor: Optional[float] = None,
|
145
140
|
):
|
146
141
|
super().__init__()
|
147
142
|
|
@@ -170,6 +165,7 @@ class EPMoE(torch.nn.Module):
|
|
170
165
|
self.correction_bias = correction_bias
|
171
166
|
self.custom_routing_function = custom_routing_function
|
172
167
|
self.activation = activation
|
168
|
+
self.routed_scaling_factor = routed_scaling_factor
|
173
169
|
|
174
170
|
if quant_config is None:
|
175
171
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
@@ -221,6 +217,7 @@ class EPMoE(torch.nn.Module):
|
|
221
217
|
num_expert_group=self.num_expert_group,
|
222
218
|
correction_bias=self.correction_bias,
|
223
219
|
custom_routing_function=self.custom_routing_function,
|
220
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
224
221
|
)
|
225
222
|
|
226
223
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
@@ -740,20 +737,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
740
737
|
)
|
741
738
|
|
742
739
|
for expert in range(layer.num_experts_per_partition):
|
743
|
-
|
744
|
-
w13_weight[expert, :, :]
|
745
|
-
|
746
|
-
|
747
|
-
w2_weight[expert, :, :]
|
748
|
-
|
749
|
-
)
|
750
|
-
else:
|
751
|
-
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
752
|
-
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
753
|
-
)
|
754
|
-
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
755
|
-
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
756
|
-
)
|
740
|
+
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
741
|
+
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
742
|
+
)
|
743
|
+
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
744
|
+
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
745
|
+
)
|
757
746
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
758
747
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
759
748
|
return
|
@@ -813,6 +802,7 @@ class DeepEPMoE(EPMoE):
|
|
813
802
|
correction_bias: Optional[torch.Tensor] = None,
|
814
803
|
custom_routing_function: Optional[Callable] = None,
|
815
804
|
activation: str = "silu",
|
805
|
+
routed_scaling_factor: Optional[float] = None,
|
816
806
|
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
817
807
|
):
|
818
808
|
super().__init__(
|
@@ -831,6 +821,7 @@ class DeepEPMoE(EPMoE):
|
|
831
821
|
correction_bias,
|
832
822
|
custom_routing_function,
|
833
823
|
activation,
|
824
|
+
routed_scaling_factor,
|
834
825
|
)
|
835
826
|
self.deepep_mode = deepep_mode
|
836
827
|
if self.deepep_mode.enable_low_latency():
|
@@ -986,9 +977,6 @@ class DeepEPMoE(EPMoE):
|
|
986
977
|
):
|
987
978
|
assert self.quant_method is not None
|
988
979
|
assert self.activation == "silu"
|
989
|
-
assert (
|
990
|
-
hidden_states_fp8[0].size(0) % 4 == 0
|
991
|
-
), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"
|
992
980
|
|
993
981
|
# GroupGemm-0
|
994
982
|
num_groups, m, k = hidden_states_fp8[0].size()
|
@@ -26,6 +26,7 @@ def fused_moe_forward_native(
|
|
26
26
|
apply_router_weight_on_input: bool = False,
|
27
27
|
inplace: bool = True,
|
28
28
|
no_combine: bool = False,
|
29
|
+
routed_scaling_factor: Optional[float] = None,
|
29
30
|
) -> torch.Tensor:
|
30
31
|
|
31
32
|
if apply_router_weight_on_input:
|
@@ -41,6 +42,7 @@ def fused_moe_forward_native(
|
|
41
42
|
num_expert_group=num_expert_group,
|
42
43
|
custom_routing_function=custom_routing_function,
|
43
44
|
correction_bias=correction_bias,
|
45
|
+
routed_scaling_factor=routed_scaling_factor,
|
44
46
|
torch_native=True,
|
45
47
|
)
|
46
48
|
|
@@ -71,6 +73,7 @@ def moe_forward_native(
|
|
71
73
|
custom_routing_function: Optional[Callable] = None,
|
72
74
|
correction_bias: Optional[torch.Tensor] = None,
|
73
75
|
activation: str = "silu",
|
76
|
+
routed_scaling_factor: Optional[float] = None,
|
74
77
|
) -> torch.Tensor:
|
75
78
|
|
76
79
|
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
@@ -86,6 +89,7 @@ def moe_forward_native(
|
|
86
89
|
custom_routing_function=custom_routing_function,
|
87
90
|
correction_bias=correction_bias,
|
88
91
|
torch_native=True,
|
92
|
+
routed_scaling_factor=routed_scaling_factor,
|
89
93
|
)
|
90
94
|
|
91
95
|
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
|
@@ -13,6 +13,7 @@ import triton
|
|
13
13
|
import triton.language as tl
|
14
14
|
|
15
15
|
from sglang.srt.layers.moe.topk import select_experts
|
16
|
+
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
16
17
|
from sglang.srt.utils import (
|
17
18
|
direct_register_custom_op,
|
18
19
|
get_bool_env_var,
|
@@ -22,28 +23,25 @@ from sglang.srt.utils import (
|
|
22
23
|
)
|
23
24
|
|
24
25
|
_is_hip = is_hip()
|
25
|
-
|
26
|
-
|
27
|
-
logger = logging.getLogger(__name__)
|
28
|
-
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
29
|
-
|
30
|
-
enable_moe_align_block_size_triton = bool(
|
31
|
-
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
32
|
-
)
|
33
|
-
|
34
26
|
_is_cuda = is_cuda()
|
35
27
|
|
36
28
|
if _is_cuda:
|
37
29
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
38
|
-
|
39
|
-
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
40
30
|
else:
|
41
31
|
from vllm import _custom_ops as vllm_ops
|
32
|
+
from vllm._custom_ops import scaled_fp8_quant
|
42
33
|
|
43
34
|
if _is_cuda or _is_hip:
|
44
35
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
45
36
|
|
46
37
|
|
38
|
+
logger = logging.getLogger(__name__)
|
39
|
+
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
40
|
+
enable_moe_align_block_size_triton = bool(
|
41
|
+
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
42
|
+
)
|
43
|
+
|
44
|
+
|
47
45
|
@triton.jit
|
48
46
|
def write_zeros_to_output(
|
49
47
|
c_ptr,
|
@@ -770,14 +768,9 @@ def invoke_fused_moe_kernel(
|
|
770
768
|
# activation tensor-wise fp8 quantization, dynamic or static
|
771
769
|
padded_size = padding_size
|
772
770
|
# activations apply per-token quantization when weights apply per-channel quantization by default
|
773
|
-
|
774
|
-
A, A_scale =
|
775
|
-
|
776
|
-
)
|
777
|
-
else:
|
778
|
-
A, A_scale = vllm_ops.scaled_fp8_quant(
|
779
|
-
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
780
|
-
)
|
771
|
+
A, A_scale = scaled_fp8_quant(
|
772
|
+
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
773
|
+
)
|
781
774
|
else:
|
782
775
|
# activation block-wise fp8 quantization
|
783
776
|
assert len(block_shape) == 2
|
@@ -1554,6 +1547,7 @@ def fused_moe(
|
|
1554
1547
|
a2_scale: Optional[torch.Tensor] = None,
|
1555
1548
|
block_shape: Optional[List[int]] = None,
|
1556
1549
|
no_combine: bool = False,
|
1550
|
+
routed_scaling_factor: Optional[float] = None,
|
1557
1551
|
) -> torch.Tensor:
|
1558
1552
|
"""
|
1559
1553
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
@@ -1608,6 +1602,7 @@ def fused_moe(
|
|
1608
1602
|
topk_group=topk_group,
|
1609
1603
|
num_expert_group=num_expert_group,
|
1610
1604
|
custom_routing_function=custom_routing_function,
|
1605
|
+
routed_scaling_factor=routed_scaling_factor,
|
1611
1606
|
)
|
1612
1607
|
|
1613
1608
|
return fused_experts(
|
@@ -131,6 +131,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
131
131
|
apply_router_weight_on_input: bool = False,
|
132
132
|
inplace: bool = True,
|
133
133
|
no_combine: bool = False,
|
134
|
+
routed_scaling_factor: Optional[float] = None,
|
134
135
|
) -> torch.Tensor:
|
135
136
|
return self.forward(
|
136
137
|
x=x,
|
@@ -147,6 +148,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
147
148
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
148
149
|
inplace=inplace,
|
149
150
|
no_combine=no_combine,
|
151
|
+
routed_scaling_factor=routed_scaling_factor,
|
150
152
|
)
|
151
153
|
|
152
154
|
def forward_cuda(
|
@@ -165,6 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
165
167
|
apply_router_weight_on_input: bool = False,
|
166
168
|
inplace: bool = True,
|
167
169
|
no_combine: bool = False,
|
170
|
+
routed_scaling_factor: Optional[float] = None,
|
168
171
|
) -> torch.Tensor:
|
169
172
|
topk_weights, topk_ids = select_experts(
|
170
173
|
hidden_states=x,
|
@@ -176,6 +179,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
176
179
|
num_expert_group=num_expert_group,
|
177
180
|
custom_routing_function=custom_routing_function,
|
178
181
|
correction_bias=correction_bias,
|
182
|
+
routed_scaling_factor=routed_scaling_factor,
|
179
183
|
)
|
180
184
|
|
181
185
|
if _is_hip and get_bool_env_var("CK_MOE"):
|
@@ -284,6 +288,7 @@ class FusedMoE(torch.nn.Module):
|
|
284
288
|
use_presharded_weights: bool = False,
|
285
289
|
inplace: bool = True,
|
286
290
|
no_combine: bool = False,
|
291
|
+
routed_scaling_factor: Optional[float] = None,
|
287
292
|
):
|
288
293
|
super().__init__()
|
289
294
|
|
@@ -293,6 +298,7 @@ class FusedMoE(torch.nn.Module):
|
|
293
298
|
self.tp_size = (
|
294
299
|
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
295
300
|
)
|
301
|
+
self.routed_scaling_factor = routed_scaling_factor
|
296
302
|
self.top_k = top_k
|
297
303
|
self.num_experts = num_experts
|
298
304
|
assert intermediate_size % self.tp_size == 0
|
@@ -637,6 +643,7 @@ class FusedMoE(torch.nn.Module):
|
|
637
643
|
correction_bias=self.correction_bias,
|
638
644
|
activation=self.activation,
|
639
645
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
646
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
640
647
|
)
|
641
648
|
|
642
649
|
if self.reduce_results and self.tp_size > 1:
|