sglang 0.4.0.post2__py3-none-any.whl → 0.4.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +0 -12
- sglang/bench_one_batch.py +0 -12
- sglang/bench_serving.py +11 -2
- sglang/lang/backend/openai.py +10 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -0
- sglang/srt/layers/attention/flashinfer_backend.py +49 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +124 -99
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
- sglang/srt/layers/moe/topk.py +205 -0
- sglang/srt/layers/quantization/__init__.py +3 -3
- sglang/srt/layers/quantization/fp8.py +169 -32
- sglang/srt/layers/quantization/fp8_kernel.py +292 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/torchao_utils.py +11 -15
- sglang/srt/managers/schedule_batch.py +16 -10
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +13 -16
- sglang/srt/managers/tokenizer_manager.py +130 -111
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_loader/loader.py +22 -11
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +19 -0
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +23 -0
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_params.py +9 -2
- sglang/srt/server.py +21 -37
- sglang/srt/utils.py +33 -44
- sglang/test/test_block_fp8.py +341 -0
- sglang/version.py +1 -1
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/METADATA +4 -4
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/RECORD +52 -48
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/top_level.txt +0 -0
@@ -322,18 +322,6 @@ def throughput_test(
|
|
322
322
|
)
|
323
323
|
time.sleep(0.5)
|
324
324
|
|
325
|
-
try:
|
326
|
-
import os
|
327
|
-
import pwd
|
328
|
-
|
329
|
-
from gemlite.core import GemLiteLinearTriton
|
330
|
-
|
331
|
-
GemLiteLinearTriton.cache_config(
|
332
|
-
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
|
333
|
-
)
|
334
|
-
except ImportError:
|
335
|
-
pass
|
336
|
-
|
337
325
|
logging.info("\nBenchmark...")
|
338
326
|
result = throughput_test_once(
|
339
327
|
backend_name=bench_args.backend,
|
sglang/bench_one_batch.py
CHANGED
@@ -386,18 +386,6 @@ def latency_test(
|
|
386
386
|
server_args.device,
|
387
387
|
)
|
388
388
|
|
389
|
-
try:
|
390
|
-
import os
|
391
|
-
import pwd
|
392
|
-
|
393
|
-
from gemlite.core import GemLiteLinearTriton
|
394
|
-
|
395
|
-
GemLiteLinearTriton.cache_config(
|
396
|
-
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
|
397
|
-
)
|
398
|
-
except ImportError:
|
399
|
-
pass
|
400
|
-
|
401
389
|
rank_print("Benchmark ...")
|
402
390
|
|
403
391
|
# Run the sweep
|
sglang/bench_serving.py
CHANGED
@@ -897,6 +897,7 @@ async def benchmark(
|
|
897
897
|
else:
|
898
898
|
raise ValueError(f"Unknown backend: {backend}")
|
899
899
|
|
900
|
+
# Limit concurrency
|
900
901
|
# From https://github.com/vllm-project/vllm/pull/9390
|
901
902
|
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
902
903
|
|
@@ -906,6 +907,7 @@ async def benchmark(
|
|
906
907
|
async with semaphore:
|
907
908
|
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
908
909
|
|
910
|
+
# Warmup
|
909
911
|
print("Starting initial single prompt test run...")
|
910
912
|
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
911
913
|
test_input = RequestFuncInput(
|
@@ -926,8 +928,13 @@ async def benchmark(
|
|
926
928
|
else:
|
927
929
|
print("Initial test run completed. Starting main benchmark run...")
|
928
930
|
|
929
|
-
|
931
|
+
# Flush cache
|
932
|
+
if "sglang" in backend:
|
933
|
+
requests.post(base_url + "/flush_cache")
|
930
934
|
|
935
|
+
time.sleep(1.0)
|
936
|
+
|
937
|
+
# Start profiler
|
931
938
|
if profile:
|
932
939
|
print("Starting profiler...")
|
933
940
|
profile_output = await async_request_profile(
|
@@ -938,6 +945,7 @@ async def benchmark(
|
|
938
945
|
|
939
946
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
940
947
|
|
948
|
+
# Run all requests
|
941
949
|
benchmark_start_time = time.perf_counter()
|
942
950
|
tasks: List[asyncio.Task] = []
|
943
951
|
async for request in get_request(input_requests, request_rate):
|
@@ -958,6 +966,7 @@ async def benchmark(
|
|
958
966
|
)
|
959
967
|
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
960
968
|
|
969
|
+
# Stop profiler
|
961
970
|
if profile:
|
962
971
|
print("Stopping profiler...")
|
963
972
|
profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
|
@@ -967,8 +976,8 @@ async def benchmark(
|
|
967
976
|
if pbar is not None:
|
968
977
|
pbar.close()
|
969
978
|
|
979
|
+
# Compute metrics and print results
|
970
980
|
benchmark_duration = time.perf_counter() - benchmark_start_time
|
971
|
-
|
972
981
|
metrics, output_lens = calculate_metrics(
|
973
982
|
input_requests=input_requests,
|
974
983
|
outputs=outputs,
|
sglang/lang/backend/openai.py
CHANGED
@@ -366,6 +366,11 @@ class OpenAI(BaseBackend):
|
|
366
366
|
def openai_completion(
|
367
367
|
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
368
368
|
):
|
369
|
+
# if "ebnf" is in kwargs, warn and remove
|
370
|
+
if "ebnf" in kwargs:
|
371
|
+
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
372
|
+
del kwargs["ebnf"]
|
373
|
+
|
369
374
|
for attempt in range(retries):
|
370
375
|
try:
|
371
376
|
if is_chat:
|
@@ -398,6 +403,11 @@ def openai_completion(
|
|
398
403
|
def openai_completion_stream(
|
399
404
|
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
400
405
|
):
|
406
|
+
# if "ebnf" is in kwargs, warn and remove
|
407
|
+
if "ebnf" in kwargs:
|
408
|
+
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
409
|
+
del kwargs["ebnf"]
|
410
|
+
|
401
411
|
for attempt in range(retries):
|
402
412
|
try:
|
403
413
|
if is_chat:
|
sglang/srt/aio_rwlock.py
ADDED
@@ -0,0 +1,100 @@
|
|
1
|
+
import asyncio
|
2
|
+
|
3
|
+
|
4
|
+
class RWLock:
|
5
|
+
def __init__(self):
|
6
|
+
# Protects internal state
|
7
|
+
self._lock = asyncio.Lock()
|
8
|
+
|
9
|
+
# Condition variable used to wait for state changes
|
10
|
+
self._cond = asyncio.Condition(self._lock)
|
11
|
+
|
12
|
+
# Number of readers currently holding the lock
|
13
|
+
self._readers = 0
|
14
|
+
|
15
|
+
# Whether a writer is currently holding the lock
|
16
|
+
self._writer_active = False
|
17
|
+
|
18
|
+
# How many writers are queued waiting for a turn
|
19
|
+
self._waiting_writers = 0
|
20
|
+
|
21
|
+
@property
|
22
|
+
def reader_lock(self):
|
23
|
+
"""
|
24
|
+
A context manager for acquiring a shared (reader) lock.
|
25
|
+
|
26
|
+
Example:
|
27
|
+
async with rwlock.reader_lock:
|
28
|
+
# read-only access
|
29
|
+
"""
|
30
|
+
return _ReaderLock(self)
|
31
|
+
|
32
|
+
@property
|
33
|
+
def writer_lock(self):
|
34
|
+
"""
|
35
|
+
A context manager for acquiring an exclusive (writer) lock.
|
36
|
+
|
37
|
+
Example:
|
38
|
+
async with rwlock.writer_lock:
|
39
|
+
# exclusive access
|
40
|
+
"""
|
41
|
+
return _WriterLock(self)
|
42
|
+
|
43
|
+
async def acquire_reader(self):
|
44
|
+
async with self._lock:
|
45
|
+
# Wait until there is no active writer or waiting writer
|
46
|
+
# to ensure fairness.
|
47
|
+
while self._writer_active or self._waiting_writers > 0:
|
48
|
+
await self._cond.wait()
|
49
|
+
self._readers += 1
|
50
|
+
|
51
|
+
async def release_reader(self):
|
52
|
+
async with self._lock:
|
53
|
+
self._readers -= 1
|
54
|
+
# If this was the last reader, wake up anyone waiting
|
55
|
+
# (potentially a writer or new readers).
|
56
|
+
if self._readers == 0:
|
57
|
+
self._cond.notify_all()
|
58
|
+
|
59
|
+
async def acquire_writer(self):
|
60
|
+
async with self._lock:
|
61
|
+
# Increment the count of writers waiting
|
62
|
+
self._waiting_writers += 1
|
63
|
+
try:
|
64
|
+
# Wait while either a writer is active or readers are present
|
65
|
+
while self._writer_active or self._readers > 0:
|
66
|
+
await self._cond.wait()
|
67
|
+
self._writer_active = True
|
68
|
+
finally:
|
69
|
+
# Decrement waiting writers only after we've acquired the writer lock
|
70
|
+
self._waiting_writers -= 1
|
71
|
+
|
72
|
+
async def release_writer(self):
|
73
|
+
async with self._lock:
|
74
|
+
self._writer_active = False
|
75
|
+
# Wake up anyone waiting (readers or writers)
|
76
|
+
self._cond.notify_all()
|
77
|
+
|
78
|
+
|
79
|
+
class _ReaderLock:
|
80
|
+
def __init__(self, rwlock: RWLock):
|
81
|
+
self._rwlock = rwlock
|
82
|
+
|
83
|
+
async def __aenter__(self):
|
84
|
+
await self._rwlock.acquire_reader()
|
85
|
+
return self
|
86
|
+
|
87
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
88
|
+
await self._rwlock.release_reader()
|
89
|
+
|
90
|
+
|
91
|
+
class _WriterLock:
|
92
|
+
def __init__(self, rwlock: RWLock):
|
93
|
+
self._rwlock = rwlock
|
94
|
+
|
95
|
+
async def __aenter__(self):
|
96
|
+
await self._rwlock.acquire_writer()
|
97
|
+
return self
|
98
|
+
|
99
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
100
|
+
await self._rwlock.release_writer()
|
@@ -94,7 +94,10 @@ class ModelConfig:
|
|
94
94
|
)
|
95
95
|
|
96
96
|
# FIXME: temporary special judge for MLA architecture
|
97
|
-
if
|
97
|
+
if (
|
98
|
+
"DeepseekV2ForCausalLM" in self.hf_config.architectures
|
99
|
+
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
|
100
|
+
):
|
98
101
|
self.head_dim = 256
|
99
102
|
self.attention_arch = AttentionArch.MLA
|
100
103
|
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
@@ -124,8 +127,12 @@ class ModelConfig:
|
|
124
127
|
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
125
128
|
self.vocab_size = self.hf_text_config.vocab_size
|
126
129
|
|
130
|
+
# Veirfy quantization
|
127
131
|
self._verify_quantization()
|
128
132
|
|
133
|
+
# Multimodel attrs
|
134
|
+
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
|
135
|
+
|
129
136
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
130
137
|
def get_total_num_kv_heads(self) -> int:
|
131
138
|
"""Returns the total number of KV heads."""
|
@@ -126,6 +126,12 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
126
126
|
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
127
127
|
)
|
128
128
|
return None
|
129
|
+
elif key_type == "ebnf":
|
130
|
+
try:
|
131
|
+
ctx = self.grammar_compiler.compile_grammar(key_string)
|
132
|
+
except RuntimeError as e:
|
133
|
+
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
|
134
|
+
return None
|
129
135
|
elif key_type == "regex":
|
130
136
|
logger.warning(
|
131
137
|
"regex hasn't been supported by xgrammar yet. This is skipped."
|
@@ -18,11 +18,7 @@ import triton.language as tl
|
|
18
18
|
from sglang.global_config import global_config
|
19
19
|
from sglang.srt.layers.attention import AttentionBackend
|
20
20
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
21
|
-
from sglang.srt.utils import
|
22
|
-
get_bool_env_var,
|
23
|
-
is_flashinfer_available,
|
24
|
-
should_use_tensor_core,
|
25
|
-
)
|
21
|
+
from sglang.srt.utils import is_flashinfer_available
|
26
22
|
|
27
23
|
if TYPE_CHECKING:
|
28
24
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -731,3 +727,51 @@ def create_flashinfer_kv_indices_triton(
|
|
731
727
|
mask=mask,
|
732
728
|
)
|
733
729
|
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
730
|
+
|
731
|
+
|
732
|
+
def should_use_tensor_core(
|
733
|
+
kv_cache_dtype: torch.dtype,
|
734
|
+
num_attention_heads: int,
|
735
|
+
num_kv_heads: int,
|
736
|
+
) -> bool:
|
737
|
+
"""
|
738
|
+
Determine whether to use tensor cores for attention computation.
|
739
|
+
|
740
|
+
Args:
|
741
|
+
kv_cache_dtype: Data type of the KV cache
|
742
|
+
num_attention_heads: Number of attention heads
|
743
|
+
num_kv_heads: Number of key/value heads
|
744
|
+
|
745
|
+
Returns:
|
746
|
+
bool: Whether to use tensor cores
|
747
|
+
"""
|
748
|
+
# Try to use environment variable first
|
749
|
+
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
|
750
|
+
if env_override is not None:
|
751
|
+
return env_override.lower() == "true"
|
752
|
+
|
753
|
+
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
754
|
+
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
755
|
+
try:
|
756
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
757
|
+
|
758
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
759
|
+
num_attention_heads,
|
760
|
+
num_kv_heads,
|
761
|
+
):
|
762
|
+
return True
|
763
|
+
else:
|
764
|
+
return False
|
765
|
+
except (ImportError, AttributeError):
|
766
|
+
pass
|
767
|
+
|
768
|
+
# Calculate GQA group size
|
769
|
+
gqa_group_size = num_attention_heads // num_kv_heads
|
770
|
+
|
771
|
+
# Determine based on dtype and GQA group size
|
772
|
+
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
773
|
+
return True
|
774
|
+
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
|
775
|
+
return gqa_group_size > 4
|
776
|
+
else:
|
777
|
+
return False
|
@@ -292,27 +292,33 @@ def extend_attention_fwd(
|
|
292
292
|
BLOCK_DPE = 0
|
293
293
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
294
294
|
|
295
|
-
if
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
BLOCK_M, BLOCK_N = (32, 64)
|
300
|
-
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
301
|
-
if Lq <= 128:
|
302
|
-
BLOCK_M, BLOCK_N = (128, 128)
|
303
|
-
elif Lq <= 256:
|
304
|
-
BLOCK_M, BLOCK_N = (64, 64)
|
305
|
-
else:
|
306
|
-
BLOCK_M, BLOCK_N = (32, 64)
|
295
|
+
if is_hip_:
|
296
|
+
BLOCK_M, BLOCK_N = (64, 64)
|
297
|
+
num_warps = 4
|
298
|
+
|
307
299
|
else:
|
308
|
-
|
300
|
+
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
301
|
+
if Lq <= 256:
|
302
|
+
BLOCK_M, BLOCK_N = (128, 64)
|
303
|
+
else:
|
304
|
+
BLOCK_M, BLOCK_N = (32, 64)
|
305
|
+
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
306
|
+
if Lq <= 128:
|
307
|
+
BLOCK_M, BLOCK_N = (128, 128)
|
308
|
+
elif Lq <= 256:
|
309
|
+
BLOCK_M, BLOCK_N = (64, 64)
|
310
|
+
else:
|
311
|
+
BLOCK_M, BLOCK_N = (32, 64)
|
312
|
+
else:
|
313
|
+
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
314
|
+
|
315
|
+
num_warps = 4 if Lk <= 64 else 8
|
309
316
|
|
310
317
|
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
311
318
|
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
312
319
|
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
313
320
|
|
314
321
|
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
315
|
-
num_warps = 4 if Lk <= 64 else 8
|
316
322
|
num_stages = 1
|
317
323
|
|
318
324
|
extra_kargs = {}
|
sglang/srt/layers/linear.py
CHANGED
@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
30
30
|
QuantizationConfig,
|
31
31
|
QuantizeMethodBase,
|
32
32
|
)
|
33
|
+
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
|
33
34
|
from sglang.srt.utils import set_weight_attrs
|
34
35
|
|
35
36
|
logger = logging.getLogger(__name__)
|
@@ -628,8 +629,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
628
629
|
assert loaded_shard_id < len(self.output_sizes)
|
629
630
|
|
630
631
|
tp_size = get_tensor_model_parallel_world_size()
|
631
|
-
|
632
|
-
|
632
|
+
|
633
|
+
if isinstance(param, BlockQuantScaleParameter):
|
634
|
+
weight_block_size = self.quant_method.quant_config.weight_block_size
|
635
|
+
block_n, _ = weight_block_size[0], weight_block_size[1]
|
636
|
+
shard_offset = (
|
637
|
+
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
|
638
|
+
) // tp_size
|
639
|
+
shard_size = (
|
640
|
+
(self.output_sizes[loaded_shard_id] + block_n - 1) // block_n // tp_size
|
641
|
+
)
|
642
|
+
else:
|
643
|
+
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
644
|
+
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
633
645
|
|
634
646
|
param.load_merged_column_weight(
|
635
647
|
loaded_weight=loaded_weight,
|
@@ -795,6 +807,12 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
795
807
|
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
|
796
808
|
shard_size = self._get_shard_size_mapping(loaded_shard_id)
|
797
809
|
|
810
|
+
if isinstance(param, BlockQuantScaleParameter):
|
811
|
+
weight_block_size = self.quant_method.quant_config.weight_block_size
|
812
|
+
block_n, _ = weight_block_size[0], weight_block_size[1]
|
813
|
+
shard_offset = (shard_offset + block_n - 1) // block_n
|
814
|
+
shard_size = (shard_size + block_n - 1) // block_n
|
815
|
+
|
798
816
|
param.load_qkv_weight(
|
799
817
|
loaded_weight=loaded_weight,
|
800
818
|
num_heads=self.num_kv_head_replicas,
|
@@ -12,15 +12,15 @@ from vllm.model_executor.custom_op import CustomOp
|
|
12
12
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
13
13
|
|
14
14
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
15
|
-
from sglang.srt.layers.ep_moe.kernels import (
|
15
|
+
from sglang.srt.layers.moe.ep_moe.kernels import (
|
16
16
|
grouped_gemm_triton,
|
17
17
|
post_reorder_triton_kernel,
|
18
18
|
pre_reorder_triton_kernel,
|
19
19
|
run_moe_ep_preproess,
|
20
20
|
silu_and_mul_triton_kernel,
|
21
21
|
)
|
22
|
-
from sglang.srt.layers.fused_moe_triton.
|
23
|
-
from sglang.srt.layers.
|
22
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
23
|
+
from sglang.srt.layers.moe.topk import select_experts
|
24
24
|
from sglang.srt.layers.quantization.base_config import (
|
25
25
|
QuantizationConfig,
|
26
26
|
QuantizeMethodBase,
|
@@ -113,6 +113,7 @@ class EPMoE(torch.nn.Module):
|
|
113
113
|
quant_config: Optional[QuantizationConfig] = None,
|
114
114
|
tp_size: Optional[int] = None,
|
115
115
|
prefix: str = "",
|
116
|
+
correction_bias: Optional[torch.Tensor] = None,
|
116
117
|
):
|
117
118
|
super().__init__()
|
118
119
|
|
@@ -138,6 +139,7 @@ class EPMoE(torch.nn.Module):
|
|
138
139
|
assert num_expert_group is not None and topk_group is not None
|
139
140
|
self.num_expert_group = num_expert_group
|
140
141
|
self.topk_group = topk_group
|
142
|
+
self.correction_bias = correction_bias
|
141
143
|
|
142
144
|
if quant_config is None:
|
143
145
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
@@ -170,13 +172,15 @@ class EPMoE(torch.nn.Module):
|
|
170
172
|
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
171
173
|
)
|
172
174
|
|
173
|
-
topk_weights, topk_ids =
|
174
|
-
hidden_states,
|
175
|
-
router_logits,
|
176
|
-
self.top_k,
|
177
|
-
self.
|
178
|
-
self.
|
179
|
-
self.
|
175
|
+
topk_weights, topk_ids = select_experts(
|
176
|
+
hidden_states=hidden_states,
|
177
|
+
router_logits=router_logits,
|
178
|
+
top_k=self.top_k,
|
179
|
+
use_grouped_topk=self.use_grouped_topk,
|
180
|
+
renormalize=self.renormalize,
|
181
|
+
topk_group=self.topk_group,
|
182
|
+
num_expert_group=self.num_expert_group,
|
183
|
+
correction_bias=self.correction_bias,
|
180
184
|
)
|
181
185
|
|
182
186
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
@@ -297,35 +301,6 @@ class EPMoE(torch.nn.Module):
|
|
297
301
|
)
|
298
302
|
return output
|
299
303
|
|
300
|
-
def select_experts(
|
301
|
-
self,
|
302
|
-
hidden_states: torch.Tensor,
|
303
|
-
router_logits: torch.Tensor,
|
304
|
-
top_k: int,
|
305
|
-
renormalize: bool,
|
306
|
-
topk_group: Optional[int] = None,
|
307
|
-
num_expert_group: Optional[int] = None,
|
308
|
-
):
|
309
|
-
if self.use_grouped_topk:
|
310
|
-
assert topk_group is not None
|
311
|
-
assert num_expert_group is not None
|
312
|
-
topk_weights, topk_ids = grouped_topk(
|
313
|
-
hidden_states=hidden_states,
|
314
|
-
gating_output=router_logits,
|
315
|
-
topk=top_k,
|
316
|
-
renormalize=renormalize,
|
317
|
-
num_expert_group=num_expert_group,
|
318
|
-
topk_group=topk_group,
|
319
|
-
)
|
320
|
-
else:
|
321
|
-
topk_weights, topk_ids = fused_topk(
|
322
|
-
hidden_states=hidden_states,
|
323
|
-
gating_output=router_logits,
|
324
|
-
topk=top_k,
|
325
|
-
renormalize=renormalize,
|
326
|
-
)
|
327
|
-
return topk_weights, topk_ids.to(torch.int32)
|
328
|
-
|
329
304
|
@classmethod
|
330
305
|
def make_expert_params_mapping(
|
331
306
|
cls,
|
@@ -0,0 +1,46 @@
|
|
1
|
+
"""
|
2
|
+
Torch-native implementation for FusedMoE. This is used for torch.compile.
|
3
|
+
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import Callable, Optional
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from torch.nn import functional as F
|
10
|
+
|
11
|
+
from sglang.srt.layers.moe.topk import select_experts
|
12
|
+
|
13
|
+
|
14
|
+
def fused_moe_forward_native(
|
15
|
+
layer: torch.nn.Module,
|
16
|
+
x: torch.Tensor,
|
17
|
+
use_grouped_topk: bool,
|
18
|
+
top_k: int,
|
19
|
+
router_logits: torch.Tensor,
|
20
|
+
renormalize: bool,
|
21
|
+
topk_group: Optional[int] = None,
|
22
|
+
num_expert_group: Optional[int] = None,
|
23
|
+
custom_routing_function: Optional[Callable] = None,
|
24
|
+
correction_bias: Optional[torch.Tensor] = None,
|
25
|
+
) -> torch.Tensor:
|
26
|
+
topk_weights, topk_ids = select_experts(
|
27
|
+
hidden_states=x,
|
28
|
+
router_logits=router_logits,
|
29
|
+
use_grouped_topk=use_grouped_topk,
|
30
|
+
top_k=top_k,
|
31
|
+
renormalize=renormalize,
|
32
|
+
topk_group=topk_group,
|
33
|
+
num_expert_group=num_expert_group,
|
34
|
+
custom_routing_function=custom_routing_function,
|
35
|
+
correction_bias=correction_bias,
|
36
|
+
torch_native=True,
|
37
|
+
)
|
38
|
+
|
39
|
+
w13_weights = layer.w13_weight[topk_ids]
|
40
|
+
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
41
|
+
w2_weights = layer.w2_weight[topk_ids]
|
42
|
+
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
43
|
+
x1 = F.silu(x1)
|
44
|
+
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
45
|
+
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
46
|
+
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
@@ -1,14 +1,12 @@
|
|
1
1
|
from contextlib import contextmanager
|
2
2
|
from typing import Any, Dict, Optional
|
3
3
|
|
4
|
-
import sglang.srt.layers.fused_moe_triton.fused_moe # noqa
|
5
|
-
from sglang.srt.layers.fused_moe_triton.fused_moe import (
|
4
|
+
import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa
|
5
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
6
6
|
fused_experts,
|
7
|
-
fused_topk,
|
8
7
|
get_config_file_name,
|
9
|
-
grouped_topk,
|
10
8
|
)
|
11
|
-
from sglang.srt.layers.fused_moe_triton.layer import (
|
9
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
12
10
|
FusedMoE,
|
13
11
|
FusedMoEMethodBase,
|
14
12
|
FusedMoeWeightScaleSupported,
|
@@ -37,8 +35,6 @@ __all__ = [
|
|
37
35
|
"override_config",
|
38
36
|
"get_config",
|
39
37
|
"fused_moe",
|
40
|
-
"fused_topk",
|
41
38
|
"fused_experts",
|
42
39
|
"get_config_file_name",
|
43
|
-
"grouped_topk",
|
44
40
|
]
|