sglang 0.4.9__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +12 -1
- sglang/srt/conversation.py +35 -1
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/layers/communicator.py +3 -1
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +58 -0
- sglang/srt/layers/moe/ep_moe/layer.py +140 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +135 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +8 -1
- sglang/srt/managers/mm_utils.py +4 -2
- sglang/srt/managers/schedule_batch.py +1 -1
- sglang/srt/managers/scheduler.py +17 -5
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +113 -63
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/models/deepseek_v2.py +16 -2
- sglang/srt/models/mllama4.py +360 -79
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +62 -60
- sglang/srt/server_args.py +15 -0
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +37 -17
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +4 -3
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +47 -43
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -217,11 +217,13 @@ class ServerArgs:
|
|
217
217
|
hicache_ratio: float = 2.0
|
218
218
|
hicache_size: int = 0
|
219
219
|
hicache_write_policy: str = "write_through_selective"
|
220
|
+
hicache_io_backend: str = ""
|
220
221
|
flashinfer_mla_disable_ragged: bool = False
|
221
222
|
disable_shared_experts_fusion: bool = False
|
222
223
|
disable_chunked_prefix_cache: bool = False
|
223
224
|
disable_fast_image_processor: bool = False
|
224
225
|
enable_return_hidden_states: bool = False
|
226
|
+
enable_triton_kernel_moe: bool = False
|
225
227
|
warmups: Optional[str] = None
|
226
228
|
|
227
229
|
# Debug tensor dumps
|
@@ -706,6 +708,7 @@ class ServerArgs:
|
|
706
708
|
"w8a8_fp8",
|
707
709
|
"moe_wna16",
|
708
710
|
"qoq",
|
711
|
+
"w4afp8",
|
709
712
|
],
|
710
713
|
help="The quantization method.",
|
711
714
|
)
|
@@ -1529,6 +1532,13 @@ class ServerArgs:
|
|
1529
1532
|
default=ServerArgs.hicache_write_policy,
|
1530
1533
|
help="The write policy of hierarchical cache.",
|
1531
1534
|
)
|
1535
|
+
parser.add_argument(
|
1536
|
+
"--hicache-io-backend",
|
1537
|
+
type=str,
|
1538
|
+
choices=["direct", "kernel"],
|
1539
|
+
default=ServerArgs.hicache_io_backend,
|
1540
|
+
help="The IO backend for KV cache transfer between CPU and GPU",
|
1541
|
+
)
|
1532
1542
|
parser.add_argument(
|
1533
1543
|
"--flashinfer-mla-disable-ragged",
|
1534
1544
|
action="store_true",
|
@@ -1554,6 +1564,11 @@ class ServerArgs:
|
|
1554
1564
|
action="store_true",
|
1555
1565
|
help="Enable returning hidden states with responses.",
|
1556
1566
|
)
|
1567
|
+
parser.add_argument(
|
1568
|
+
"--enable-triton-kernel-moe",
|
1569
|
+
action="store_true",
|
1570
|
+
help="Use triton moe grouped gemm kernel.",
|
1571
|
+
)
|
1557
1572
|
parser.add_argument(
|
1558
1573
|
"--warmups",
|
1559
1574
|
type=str,
|
sglang/srt/two_batch_overlap.py
CHANGED
@@ -490,6 +490,7 @@ class TboForwardBatchPreparer:
|
|
490
490
|
output_dict["spec_info"] = output_spec_info
|
491
491
|
for key in [
|
492
492
|
"forward_mode",
|
493
|
+
"is_extend_in_batch",
|
493
494
|
"return_logprob",
|
494
495
|
"req_to_token_pool",
|
495
496
|
"token_to_kv_pool",
|
@@ -550,6 +551,8 @@ class TboForwardBatchPreparer:
|
|
550
551
|
top_p_normalized_logprobs=False,
|
551
552
|
top_p=None,
|
552
553
|
mm_inputs=None,
|
554
|
+
top_logprobs_nums=None,
|
555
|
+
token_ids_logprobs=None,
|
553
556
|
)
|
554
557
|
)
|
555
558
|
|
sglang/srt/utils.py
CHANGED
@@ -15,7 +15,6 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
import base64
|
19
18
|
import builtins
|
20
19
|
import ctypes
|
21
20
|
import dataclasses
|
@@ -68,6 +67,7 @@ from typing import (
|
|
68
67
|
|
69
68
|
import numpy as np
|
70
69
|
import psutil
|
70
|
+
import pybase64
|
71
71
|
import requests
|
72
72
|
import torch
|
73
73
|
import torch.distributed
|
@@ -83,12 +83,7 @@ from torch.func import functional_call
|
|
83
83
|
from torch.library import Library
|
84
84
|
from torch.profiler import ProfilerActivity, profile, record_function
|
85
85
|
from torch.utils._contextlib import _DecoratorContextManager
|
86
|
-
from triton.runtime.cache import
|
87
|
-
FileCacheManager,
|
88
|
-
default_cache_dir,
|
89
|
-
default_dump_dir,
|
90
|
-
default_override_dir,
|
91
|
-
)
|
86
|
+
from triton.runtime.cache import FileCacheManager
|
92
87
|
|
93
88
|
logger = logging.getLogger(__name__)
|
94
89
|
|
@@ -621,7 +616,7 @@ def decode_video_base64(video_base64):
|
|
621
616
|
from PIL import Image
|
622
617
|
|
623
618
|
# Decode the base64 string
|
624
|
-
video_bytes =
|
619
|
+
video_bytes = pybase64.b64decode(video_base64, validate=True)
|
625
620
|
|
626
621
|
# Placeholder for the start indices of each PNG image
|
627
622
|
img_starts = []
|
@@ -707,7 +702,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
|
|
707
702
|
audio, original_sr = sf.read(BytesIO(audio_file))
|
708
703
|
elif audio_file.startswith("data:"):
|
709
704
|
audio_file = audio_file.split(",")[1]
|
710
|
-
audio, original_sr = sf.read(
|
705
|
+
audio, original_sr = sf.read(
|
706
|
+
BytesIO(pybase64.b64decode(audio_file, validate=True))
|
707
|
+
)
|
711
708
|
elif audio_file.startswith("http://") or audio_file.startswith("https://"):
|
712
709
|
timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
|
713
710
|
response = requests.get(audio_file, stream=True, timeout=timeout)
|
@@ -776,12 +773,12 @@ def load_image(
|
|
776
773
|
image = Image.open(image_file)
|
777
774
|
elif image_file.startswith("data:"):
|
778
775
|
image_file = image_file.split(",")[1]
|
779
|
-
image = Image.open(BytesIO(
|
776
|
+
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
|
780
777
|
elif image_file.startswith("video:"):
|
781
778
|
image_file = image_file.replace("video:", "")
|
782
779
|
image, image_size = decode_video_base64(image_file)
|
783
780
|
elif isinstance(image_file, str):
|
784
|
-
image = Image.open(BytesIO(
|
781
|
+
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
|
785
782
|
else:
|
786
783
|
raise ValueError(f"Invalid image: {image}")
|
787
784
|
|
@@ -923,18 +920,41 @@ class CustomCacheManager(FileCacheManager):
|
|
923
920
|
|
924
921
|
self.key = key
|
925
922
|
self.lock_path = None
|
923
|
+
|
924
|
+
try:
|
925
|
+
module_path = "triton.runtime.cache"
|
926
|
+
cache_module = importlib.import_module(module_path)
|
927
|
+
|
928
|
+
default_cache_dir = getattr(cache_module, "default_cache_dir", None)
|
929
|
+
default_dump_dir = getattr(cache_module, "default_dump_dir", None)
|
930
|
+
default_override_dir = getattr(cache_module, "default_override_dir", None)
|
931
|
+
except (ModuleNotFoundError, AttributeError) as e:
|
932
|
+
default_cache_dir = None
|
933
|
+
default_dump_dir = None
|
934
|
+
default_override_dir = None
|
935
|
+
|
926
936
|
if dump:
|
927
|
-
self.cache_dir =
|
937
|
+
self.cache_dir = (
|
938
|
+
default_dump_dir()
|
939
|
+
if default_dump_dir is not None
|
940
|
+
else os.path.join(Path.home(), ".triton", "dump")
|
941
|
+
)
|
928
942
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
929
943
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
930
944
|
os.makedirs(self.cache_dir, exist_ok=True)
|
931
945
|
elif override:
|
932
|
-
self.cache_dir =
|
946
|
+
self.cache_dir = (
|
947
|
+
default_override_dir()
|
948
|
+
if default_override_dir is not None
|
949
|
+
else os.path.join(Path.home(), ".triton", "override")
|
950
|
+
)
|
933
951
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
934
952
|
else:
|
935
953
|
# create cache directory if it doesn't exist
|
936
|
-
self.cache_dir = (
|
937
|
-
|
954
|
+
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or (
|
955
|
+
default_cache_dir()
|
956
|
+
if default_cache_dir is not None
|
957
|
+
else os.path.join(Path.home(), ".triton", "cache")
|
938
958
|
)
|
939
959
|
if self.cache_dir:
|
940
960
|
try:
|
@@ -1848,7 +1868,7 @@ class MultiprocessingSerializer:
|
|
1848
1868
|
|
1849
1869
|
if output_str:
|
1850
1870
|
# Convert bytes to base64-encoded string
|
1851
|
-
output =
|
1871
|
+
output = pybase64.b64encode(output).decode("utf-8")
|
1852
1872
|
|
1853
1873
|
return output
|
1854
1874
|
|
@@ -1865,7 +1885,7 @@ class MultiprocessingSerializer:
|
|
1865
1885
|
"""
|
1866
1886
|
if isinstance(data, str):
|
1867
1887
|
# Decode base64 string to bytes
|
1868
|
-
data =
|
1888
|
+
data = pybase64.b64decode(data, validate=True)
|
1869
1889
|
|
1870
1890
|
return ForkingPickler.loads(data)
|
1871
1891
|
|
@@ -0,0 +1,281 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
9
|
+
from sglang.srt.layers.moe.topk import select_experts
|
10
|
+
|
11
|
+
|
12
|
+
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
|
13
|
+
if int4_values_interleaved.shape[-1] % 2 != 0:
|
14
|
+
raise ValueError(
|
15
|
+
"the last dim size of int4_values_interleaved tensor must be even."
|
16
|
+
)
|
17
|
+
|
18
|
+
input_tensor_int8 = int4_values_interleaved.to(torch.int8)
|
19
|
+
|
20
|
+
low_nibbles = input_tensor_int8[..., 0::2]
|
21
|
+
high_nibbles = input_tensor_int8[..., 1::2]
|
22
|
+
|
23
|
+
packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F)
|
24
|
+
|
25
|
+
return packed_tensor.to(torch.int8)
|
26
|
+
|
27
|
+
|
28
|
+
def pack_interleave(num_experts, ref_weight, ref_scale):
|
29
|
+
n, k = ref_weight.shape[1], ref_weight.shape[2]
|
30
|
+
|
31
|
+
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
|
32
|
+
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
|
33
|
+
w_q = w_q.contiguous()
|
34
|
+
|
35
|
+
scale_interleaved = ref_scale.reshape(
|
36
|
+
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
|
37
|
+
) # [E, N, K/4, 4]
|
38
|
+
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
|
39
|
+
scale_interleaved = scale_interleaved.reshape(
|
40
|
+
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
|
41
|
+
) # [E, K/4, N*4]
|
42
|
+
w_scale = scale_interleaved.contiguous()
|
43
|
+
|
44
|
+
return w_q, w_scale
|
45
|
+
|
46
|
+
|
47
|
+
@pytest.mark.parametrize("M", [1, 2, 4, 8, 16])
|
48
|
+
@pytest.mark.parametrize("N", [2048])
|
49
|
+
@pytest.mark.parametrize("K", [7168])
|
50
|
+
@pytest.mark.parametrize("E", [256])
|
51
|
+
@pytest.mark.parametrize("ep_size", [8])
|
52
|
+
@pytest.mark.parametrize("topk", [8])
|
53
|
+
@pytest.mark.parametrize("group_size", [128])
|
54
|
+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
55
|
+
def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
|
56
|
+
local_e = E // ep_size
|
57
|
+
|
58
|
+
debug = False
|
59
|
+
if debug:
|
60
|
+
a = torch.ones((M, K), dtype=dtype, device="cuda") * 0.001
|
61
|
+
ref_weight_1 = torch.ones((local_e, N * 2, K), dtype=torch.int8, device="cuda")
|
62
|
+
ref_weight_2 = torch.ones((local_e, K, N), dtype=torch.int8, device="cuda")
|
63
|
+
a1_scale = torch.ones(1, dtype=torch.float32, device="cuda")
|
64
|
+
a2_scale = torch.ones(1, dtype=torch.float32, device="cuda")
|
65
|
+
scale_1 = torch.ones(
|
66
|
+
(local_e, N * 2, K // group_size), dtype=dtype, device="cuda"
|
67
|
+
)
|
68
|
+
scale_2 = torch.ones((local_e, K, N // group_size), dtype=dtype, device="cuda")
|
69
|
+
else:
|
70
|
+
a = torch.randn(M, K, dtype=dtype, device="cuda")
|
71
|
+
ref_weight_1 = torch.randint(
|
72
|
+
-8, 8, (local_e, N * 2, K), dtype=torch.int8, device="cuda"
|
73
|
+
)
|
74
|
+
ref_weight_2 = torch.randint(
|
75
|
+
-8, 8, (local_e, K, N), dtype=torch.int8, device="cuda"
|
76
|
+
)
|
77
|
+
affine_coeff = 0.005
|
78
|
+
a1_scale = torch.randn(1, dtype=torch.float32, device="cuda")
|
79
|
+
a2_scale = torch.randn(1, dtype=torch.float32, device="cuda")
|
80
|
+
scale_1 = (
|
81
|
+
torch.randn(local_e, N * 2, K // group_size, dtype=dtype, device="cuda")
|
82
|
+
* affine_coeff
|
83
|
+
)
|
84
|
+
scale_2 = (
|
85
|
+
torch.randn(local_e, K, N // group_size, dtype=dtype, device="cuda")
|
86
|
+
* affine_coeff
|
87
|
+
)
|
88
|
+
|
89
|
+
w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1)
|
90
|
+
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2)
|
91
|
+
|
92
|
+
device = "cuda"
|
93
|
+
a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
|
94
|
+
c_strides1 = torch.full((local_e, 3), 2 * N, device=device, dtype=torch.int64)
|
95
|
+
a_strides2 = torch.full((local_e, 3), N, device=device, dtype=torch.int64)
|
96
|
+
c_strides2 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
|
97
|
+
b_strides1 = a_strides1
|
98
|
+
s_strides13 = c_strides1
|
99
|
+
b_strides2 = a_strides2
|
100
|
+
s_strides2 = c_strides2
|
101
|
+
|
102
|
+
score = torch.randn((M, E), dtype=dtype, device=device)
|
103
|
+
topk_weights, topk_ids = select_experts(
|
104
|
+
hidden_states=a,
|
105
|
+
router_logits=score,
|
106
|
+
top_k=topk,
|
107
|
+
use_grouped_topk=False,
|
108
|
+
renormalize=False,
|
109
|
+
)
|
110
|
+
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
111
|
+
expert_map[local_e:] = E
|
112
|
+
|
113
|
+
output = cutlass_moe(
|
114
|
+
a,
|
115
|
+
w1_q,
|
116
|
+
w2_q,
|
117
|
+
w1_scale,
|
118
|
+
w2_scale,
|
119
|
+
topk_weights,
|
120
|
+
topk_ids,
|
121
|
+
a_strides1,
|
122
|
+
b_strides1,
|
123
|
+
c_strides1,
|
124
|
+
a_strides2,
|
125
|
+
b_strides2,
|
126
|
+
c_strides2,
|
127
|
+
s_strides13,
|
128
|
+
s_strides2,
|
129
|
+
0,
|
130
|
+
local_e - 1,
|
131
|
+
E,
|
132
|
+
a1_scale,
|
133
|
+
a2_scale,
|
134
|
+
expert_map,
|
135
|
+
)
|
136
|
+
|
137
|
+
ref_output = ref(
|
138
|
+
a,
|
139
|
+
local_e,
|
140
|
+
topk_weights,
|
141
|
+
topk_ids,
|
142
|
+
ref_weight_1,
|
143
|
+
ref_weight_2,
|
144
|
+
scale_1,
|
145
|
+
scale_2,
|
146
|
+
has_pre_quant=True,
|
147
|
+
has_alpha=True,
|
148
|
+
pre_quant_scale_1=a1_scale,
|
149
|
+
pre_quant_scale_2=a2_scale,
|
150
|
+
alpha_1=a1_scale,
|
151
|
+
alpha_2=a2_scale,
|
152
|
+
)
|
153
|
+
|
154
|
+
# compare
|
155
|
+
torch.cuda.synchronize()
|
156
|
+
|
157
|
+
# compare final output
|
158
|
+
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
|
159
|
+
print("SUCCESS: Final output tensors are close.")
|
160
|
+
|
161
|
+
|
162
|
+
def cutlass_moe(
|
163
|
+
a: torch.Tensor,
|
164
|
+
w1_q: torch.Tensor,
|
165
|
+
w2_q: torch.Tensor,
|
166
|
+
w1_scale: torch.Tensor,
|
167
|
+
w2_scale: torch.Tensor,
|
168
|
+
topk_weights: torch.Tensor,
|
169
|
+
topk_ids_: torch.Tensor,
|
170
|
+
a_strides1: torch.Tensor,
|
171
|
+
b_strides1: torch.Tensor,
|
172
|
+
c_strides1: torch.Tensor,
|
173
|
+
a_strides2: torch.Tensor,
|
174
|
+
b_strides2: torch.Tensor,
|
175
|
+
c_strides2: torch.Tensor,
|
176
|
+
s_strides13: torch.Tensor,
|
177
|
+
s_strides2: torch.Tensor,
|
178
|
+
start_expert_id: int,
|
179
|
+
end_expert_id: int,
|
180
|
+
E: int,
|
181
|
+
a1_scale: Optional[torch.Tensor] = None,
|
182
|
+
a2_scale: Optional[torch.Tensor] = None,
|
183
|
+
expert_map: Optional[torch.Tensor] = None,
|
184
|
+
apply_router_weight_on_input: bool = False,
|
185
|
+
):
|
186
|
+
local_topk_ids = topk_ids_
|
187
|
+
local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
|
188
|
+
device = a.device
|
189
|
+
|
190
|
+
local_num_experts = end_expert_id - start_expert_id + 1
|
191
|
+
expert_offsets = torch.empty(
|
192
|
+
(local_num_experts + 1), dtype=torch.int32, device=device
|
193
|
+
)
|
194
|
+
problem_sizes1 = torch.empty(
|
195
|
+
(local_num_experts, 3), dtype=torch.int32, device=device
|
196
|
+
)
|
197
|
+
problem_sizes2 = torch.empty(
|
198
|
+
(local_num_experts, 3), dtype=torch.int32, device=device
|
199
|
+
)
|
200
|
+
return cutlass_w4a8_moe(
|
201
|
+
start_expert_id,
|
202
|
+
end_expert_id,
|
203
|
+
E,
|
204
|
+
a,
|
205
|
+
w1_q,
|
206
|
+
w2_q,
|
207
|
+
w1_scale,
|
208
|
+
w2_scale,
|
209
|
+
topk_weights,
|
210
|
+
topk_ids_,
|
211
|
+
local_topk_ids,
|
212
|
+
a_strides1,
|
213
|
+
b_strides1,
|
214
|
+
c_strides1,
|
215
|
+
a_strides2,
|
216
|
+
b_strides2,
|
217
|
+
c_strides2,
|
218
|
+
s_strides13,
|
219
|
+
s_strides2,
|
220
|
+
expert_offsets,
|
221
|
+
problem_sizes1,
|
222
|
+
problem_sizes2,
|
223
|
+
a1_scale,
|
224
|
+
a2_scale,
|
225
|
+
apply_router_weight_on_input,
|
226
|
+
)
|
227
|
+
|
228
|
+
|
229
|
+
def ref(
|
230
|
+
x: torch.Tensor,
|
231
|
+
num_experts: int,
|
232
|
+
topk_weights: torch.Tensor,
|
233
|
+
topk_ids: torch.Tensor,
|
234
|
+
ref_weight_1: torch.Tensor,
|
235
|
+
ref_weight_2: torch.Tensor,
|
236
|
+
ref_weight_scale_1: torch.Tensor,
|
237
|
+
ref_weight_scale_2: torch.Tensor,
|
238
|
+
has_pre_quant: bool = False,
|
239
|
+
has_alpha: bool = False,
|
240
|
+
pre_quant_scale_1: Optional[torch.Tensor] = None,
|
241
|
+
pre_quant_scale_2: Optional[torch.Tensor] = None,
|
242
|
+
alpha_1: Optional[torch.Tensor] = None,
|
243
|
+
alpha_2: Optional[torch.Tensor] = None,
|
244
|
+
):
|
245
|
+
results = torch.zeros_like(x)
|
246
|
+
dtype = x.dtype
|
247
|
+
for e_idx in range(num_experts):
|
248
|
+
mask = topk_ids == e_idx
|
249
|
+
activated_tokens = mask.sum(1).bool()
|
250
|
+
act = x[activated_tokens, :]
|
251
|
+
if act.shape[0] == 0:
|
252
|
+
continue
|
253
|
+
final_scale = (topk_weights * mask).sum(1)[activated_tokens].unsqueeze(1)
|
254
|
+
|
255
|
+
act = (
|
256
|
+
torch.clamp((act / pre_quant_scale_1.float()), -448.0, 448.0)
|
257
|
+
.to(torch.float8_e4m3fn)
|
258
|
+
.to(dtype)
|
259
|
+
)
|
260
|
+
w3_w1 = ref_weight_1[e_idx]
|
261
|
+
ref_w_scale_repeat = (
|
262
|
+
ref_weight_scale_1[e_idx].repeat_interleave(128, dim=1).to(float)
|
263
|
+
)
|
264
|
+
w3_w1 = (w3_w1.to(float) * ref_w_scale_repeat).to(dtype)
|
265
|
+
fc1 = ((torch.matmul(act, w3_w1.T)) * alpha_1).to(torch.float16)
|
266
|
+
|
267
|
+
gate, fc1 = fc1.chunk(2, dim=-1)
|
268
|
+
fc1 = fc1 * torch.nn.functional.silu(gate)
|
269
|
+
act = (fc1 / pre_quant_scale_2.float()).to(torch.float8_e4m3fn)
|
270
|
+
act = act.to(dtype)
|
271
|
+
|
272
|
+
w2 = ref_weight_2[e_idx]
|
273
|
+
ref_w_scale_repeat = (
|
274
|
+
ref_weight_scale_2[e_idx].repeat_interleave(128, dim=1).to(float)
|
275
|
+
)
|
276
|
+
w2 = (w2.to(float) * ref_w_scale_repeat).to(dtype)
|
277
|
+
fc2 = (torch.matmul(act, w2.T) * alpha_2).to(torch.float16)
|
278
|
+
|
279
|
+
results[activated_tokens, :] += (fc2 * final_scale).to(results.dtype)
|
280
|
+
|
281
|
+
return results
|
sglang/utils.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
"""Common utilities"""
|
2
2
|
|
3
|
-
import base64
|
4
3
|
import importlib
|
5
4
|
import json
|
6
5
|
import logging
|
@@ -20,6 +19,7 @@ from json import dumps
|
|
20
19
|
from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
21
20
|
|
22
21
|
import numpy as np
|
22
|
+
import pybase64
|
23
23
|
import requests
|
24
24
|
from IPython.display import HTML, display
|
25
25
|
from pydantic import BaseModel
|
@@ -148,15 +148,15 @@ def encode_image_base64(image_path: Union[str, bytes]):
|
|
148
148
|
if isinstance(image_path, str):
|
149
149
|
with open(image_path, "rb") as image_file:
|
150
150
|
data = image_file.read()
|
151
|
-
return
|
151
|
+
return pybase64.b64encode(data).decode("utf-8")
|
152
152
|
elif isinstance(image_path, bytes):
|
153
|
-
return
|
153
|
+
return pybase64.b64encode(image_path).decode("utf-8")
|
154
154
|
else:
|
155
155
|
# image_path is PIL.WebPImagePlugin.WebPImageFile
|
156
156
|
image = image_path
|
157
157
|
buffered = BytesIO()
|
158
158
|
image.save(buffered, format="PNG")
|
159
|
-
return
|
159
|
+
return pybase64.b64encode(buffered.getvalue()).decode("utf-8")
|
160
160
|
|
161
161
|
|
162
162
|
def encode_frame(frame):
|
@@ -223,7 +223,7 @@ def encode_video_base64(video_path: str, num_frames: int = 16):
|
|
223
223
|
video_bytes = b"".join(encoded_frames)
|
224
224
|
|
225
225
|
# Encode the concatenated bytes to base64
|
226
|
-
video_base64 = "video:" +
|
226
|
+
video_base64 = "video:" + pybase64.b64encode(video_bytes).decode("utf-8")
|
227
227
|
|
228
228
|
return video_base64
|
229
229
|
|
sglang/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.4.9"
|
1
|
+
__version__ = "0.4.9.post1"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: sglang
|
3
|
-
Version: 0.4.9
|
3
|
+
Version: 0.4.9.post1
|
4
4
|
Summary: SGLang is yet another fast serving framework for large language models and vision language models.
|
5
5
|
License: Apache License
|
6
6
|
Version 2.0, January 2004
|
@@ -239,6 +239,7 @@ Requires-Dist: prometheus-client>=0.20.0; extra == "runtime-common"
|
|
239
239
|
Requires-Dist: psutil; extra == "runtime-common"
|
240
240
|
Requires-Dist: pydantic; extra == "runtime-common"
|
241
241
|
Requires-Dist: pynvml; extra == "runtime-common"
|
242
|
+
Requires-Dist: pybase64; extra == "runtime-common"
|
242
243
|
Requires-Dist: python-multipart; extra == "runtime-common"
|
243
244
|
Requires-Dist: pyzmq>=25.1.2; extra == "runtime-common"
|
244
245
|
Requires-Dist: soundfile==0.13.1; extra == "runtime-common"
|
@@ -248,7 +249,7 @@ Requires-Dist: transformers==4.53.0; extra == "runtime-common"
|
|
248
249
|
Requires-Dist: timm==1.0.16; extra == "runtime-common"
|
249
250
|
Requires-Dist: uvicorn; extra == "runtime-common"
|
250
251
|
Requires-Dist: uvloop; extra == "runtime-common"
|
251
|
-
Requires-Dist: xgrammar==0.1.
|
252
|
+
Requires-Dist: xgrammar==0.1.20; extra == "runtime-common"
|
252
253
|
Provides-Extra: srt
|
253
254
|
Requires-Dist: sglang[runtime_common]; extra == "srt"
|
254
255
|
Requires-Dist: sgl-kernel==0.2.4; extra == "srt"
|
@@ -419,7 +420,7 @@ Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s
|
|
419
420
|
[Development Roadmap (2025 H1)](https://github.com/sgl-project/sglang/issues/4042)
|
420
421
|
|
421
422
|
## Adoption and Sponsorship
|
422
|
-
SGLang has been deployed at large scale, generating trillions of tokens in production
|
423
|
+
SGLang has been deployed at large scale, generating trillions of tokens in production each day. It is trusted and adopted by a wide range of leading enterprises and institutions, including xAI, AMD, NVIDIA, Intel, LinkedIn, Cursor, Oracle Cloud, Google Cloud, Microsoft Azure, AWS, Atlas Cloud, Voltage Park, Nebius, DataCrunch, Novita, InnoMatrix, MIT, UCLA, the University of Washington, Stanford, UC Berkeley, Tsinghua University, Jam & Tea Studios, Baseten, and other major technology organizations across North America and Asia. As an open-source LLM inference engine, SGLang has become the de facto industry standard, with deployments running on over 1,000,000 GPUs worldwide.
|
423
424
|
|
424
425
|
<img src="https://raw.githubusercontent.com/sgl-project/sgl-learning-materials/refs/heads/main/slides/adoption.png" alt="logo" width="800" margin="10px"></img>
|
425
426
|
|