sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +23 -3
- sglang/srt/configs/deepseekvl2.py +10 -1
- sglang/srt/configs/model_config.py +5 -16
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +32 -5
- sglang/srt/entrypoints/http_server.py +7 -1
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/attention/flashattention_backend.py +218 -79
- sglang/srt/layers/dp_attention.py +12 -1
- sglang/srt/layers/moe/topk.py +30 -3
- sglang/srt/layers/quantization/__init__.py +134 -165
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/fp8_kernel.py +2 -1
- sglang/srt/layers/quantization/gptq.py +30 -40
- sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +12 -0
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +19 -33
- sglang/srt/lora/lora_manager.py +20 -7
- sglang/srt/lora/mem_pool.py +12 -6
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +6 -0
- sglang/srt/managers/io_struct.py +4 -2
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/schedule_batch.py +1 -0
- sglang/srt/managers/scheduler.py +25 -19
- sglang/srt/managers/tokenizer_manager.py +0 -1
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -8
- sglang/srt/model_executor/model_runner.py +9 -6
- sglang/srt/model_loader/loader.py +11 -1
- sglang/srt/model_loader/weight_utils.py +6 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +151 -26
- sglang/srt/models/gemma3_causal.py +12 -2
- sglang/srt/models/gemma3_mm.py +6 -0
- sglang/srt/openai_api/adapter.py +88 -87
- sglang/srt/openai_api/protocol.py +10 -5
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/server_args.py +21 -11
- sglang/srt/speculative/eagle_worker.py +1 -1
- sglang/srt/utils.py +33 -0
- sglang/test/runners.py +27 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +8 -4
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +57 -53
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -12,8 +12,9 @@ def _sgemm_lora_a_kernel(
|
|
12
12
|
weights,
|
13
13
|
output,
|
14
14
|
# Matrix dimensions
|
15
|
-
N, # r
|
15
|
+
N, # stack_num * r
|
16
16
|
K, # input_dim
|
17
|
+
stack_num,
|
17
18
|
# Strides
|
18
19
|
x_stride_0,
|
19
20
|
x_stride_1,
|
@@ -22,10 +23,11 @@ def _sgemm_lora_a_kernel(
|
|
22
23
|
w_stride_2,
|
23
24
|
output_stride_0,
|
24
25
|
output_stride_1,
|
25
|
-
# Information on sequence lengths and weight id
|
26
|
+
# Information on sequence lengths,ranks and weight id
|
26
27
|
seg_lens,
|
27
28
|
seg_indptr,
|
28
29
|
weight_indices,
|
30
|
+
lora_ranks,
|
29
31
|
# Meta parameters
|
30
32
|
BLOCK_S: tl.constexpr,
|
31
33
|
BLOCK_N: tl.constexpr,
|
@@ -43,6 +45,9 @@ def _sgemm_lora_a_kernel(
|
|
43
45
|
seg_len = tl.load(seg_lens + batch_id)
|
44
46
|
w_index = tl.load(weight_indices + batch_id)
|
45
47
|
seg_start = tl.load(seg_indptr + batch_id)
|
48
|
+
rank = tl.load(lora_ranks + w_index)
|
49
|
+
# Adjust N (stack_num * max_rank) according to the specific LoRA adapter
|
50
|
+
N = tl.minimum(N, rank * stack_num)
|
46
51
|
|
47
52
|
# The tile in output matrix will have (pid_s, pid_n) as id
|
48
53
|
num_pid_n = tl.cdiv(N, BLOCK_N)
|
@@ -91,11 +96,15 @@ def _sgemm_lora_a_kernel(
|
|
91
96
|
|
92
97
|
|
93
98
|
def sgemm_lora_a_fwd(
|
94
|
-
x: torch.Tensor,
|
99
|
+
x: torch.Tensor,
|
100
|
+
weights: torch.Tensor,
|
101
|
+
batch_info: LoRABatchInfo,
|
102
|
+
stack_num: int = 1,
|
95
103
|
) -> torch.Tensor:
|
96
104
|
# x: (s, input_dim)
|
97
|
-
# weights: (num_lora, r, input_dim)
|
98
|
-
# output: (s, r)
|
105
|
+
# weights: (num_lora, stack_num * r, input_dim)
|
106
|
+
# output: (s, stack_num * r)
|
107
|
+
# stack_num: run_qkv_lora: 3, run_gate_up_lora: 2
|
99
108
|
# when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
|
100
109
|
# input_dim is much larger than r
|
101
110
|
|
@@ -126,6 +135,7 @@ def sgemm_lora_a_fwd(
|
|
126
135
|
output,
|
127
136
|
R,
|
128
137
|
K,
|
138
|
+
stack_num,
|
129
139
|
x.stride(0),
|
130
140
|
x.stride(1),
|
131
141
|
weights.stride(0),
|
@@ -136,6 +146,7 @@ def sgemm_lora_a_fwd(
|
|
136
146
|
batch_info.seg_lens,
|
137
147
|
batch_info.seg_indptr,
|
138
148
|
batch_info.weight_indices,
|
149
|
+
batch_info.lora_ranks,
|
139
150
|
BLOCK_S,
|
140
151
|
BLOCK_R,
|
141
152
|
BLOCK_K,
|
@@ -26,13 +26,14 @@ def _sgemm_lora_b_kernel(
|
|
26
26
|
seg_lens,
|
27
27
|
seg_indptr,
|
28
28
|
weight_indices,
|
29
|
+
lora_ranks,
|
29
30
|
# Meta parameters
|
30
31
|
BLOCK_S: tl.constexpr,
|
31
32
|
BLOCK_N: tl.constexpr,
|
32
33
|
BLOCK_K: tl.constexpr,
|
33
34
|
# For fused output scaling and adding
|
34
35
|
fuse_scaling_add,
|
35
|
-
|
36
|
+
scalings,
|
36
37
|
):
|
37
38
|
# x: (s, K), s is the sum of sequence lengths
|
38
39
|
# weights: (num_lora, N, K)
|
@@ -45,6 +46,10 @@ def _sgemm_lora_b_kernel(
|
|
45
46
|
seg_len = tl.load(seg_lens + batch_id)
|
46
47
|
w_index = tl.load(weight_indices + batch_id)
|
47
48
|
seg_start = tl.load(seg_indptr + batch_id)
|
49
|
+
rank = tl.load(lora_ranks + w_index)
|
50
|
+
scaling = tl.load(scalings + w_index)
|
51
|
+
# Adjust K (rank) according to the specific LoRA adapter
|
52
|
+
K = tl.minimum(K, rank)
|
48
53
|
|
49
54
|
# The tile in output matrix will have (pid_s, pid_n) as id
|
50
55
|
num_pid_n = tl.cdiv(N, BLOCK_N)
|
@@ -100,12 +105,11 @@ def sgemm_lora_b_fwd(
|
|
100
105
|
weights: torch.Tensor,
|
101
106
|
batch_info: LoRABatchInfo,
|
102
107
|
base_output: torch.Tensor = None,
|
103
|
-
scaling: float = 1.0,
|
104
108
|
) -> torch.Tensor:
|
105
|
-
# x: (s,
|
106
|
-
# weights: (num_lora, output_dim,
|
109
|
+
# x: (s, max_r)
|
110
|
+
# weights: (num_lora, output_dim, max_r)
|
107
111
|
# output: (s, output_dim)
|
108
|
-
# output_dim is much larger than
|
112
|
+
# output_dim is much larger than max_r
|
109
113
|
|
110
114
|
assert x.is_contiguous()
|
111
115
|
assert weights.is_contiguous()
|
@@ -150,10 +154,11 @@ def sgemm_lora_b_fwd(
|
|
150
154
|
batch_info.seg_lens,
|
151
155
|
batch_info.seg_indptr,
|
152
156
|
batch_info.weight_indices,
|
157
|
+
batch_info.lora_ranks,
|
153
158
|
BLOCK_S,
|
154
159
|
BLOCK_N,
|
155
160
|
BLOCK_R,
|
156
161
|
fuse_scaling_add,
|
157
|
-
|
162
|
+
batch_info.scalings,
|
158
163
|
)
|
159
164
|
return output
|
sglang/srt/lora/utils.py
CHANGED
@@ -25,6 +25,12 @@ class LoRABatchInfo:
|
|
25
25
|
# The index of lora adapter used by each sequence, in shape (bs,)
|
26
26
|
weight_indices: torch.Tensor
|
27
27
|
|
28
|
+
# ranks of each lora adapter, in shape (lora_num,)
|
29
|
+
lora_ranks: torch.Tensor
|
30
|
+
|
31
|
+
# scaling of each lora adapter, in shape (lora_num,)
|
32
|
+
scalings: torch.Tensor
|
33
|
+
|
28
34
|
|
29
35
|
class LoRAType(Enum):
|
30
36
|
LORA_A = 0
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -20,7 +20,7 @@ import copy
|
|
20
20
|
import uuid
|
21
21
|
from dataclasses import dataclass, field
|
22
22
|
from enum import Enum
|
23
|
-
from typing import Any, Dict, List, Optional, Union
|
23
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
24
24
|
|
25
25
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
26
26
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -650,7 +650,7 @@ class ProfileReqInput:
|
|
650
650
|
# If it is set, profiling is automatically stopped after this step, and
|
651
651
|
# the caller doesn't need to run stop_profile.
|
652
652
|
num_steps: Optional[int] = None
|
653
|
-
activities: Optional[List[
|
653
|
+
activities: Optional[List[Literal["CPU", "GPU", "MEM", "CUDA_PROFILER"]]] = None
|
654
654
|
|
655
655
|
|
656
656
|
class ProfileReqType(Enum):
|
@@ -675,6 +675,8 @@ class ProfileReq:
|
|
675
675
|
output_dir: Optional[str] = None
|
676
676
|
num_steps: Optional[int] = None
|
677
677
|
activities: Optional[List[str]] = None
|
678
|
+
with_stack: Optional[bool] = None
|
679
|
+
record_shapes: Optional[bool] = None
|
678
680
|
|
679
681
|
|
680
682
|
@dataclass
|
@@ -0,0 +1,63 @@
|
|
1
|
+
import asyncio
|
2
|
+
from typing import List, Union
|
3
|
+
|
4
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
5
|
+
BaseMultimodalProcessor,
|
6
|
+
get_global_processor,
|
7
|
+
)
|
8
|
+
from sglang.srt.models.clip import CLIPModel
|
9
|
+
from sglang.srt.utils import load_image
|
10
|
+
|
11
|
+
|
12
|
+
class ClipImageProcessor(BaseMultimodalProcessor):
|
13
|
+
models = [CLIPModel]
|
14
|
+
|
15
|
+
def __init__(self, hf_config, server_args, _processor):
|
16
|
+
super().__init__(hf_config, server_args, _processor)
|
17
|
+
|
18
|
+
@staticmethod
|
19
|
+
def _process_single_image_task(images, input_text):
|
20
|
+
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
|
21
|
+
return get_global_processor()(
|
22
|
+
images=images, text=input_text, return_tensors="pt"
|
23
|
+
)
|
24
|
+
|
25
|
+
async def _process_single_image(self, images, input_text):
|
26
|
+
if self.executor is not None:
|
27
|
+
loop = asyncio.get_event_loop()
|
28
|
+
image_inputs = await loop.run_in_executor(
|
29
|
+
self.executor,
|
30
|
+
ClipImageProcessor._process_single_image_task,
|
31
|
+
images,
|
32
|
+
input_text,
|
33
|
+
)
|
34
|
+
else:
|
35
|
+
image_inputs = self._processor(
|
36
|
+
images=images, text=[input_text], return_tensors="pt"
|
37
|
+
)
|
38
|
+
|
39
|
+
return image_inputs
|
40
|
+
|
41
|
+
async def process_mm_data_async(
|
42
|
+
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
43
|
+
):
|
44
|
+
if not image_data:
|
45
|
+
return None
|
46
|
+
|
47
|
+
if isinstance(input_text, list):
|
48
|
+
assert len(input_text) and isinstance(input_text[0], int)
|
49
|
+
input_text = self._processor.tokenizer.decode(input_text)
|
50
|
+
|
51
|
+
if not isinstance(image_data, list):
|
52
|
+
image_data = [image_data]
|
53
|
+
|
54
|
+
if len(image_data) > 0:
|
55
|
+
images = [load_image(image)[0] for image in image_data]
|
56
|
+
else:
|
57
|
+
images = load_image(image_data[0])[0]
|
58
|
+
|
59
|
+
image_inputs = await self._process_single_image(images, input_text)
|
60
|
+
image_inputs["data_hashes"] = [hash(str(image_data))]
|
61
|
+
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
62
|
+
|
63
|
+
return image_inputs
|
@@ -1376,6 +1376,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1376
1376
|
if (
|
1377
1377
|
global_server_args_dict["enable_flashinfer_mla"]
|
1378
1378
|
or global_server_args_dict["enable_flashmla"]
|
1379
|
+
or global_server_args_dict["attention_backend"] == "fa3"
|
1379
1380
|
):
|
1380
1381
|
decode_seq_lens = self.seq_lens.cpu()
|
1381
1382
|
else:
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -379,7 +379,7 @@ class Scheduler(
|
|
379
379
|
# Init profiler
|
380
380
|
self.torch_profiler = None
|
381
381
|
self.torch_profiler_output_dir: Optional[str] = None
|
382
|
-
self.
|
382
|
+
self.profiler_activities: Optional[List[str]] = None
|
383
383
|
self.profiler_target_forward_ct: Optional[int] = None
|
384
384
|
|
385
385
|
# Init metrics stats
|
@@ -1186,7 +1186,7 @@ class Scheduler(
|
|
1186
1186
|
ret = None
|
1187
1187
|
|
1188
1188
|
# Handle DP attention
|
1189
|
-
if self.server_args.enable_dp_attention:
|
1189
|
+
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
|
1190
1190
|
ret, _ = self.prepare_dp_attn_batch(ret)
|
1191
1191
|
|
1192
1192
|
return ret
|
@@ -1703,18 +1703,12 @@ class Scheduler(
|
|
1703
1703
|
def save_remote_model(self, params):
|
1704
1704
|
url = params["url"]
|
1705
1705
|
|
1706
|
-
|
1707
|
-
worker = self.tp_worker.worker
|
1708
|
-
else:
|
1709
|
-
worker = self.tp_worker
|
1706
|
+
worker = self.tp_worker.worker
|
1710
1707
|
|
1711
1708
|
worker.model_runner.save_remote_model(url)
|
1712
1709
|
|
1713
1710
|
def save_sharded_model(self, params):
|
1714
|
-
|
1715
|
-
worker = self.tp_worker.worker
|
1716
|
-
else:
|
1717
|
-
worker = self.tp_worker
|
1711
|
+
worker = self.tp_worker.worker
|
1718
1712
|
|
1719
1713
|
worker.model_runner.save_sharded_model(
|
1720
1714
|
path=params["path"],
|
@@ -1813,7 +1807,11 @@ class Scheduler(
|
|
1813
1807
|
def profile(self, recv_req: ProfileReq):
|
1814
1808
|
if recv_req.type == ProfileReqType.START_PROFILE:
|
1815
1809
|
return self.start_profile(
|
1816
|
-
recv_req.output_dir,
|
1810
|
+
recv_req.output_dir,
|
1811
|
+
recv_req.num_steps,
|
1812
|
+
recv_req.activities,
|
1813
|
+
recv_req.with_stack,
|
1814
|
+
recv_req.record_shapes,
|
1817
1815
|
)
|
1818
1816
|
else:
|
1819
1817
|
return self.stop_profile()
|
@@ -1823,8 +1821,10 @@ class Scheduler(
|
|
1823
1821
|
output_dir: Optional[str],
|
1824
1822
|
num_steps: Optional[int],
|
1825
1823
|
activities: Optional[List[str]],
|
1824
|
+
with_stack: Optional[bool],
|
1825
|
+
record_shapes: Optional[bool],
|
1826
1826
|
) -> None:
|
1827
|
-
if self.
|
1827
|
+
if self.profiler_activities:
|
1828
1828
|
return ProfileReqOutput(
|
1829
1829
|
success=False,
|
1830
1830
|
message="Profiling is already in progress. Call /stop_profile first.",
|
@@ -1836,7 +1836,7 @@ class Scheduler(
|
|
1836
1836
|
activities = ["CPU", "GPU"]
|
1837
1837
|
|
1838
1838
|
self.torch_profiler_output_dir = output_dir
|
1839
|
-
self.
|
1839
|
+
self.profiler_activities = activities
|
1840
1840
|
logger.info(
|
1841
1841
|
"Profiling starts. Traces will be saved to: %s",
|
1842
1842
|
self.torch_profiler_output_dir,
|
@@ -1853,13 +1853,17 @@ class Scheduler(
|
|
1853
1853
|
if torchprof_activities:
|
1854
1854
|
self.torch_profiler = torch.profiler.profile(
|
1855
1855
|
activities=torchprof_activities,
|
1856
|
-
with_stack=True,
|
1856
|
+
with_stack=with_stack if with_stack is not None else True,
|
1857
|
+
record_shapes=record_shapes if record_shapes is not None else False,
|
1857
1858
|
)
|
1858
1859
|
self.torch_profiler.start()
|
1859
1860
|
|
1860
1861
|
if "MEM" in activities:
|
1861
1862
|
torch.cuda.memory._record_memory_history(max_entries=100000)
|
1862
1863
|
|
1864
|
+
if "CUDA_PROFILER" in activities:
|
1865
|
+
torch.cuda.cudart().cudaProfilerStart()
|
1866
|
+
|
1863
1867
|
if num_steps:
|
1864
1868
|
self.profiler_target_forward_ct = self.forward_ct + num_steps
|
1865
1869
|
# The caller will be notified when reaching profiler_target_forward_ct
|
@@ -1868,7 +1872,7 @@ class Scheduler(
|
|
1868
1872
|
return ProfileReqOutput(success=True, message="Succeeded")
|
1869
1873
|
|
1870
1874
|
def stop_profile(self) -> None:
|
1871
|
-
if self.
|
1875
|
+
if self.profiler_activities is None:
|
1872
1876
|
return
|
1873
1877
|
|
1874
1878
|
logger.info("Stop profiling...")
|
@@ -1881,21 +1885,24 @@ class Scheduler(
|
|
1881
1885
|
)
|
1882
1886
|
)
|
1883
1887
|
|
1884
|
-
if "MEM" in self.
|
1888
|
+
if "MEM" in self.profiler_activities:
|
1885
1889
|
memory_profile_path = os.path.join(
|
1886
|
-
self.
|
1890
|
+
self.torch_profiler_output_dir,
|
1887
1891
|
str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
|
1888
1892
|
)
|
1889
1893
|
torch.cuda.memory._dump_snapshot(memory_profile_path)
|
1890
1894
|
torch.cuda.memory._record_memory_history(enabled=None)
|
1891
1895
|
|
1896
|
+
if "CUDA_PROFILER" in self.profiler_activities:
|
1897
|
+
torch.cuda.cudart().cudaProfilerStop()
|
1898
|
+
|
1892
1899
|
logger.info(
|
1893
1900
|
"Profiling done. Traces are saved to: %s",
|
1894
1901
|
self.torch_profiler_output_dir,
|
1895
1902
|
)
|
1896
1903
|
self.torch_profiler = None
|
1897
1904
|
self.torch_profiler_output_dir = None
|
1898
|
-
self.
|
1905
|
+
self.profiler_activities = None
|
1899
1906
|
|
1900
1907
|
if self.profiler_target_forward_ct:
|
1901
1908
|
self.send_to_tokenizer.send_pyobj(
|
@@ -1963,7 +1970,6 @@ def run_scheduler_process(
|
|
1963
1970
|
dp_rank: Optional[int],
|
1964
1971
|
pipe_writer,
|
1965
1972
|
):
|
1966
|
-
|
1967
1973
|
# Generate the prefix
|
1968
1974
|
if dp_rank is None:
|
1969
1975
|
prefix = f" TP{tp_rank}"
|
@@ -261,7 +261,6 @@ class TokenizerManager:
|
|
261
261
|
self.start_profile_communicator = _Communicator(
|
262
262
|
self.send_to_scheduler, server_args.dp_size
|
263
263
|
)
|
264
|
-
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
|
265
264
|
self.get_internal_state_communicator = _Communicator(
|
266
265
|
self.send_to_scheduler, server_args.dp_size
|
267
266
|
)
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -132,6 +132,9 @@ class TpModelWorker:
|
|
132
132
|
)[0]
|
133
133
|
set_random_seed(self.random_seed)
|
134
134
|
|
135
|
+
# A reference make this class has the same member as TpModelWorkerClient
|
136
|
+
self.worker = self
|
137
|
+
|
135
138
|
def get_worker_info(self):
|
136
139
|
return (
|
137
140
|
self.max_total_num_tokens,
|
@@ -124,8 +124,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
124
124
|
# capture less.
|
125
125
|
capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
|
126
126
|
|
127
|
-
|
128
|
-
|
127
|
+
if _is_hip:
|
128
|
+
capture_bs += [i * 8 for i in range(21, 33)]
|
129
129
|
|
130
130
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
131
131
|
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
@@ -174,6 +174,7 @@ class CudaGraphRunner:
|
|
174
174
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
175
175
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
176
176
|
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
177
|
+
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
177
178
|
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
178
179
|
self.tp_size = model_runner.server_args.tp_size
|
179
180
|
self.dp_size = model_runner.server_args.dp_size
|
@@ -245,8 +246,8 @@ class CudaGraphRunner:
|
|
245
246
|
)
|
246
247
|
else:
|
247
248
|
self.encoder_lens = None
|
248
|
-
|
249
|
-
|
249
|
+
if self.enable_dp_attention or self.enable_sp_layernorm:
|
250
|
+
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
250
251
|
self.gathered_buffer = torch.zeros(
|
251
252
|
(
|
252
253
|
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
@@ -288,7 +289,7 @@ class CudaGraphRunner:
|
|
288
289
|
self.model_runner.token_to_kv_pool.capture_mode = False
|
289
290
|
|
290
291
|
def can_run(self, forward_batch: ForwardBatch):
|
291
|
-
if self.enable_dp_attention:
|
292
|
+
if self.enable_dp_attention or self.enable_sp_layernorm:
|
292
293
|
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
|
293
294
|
|
294
295
|
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
@@ -369,7 +370,7 @@ class CudaGraphRunner:
|
|
369
370
|
encoder_lens = None
|
370
371
|
mrope_positions = self.mrope_positions[:, :bs]
|
371
372
|
|
372
|
-
if self.enable_dp_attention:
|
373
|
+
if self.enable_dp_attention or self.enable_sp_layernorm:
|
373
374
|
self.global_num_tokens_gpu.copy_(
|
374
375
|
torch.tensor(
|
375
376
|
[
|
@@ -471,7 +472,7 @@ class CudaGraphRunner:
|
|
471
472
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
472
473
|
|
473
474
|
# Pad
|
474
|
-
if self.enable_dp_attention:
|
475
|
+
if self.enable_dp_attention or self.enable_sp_layernorm:
|
475
476
|
index = bisect.bisect_left(
|
476
477
|
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
|
477
478
|
)
|
@@ -497,7 +498,7 @@ class CudaGraphRunner:
|
|
497
498
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
498
499
|
if forward_batch.mrope_positions is not None:
|
499
500
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
500
|
-
if self.enable_dp_attention:
|
501
|
+
if self.enable_dp_attention or self.enable_sp_layernorm:
|
501
502
|
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
502
503
|
|
503
504
|
if hasattr(forward_batch.spec_info, "hidden_states"):
|
@@ -64,6 +64,7 @@ from sglang.srt.model_loader.loader import (
|
|
64
64
|
)
|
65
65
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
66
66
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
67
|
+
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
67
68
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
68
69
|
from sglang.srt.server_args import ServerArgs
|
69
70
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
@@ -229,6 +230,10 @@ class ModelRunner:
|
|
229
230
|
elif server_args.enable_flashmla:
|
230
231
|
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
231
232
|
server_args.attention_backend = "flashmla"
|
233
|
+
elif server_args.attention_backend == "fa3":
|
234
|
+
logger.info(
|
235
|
+
f"MLA optimization is turned on. Use flash attention 3 backend."
|
236
|
+
)
|
232
237
|
else:
|
233
238
|
logger.info("MLA optimization is turned on. Use triton backend.")
|
234
239
|
server_args.attention_backend = "triton"
|
@@ -280,9 +285,6 @@ class ModelRunner:
|
|
280
285
|
|
281
286
|
if server_args.enable_deepep_moe:
|
282
287
|
logger.info("DeepEP is turned on.")
|
283
|
-
assert (
|
284
|
-
server_args.enable_dp_attention == True
|
285
|
-
), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'"
|
286
288
|
|
287
289
|
def init_torch_distributed(self):
|
288
290
|
logger.info("Init torch distributed begin.")
|
@@ -881,7 +883,7 @@ class ModelRunner:
|
|
881
883
|
"Please use `--attention-backend flashinfer`."
|
882
884
|
)
|
883
885
|
logger.warning(
|
884
|
-
"FlashAttention v3 Backend is in Beta. Multimodal,
|
886
|
+
"FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
|
885
887
|
)
|
886
888
|
from sglang.srt.layers.attention.flashattention_backend import (
|
887
889
|
FlashAttentionBackend,
|
@@ -1082,8 +1084,9 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
|
|
1082
1084
|
|
1083
1085
|
def _unwrap_tensor(tensor, tp_rank):
|
1084
1086
|
if isinstance(tensor, LocalSerializedTensor):
|
1085
|
-
|
1086
|
-
|
1087
|
+
monkey_patch_torch_reductions()
|
1088
|
+
tensor = tensor.get(tp_rank)
|
1089
|
+
return tensor.to(torch.cuda.current_device())
|
1087
1090
|
|
1088
1091
|
|
1089
1092
|
@dataclass
|
@@ -14,7 +14,6 @@ from abc import ABC, abstractmethod
|
|
14
14
|
from contextlib import contextmanager
|
15
15
|
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
16
16
|
|
17
|
-
import gguf
|
18
17
|
import huggingface_hub
|
19
18
|
import numpy as np
|
20
19
|
import torch
|
@@ -1155,6 +1154,17 @@ class GGUFModelLoader(BaseModelLoader):
|
|
1155
1154
|
See "Standardized tensor names" in
|
1156
1155
|
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
|
1157
1156
|
"""
|
1157
|
+
|
1158
|
+
# only load the gguf module when needed
|
1159
|
+
try:
|
1160
|
+
import gguf
|
1161
|
+
|
1162
|
+
# FIXME: add version check for gguf
|
1163
|
+
except ImportError as err:
|
1164
|
+
raise ImportError(
|
1165
|
+
"Please install gguf via `pip install gguf` to use gguf quantizer."
|
1166
|
+
) from err
|
1167
|
+
|
1158
1168
|
config = model_config.hf_config
|
1159
1169
|
model_type = config.model_type
|
1160
1170
|
# hack: ggufs have a different name than transformers
|
@@ -22,7 +22,6 @@ from typing import (
|
|
22
22
|
)
|
23
23
|
|
24
24
|
import filelock
|
25
|
-
import gguf
|
26
25
|
import huggingface_hub.constants
|
27
26
|
import numpy as np
|
28
27
|
import safetensors.torch
|
@@ -93,7 +92,7 @@ def convert_bin_to_safetensor_file(
|
|
93
92
|
pt_filename: str,
|
94
93
|
sf_filename: str,
|
95
94
|
) -> None:
|
96
|
-
loaded = torch.load(pt_filename, map_location="cpu")
|
95
|
+
loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
|
97
96
|
if "state_dict" in loaded:
|
98
97
|
loaded = loaded["state_dict"]
|
99
98
|
shared = _shared_pointers(loaded)
|
@@ -381,7 +380,7 @@ def np_cache_weights_iterator(
|
|
381
380
|
disable=not enable_tqdm,
|
382
381
|
bar_format=_BAR_FORMAT,
|
383
382
|
):
|
384
|
-
state = torch.load(bin_file, map_location="cpu")
|
383
|
+
state = torch.load(bin_file, map_location="cpu", weights_only=True)
|
385
384
|
for name, param in state.items():
|
386
385
|
param_path = os.path.join(np_folder, name)
|
387
386
|
with open(param_path, "wb") as f:
|
@@ -464,6 +463,8 @@ def pt_weights_iterator(
|
|
464
463
|
def get_gguf_extra_tensor_names(
|
465
464
|
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
|
466
465
|
) -> List[str]:
|
466
|
+
import gguf
|
467
|
+
|
467
468
|
reader = gguf.GGUFReader(gguf_file)
|
468
469
|
expected_gguf_keys = set(gguf_to_hf_name_map.keys())
|
469
470
|
exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
|
@@ -479,6 +480,8 @@ def gguf_quant_weights_iterator(
|
|
479
480
|
them to torch tensors
|
480
481
|
"""
|
481
482
|
|
483
|
+
import gguf
|
484
|
+
|
482
485
|
reader = gguf.GGUFReader(gguf_file)
|
483
486
|
|
484
487
|
for tensor in reader.tensors:
|