sglang 0.4.5.post3__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -9
- sglang/compile_deep_gemm.py +45 -4
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +9 -3
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +59 -11
- sglang/srt/disaggregation/mini_lb.py +45 -8
- sglang/srt/disaggregation/mooncake/conn.py +198 -31
- sglang/srt/disaggregation/prefill.py +24 -9
- sglang/srt/entrypoints/http_server.py +8 -2
- sglang/srt/function_call_parser.py +77 -5
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +28 -10
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/layernorm.py +38 -16
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/deep_gemm.py +17 -10
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +170 -126
- sglang/srt/managers/data_parallel_controller.py +10 -3
- sglang/srt/managers/io_struct.py +7 -0
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +29 -12
- sglang/srt/managers/scheduler.py +31 -20
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +87 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -3
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +11 -24
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +144 -70
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpmo.py +5 -1
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +18 -8
- sglang/srt/server_args.py +15 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +2 -1
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +36 -18
- sglang/version.py +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/METADATA +4 -5
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/RECORD +70 -68
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -176,17 +176,25 @@ class SchedulerDisaggregationPrefillMixin:
|
|
176
176
|
"""
|
177
177
|
|
178
178
|
@torch.no_grad()
|
179
|
-
def event_loop_normal_disagg_prefill(self):
|
179
|
+
def event_loop_normal_disagg_prefill(self: Scheduler):
|
180
180
|
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
181
181
|
|
182
182
|
while True:
|
183
183
|
recv_reqs = self.recv_requests()
|
184
184
|
self.process_input_requests(recv_reqs)
|
185
185
|
self.waiting_queue.extend(
|
186
|
-
self.
|
186
|
+
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
|
187
187
|
)
|
188
188
|
self.process_prefill_chunk()
|
189
189
|
batch = self.get_new_batch_prefill()
|
190
|
+
|
191
|
+
# Handle DP attention
|
192
|
+
if (
|
193
|
+
self.server_args.enable_dp_attention
|
194
|
+
or self.server_args.enable_sp_layernorm
|
195
|
+
):
|
196
|
+
batch, _ = self.prepare_dp_attn_batch(batch)
|
197
|
+
|
190
198
|
self.cur_batch = batch
|
191
199
|
|
192
200
|
if batch:
|
@@ -206,17 +214,25 @@ class SchedulerDisaggregationPrefillMixin:
|
|
206
214
|
self.running_batch.batch_is_full = False
|
207
215
|
|
208
216
|
@torch.no_grad()
|
209
|
-
def event_loop_overlap_disagg_prefill(self):
|
217
|
+
def event_loop_overlap_disagg_prefill(self: Scheduler):
|
210
218
|
self.result_queue = deque()
|
211
219
|
|
212
220
|
while True:
|
213
221
|
recv_reqs = self.recv_requests()
|
214
222
|
self.process_input_requests(recv_reqs)
|
215
223
|
self.waiting_queue.extend(
|
216
|
-
self.
|
224
|
+
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
|
217
225
|
)
|
218
226
|
self.process_prefill_chunk()
|
219
227
|
batch = self.get_new_batch_prefill()
|
228
|
+
|
229
|
+
# Handle DP attention
|
230
|
+
if (
|
231
|
+
self.server_args.enable_dp_attention
|
232
|
+
or self.server_args.enable_sp_layernorm
|
233
|
+
):
|
234
|
+
batch, _ = self.prepare_dp_attn_batch(batch)
|
235
|
+
|
220
236
|
self.cur_batch = batch
|
221
237
|
|
222
238
|
if batch:
|
@@ -310,7 +326,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
310
326
|
raise Exception("Transferring failed")
|
311
327
|
|
312
328
|
for req in done_reqs:
|
313
|
-
self.
|
329
|
+
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
|
314
330
|
req.metadata_buffer_index
|
315
331
|
)
|
316
332
|
|
@@ -326,9 +342,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
326
342
|
# only finished requests to running_batch.
|
327
343
|
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
328
344
|
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
329
|
-
if
|
330
|
-
|
331
|
-
): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
345
|
+
if self.enable_overlap:
|
346
|
+
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
332
347
|
self.chunked_req.tmp_end_idx = min(
|
333
348
|
len(self.chunked_req.fill_ids),
|
334
349
|
len(self.chunked_req.origin_input_ids),
|
@@ -374,7 +389,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
374
389
|
.numpy()
|
375
390
|
)
|
376
391
|
if last_chunk is True:
|
377
|
-
self.
|
392
|
+
self.disagg_prefill_bootstrap_queue.store_prefill_results(
|
378
393
|
req.metadata_buffer_index, token_id
|
379
394
|
)
|
380
395
|
page_indices = kv_to_page_indices(kv_indices, page_size)
|
@@ -84,6 +84,7 @@ from sglang.srt.utils import (
|
|
84
84
|
add_api_key_middleware,
|
85
85
|
add_prometheus_middleware,
|
86
86
|
delete_directory,
|
87
|
+
get_bool_env_var,
|
87
88
|
kill_process_tree,
|
88
89
|
set_uvicorn_logging_configs,
|
89
90
|
)
|
@@ -126,7 +127,10 @@ async def lifespan(fast_api_app: FastAPI):
|
|
126
127
|
|
127
128
|
|
128
129
|
# Fast API
|
129
|
-
app = FastAPI(
|
130
|
+
app = FastAPI(
|
131
|
+
lifespan=lifespan,
|
132
|
+
openapi_url=None if get_bool_env_var("DISABLE_OPENAPI_DOC") else "/openapi.json",
|
133
|
+
)
|
130
134
|
app.add_middleware(
|
131
135
|
CORSMiddleware,
|
132
136
|
allow_origins=["*"],
|
@@ -277,7 +281,9 @@ async def generate_from_file_request(file: UploadFile, request: Request):
|
|
277
281
|
)
|
278
282
|
|
279
283
|
try:
|
280
|
-
ret = await _global_state.generate_request(
|
284
|
+
ret = await _global_state.tokenizer_manager.generate_request(
|
285
|
+
obj, request
|
286
|
+
).__anext__()
|
281
287
|
return ret
|
282
288
|
except ValueError as e:
|
283
289
|
logger.error(f"Error: {e}")
|
@@ -491,6 +491,7 @@ class DeepSeekV3Detector(BaseFormatDetector):
|
|
491
491
|
self.eot_token = "<|tool▁calls▁end|>"
|
492
492
|
self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
|
493
493
|
self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
|
494
|
+
self._last_arguments = ""
|
494
495
|
|
495
496
|
def has_tool_call(self, text: str) -> bool:
|
496
497
|
"""Check if the text contains a deepseek format tool call."""
|
@@ -528,13 +529,84 @@ class DeepSeekV3Detector(BaseFormatDetector):
|
|
528
529
|
|
529
530
|
def structure_info(self) -> _GetInfoFunc:
|
530
531
|
return lambda name: StructureInfo(
|
531
|
-
begin="
|
532
|
-
|
533
|
-
+ "\n```json\n",
|
534
|
-
end="\n```<|tool▁call▁end|><|tool▁calls▁end|>",
|
535
|
-
trigger="<|tool▁calls▁begin|>",
|
532
|
+
begin=">" + name + "\n```json\n",
|
533
|
+
end="\n```<",
|
534
|
+
trigger=">" + name + "\n```json\n",
|
536
535
|
)
|
537
536
|
|
537
|
+
def parse_streaming_increment(
|
538
|
+
self, new_text: str, tools: List[Tool]
|
539
|
+
) -> StreamingParseResult:
|
540
|
+
"""
|
541
|
+
Streaming incremental parsing tool calls for DeepSeekV3 format.
|
542
|
+
"""
|
543
|
+
self._buffer += new_text
|
544
|
+
current_text = self._buffer
|
545
|
+
|
546
|
+
if self.bot_token not in current_text:
|
547
|
+
self._buffer = ""
|
548
|
+
for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
|
549
|
+
if e_token in new_text:
|
550
|
+
new_text = new_text.replace(e_token, "")
|
551
|
+
return StreamingParseResult(normal_text=new_text)
|
552
|
+
|
553
|
+
if not hasattr(self, "_tool_indices"):
|
554
|
+
self._tool_indices = {
|
555
|
+
tool.function.name: i
|
556
|
+
for i, tool in enumerate(tools)
|
557
|
+
if tool.function and tool.function.name
|
558
|
+
}
|
559
|
+
|
560
|
+
calls: list[ToolCallItem] = []
|
561
|
+
try:
|
562
|
+
partial_match = re.search(
|
563
|
+
pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
|
564
|
+
string=current_text,
|
565
|
+
flags=re.DOTALL,
|
566
|
+
)
|
567
|
+
if partial_match:
|
568
|
+
func_name = partial_match.group(2).strip()
|
569
|
+
func_args_raw = partial_match.group(3).strip()
|
570
|
+
|
571
|
+
if not self.current_tool_name_sent:
|
572
|
+
calls.append(
|
573
|
+
ToolCallItem(
|
574
|
+
tool_index=self._tool_indices.get(func_name, 0),
|
575
|
+
name=func_name,
|
576
|
+
parameters="",
|
577
|
+
)
|
578
|
+
)
|
579
|
+
self.current_tool_name_sent = True
|
580
|
+
else:
|
581
|
+
argument_diff = (
|
582
|
+
func_args_raw[len(self._last_arguments) :]
|
583
|
+
if func_args_raw.startswith(self._last_arguments)
|
584
|
+
else func_args_raw
|
585
|
+
)
|
586
|
+
|
587
|
+
if argument_diff:
|
588
|
+
calls.append(
|
589
|
+
ToolCallItem(
|
590
|
+
tool_index=self._tool_indices.get(func_name, 0),
|
591
|
+
name=None,
|
592
|
+
parameters=argument_diff,
|
593
|
+
)
|
594
|
+
)
|
595
|
+
self._last_arguments += argument_diff
|
596
|
+
|
597
|
+
if _is_complete_json(func_args_raw):
|
598
|
+
result = StreamingParseResult(normal_text="", calls=calls)
|
599
|
+
self._buffer = ""
|
600
|
+
self._last_arguments = ""
|
601
|
+
self.current_tool_name_sent = False
|
602
|
+
return result
|
603
|
+
|
604
|
+
return StreamingParseResult(normal_text="", calls=calls)
|
605
|
+
|
606
|
+
except Exception as e:
|
607
|
+
logger.error(f"Error in parse_streaming_increment: {e}")
|
608
|
+
return StreamingParseResult(normal_text=current_text)
|
609
|
+
|
538
610
|
|
539
611
|
class MultiFormatParser:
|
540
612
|
def __init__(self, detectors: List[BaseFormatDetector]):
|
@@ -62,6 +62,7 @@ class AttentionBackend(ABC):
|
|
62
62
|
layer: RadixAttention,
|
63
63
|
forward_batch: ForwardBatch,
|
64
64
|
save_kv_cache: bool = True,
|
65
|
+
**kwargs,
|
65
66
|
):
|
66
67
|
"""Run forward on an attention layer."""
|
67
68
|
if forward_batch.forward_mode.is_decode():
|
@@ -72,6 +73,7 @@ class AttentionBackend(ABC):
|
|
72
73
|
layer,
|
73
74
|
forward_batch,
|
74
75
|
save_kv_cache=save_kv_cache,
|
76
|
+
**kwargs,
|
75
77
|
)
|
76
78
|
else:
|
77
79
|
return self.forward_extend(
|
@@ -81,6 +83,7 @@ class AttentionBackend(ABC):
|
|
81
83
|
layer,
|
82
84
|
forward_batch,
|
83
85
|
save_kv_cache=save_kv_cache,
|
86
|
+
**kwargs,
|
84
87
|
)
|
85
88
|
|
86
89
|
def forward_decode(
|
@@ -623,6 +623,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
623
623
|
layer: RadixAttention,
|
624
624
|
forward_batch: ForwardBatch,
|
625
625
|
save_kv_cache=True,
|
626
|
+
# For multi-head latent attention
|
627
|
+
q_rope: Optional[torch.Tensor] = None,
|
628
|
+
k_rope: Optional[torch.Tensor] = None,
|
626
629
|
):
|
627
630
|
if k is not None:
|
628
631
|
assert v is not None
|
@@ -637,11 +640,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
637
640
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
638
641
|
)
|
639
642
|
else:
|
640
|
-
forward_batch.token_to_kv_pool.
|
643
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
641
644
|
layer,
|
642
645
|
cache_loc,
|
643
646
|
k,
|
644
|
-
|
647
|
+
k_rope,
|
645
648
|
)
|
646
649
|
|
647
650
|
# Use precomputed metadata across all layers
|
@@ -815,9 +818,15 @@ class FlashAttentionBackend(AttentionBackend):
|
|
815
818
|
c_kv_cache = c_kv.view(
|
816
819
|
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
817
820
|
)
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
+
if q_rope is not None:
|
822
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
823
|
+
q_rope = q_rope.view(
|
824
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
825
|
+
)
|
826
|
+
else:
|
827
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
828
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
829
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
821
830
|
|
822
831
|
result = flash_attn_with_kvcache(
|
823
832
|
q=q_rope,
|
@@ -877,6 +886,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
877
886
|
layer: RadixAttention,
|
878
887
|
forward_batch: ForwardBatch,
|
879
888
|
save_kv_cache=True,
|
889
|
+
# For multi-head latent attention
|
890
|
+
q_rope: Optional[torch.Tensor] = None,
|
891
|
+
k_rope: Optional[torch.Tensor] = None,
|
880
892
|
) -> torch.Tensor:
|
881
893
|
if k is not None:
|
882
894
|
assert v is not None
|
@@ -891,11 +903,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
891
903
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
892
904
|
)
|
893
905
|
else:
|
894
|
-
forward_batch.token_to_kv_pool.
|
906
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
895
907
|
layer,
|
896
908
|
cache_loc,
|
897
909
|
k,
|
898
|
-
|
910
|
+
k_rope,
|
899
911
|
)
|
900
912
|
|
901
913
|
# Use precomputed metadata across all layers
|
@@ -1047,9 +1059,15 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1047
1059
|
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
1048
1060
|
)
|
1049
1061
|
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1062
|
+
if q_rope is not None:
|
1063
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
1064
|
+
q_rope = q_rope.view(
|
1065
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
1066
|
+
)
|
1067
|
+
else:
|
1068
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
1069
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
1070
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
1053
1071
|
max_seqlen_q = metadata.max_seq_len_q
|
1054
1072
|
|
1055
1073
|
result = flash_attn_with_kvcache(
|
@@ -68,9 +68,6 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
68
68
|
self.num_q_heads = (
|
69
69
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
70
70
|
)
|
71
|
-
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
72
|
-
get_attention_tp_size()
|
73
|
-
)
|
74
71
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
75
72
|
self.num_local_heads = (
|
76
73
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
@@ -111,8 +108,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
111
108
|
)
|
112
109
|
mla_metadata, num_splits = get_mla_metadata(
|
113
110
|
forward_batch.seq_lens.to(torch.int32),
|
114
|
-
Q_LEN * self.num_q_heads
|
115
|
-
|
111
|
+
Q_LEN * self.num_q_heads,
|
112
|
+
1,
|
116
113
|
)
|
117
114
|
self.forward_metadata = FlashMLADecodeMetadata(
|
118
115
|
mla_metadata,
|
@@ -141,8 +138,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
141
138
|
|
142
139
|
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
|
143
140
|
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
|
144
|
-
Q_LEN * self.num_q_heads
|
145
|
-
|
141
|
+
Q_LEN * self.num_q_heads,
|
142
|
+
1,
|
146
143
|
)
|
147
144
|
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
148
145
|
|
@@ -171,8 +168,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
171
168
|
)
|
172
169
|
mla_metadata, num_splits = get_mla_metadata(
|
173
170
|
seq_lens.to(torch.int32),
|
174
|
-
Q_LEN * self.num_q_heads
|
175
|
-
|
171
|
+
Q_LEN * self.num_q_heads,
|
172
|
+
1,
|
176
173
|
)
|
177
174
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
178
175
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
@@ -221,8 +218,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
221
218
|
)
|
222
219
|
mla_metadata, num_splits = get_mla_metadata(
|
223
220
|
seq_lens.to(torch.int32),
|
224
|
-
Q_LEN * self.num_q_heads
|
225
|
-
|
221
|
+
Q_LEN * self.num_q_heads,
|
222
|
+
1,
|
226
223
|
)
|
227
224
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
228
225
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -22,8 +22,6 @@ import torch.nn as nn
|
|
22
22
|
from sglang.srt.custom_op import CustomOp
|
23
23
|
from sglang.srt.utils import is_cuda, is_hip
|
24
24
|
|
25
|
-
logger = logging.getLogger(__name__)
|
26
|
-
|
27
25
|
_is_cuda = is_cuda()
|
28
26
|
_is_hip = is_hip()
|
29
27
|
|
@@ -36,19 +34,9 @@ if _is_cuda:
|
|
36
34
|
)
|
37
35
|
|
38
36
|
if _is_hip:
|
37
|
+
from vllm._custom_ops import fused_add_rms_norm, rms_norm
|
39
38
|
|
40
|
-
|
41
|
-
|
42
|
-
rmsnorm = rms_norm
|
43
|
-
|
44
|
-
def fused_add_rmsnorm(
|
45
|
-
x: torch.Tensor,
|
46
|
-
residual: torch.Tensor,
|
47
|
-
w: torch.Tensor,
|
48
|
-
eps: float,
|
49
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
50
|
-
rmsnorm2d_fwd_with_add(x, x, residual, residual, w, eps)
|
51
|
-
return x, residual
|
39
|
+
logger = logging.getLogger(__name__)
|
52
40
|
|
53
41
|
|
54
42
|
class RMSNorm(CustomOp):
|
@@ -61,23 +49,49 @@ class RMSNorm(CustomOp):
|
|
61
49
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
62
50
|
self.variance_epsilon = eps
|
63
51
|
|
52
|
+
def forward(self, *args, **kwargs):
|
53
|
+
if torch.compiler.is_compiling():
|
54
|
+
return self.forward_native(*args, **kwargs)
|
55
|
+
if _is_cuda:
|
56
|
+
return self.forward_cuda(*args, **kwargs)
|
57
|
+
elif _is_hip:
|
58
|
+
return self.forward_hip(*args, **kwargs)
|
59
|
+
else:
|
60
|
+
return self.forward_native(*args, **kwargs)
|
61
|
+
|
64
62
|
def forward_cuda(
|
65
63
|
self,
|
66
64
|
x: torch.Tensor,
|
67
65
|
residual: Optional[torch.Tensor] = None,
|
68
66
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
69
|
-
|
70
67
|
if residual is not None:
|
71
68
|
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
72
69
|
return x, residual
|
73
70
|
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
74
71
|
return out
|
75
72
|
|
73
|
+
def forward_hip(
|
74
|
+
self,
|
75
|
+
x: torch.Tensor,
|
76
|
+
residual: Optional[torch.Tensor] = None,
|
77
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
78
|
+
if not x.is_contiguous():
|
79
|
+
# NOTE: Romove this if aiter kernel supports discontinuous input
|
80
|
+
x = x.contiguous()
|
81
|
+
if residual is not None:
|
82
|
+
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
83
|
+
return x, residual
|
84
|
+
out = torch.empty_like(x)
|
85
|
+
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
86
|
+
return out
|
87
|
+
|
76
88
|
def forward_native(
|
77
89
|
self,
|
78
90
|
x: torch.Tensor,
|
79
91
|
residual: Optional[torch.Tensor] = None,
|
80
92
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
93
|
+
if not x.is_contiguous():
|
94
|
+
x = x.contiguous()
|
81
95
|
orig_dtype = x.dtype
|
82
96
|
x = x.to(torch.float32)
|
83
97
|
if residual is not None:
|
@@ -103,6 +117,14 @@ class GemmaRMSNorm(CustomOp):
|
|
103
117
|
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
104
118
|
self.variance_epsilon = eps
|
105
119
|
|
120
|
+
def forward(self, *args, **kwargs):
|
121
|
+
if torch.compiler.is_compiling():
|
122
|
+
return self.forward_native(*args, **kwargs)
|
123
|
+
if _is_cuda:
|
124
|
+
return self.forward_cuda(*args, **kwargs)
|
125
|
+
else:
|
126
|
+
return self.forward_native(*args, **kwargs)
|
127
|
+
|
106
128
|
def forward_native(
|
107
129
|
self,
|
108
130
|
x: torch.Tensor,
|
@@ -156,6 +178,6 @@ class Gemma3RMSNorm(nn.Module):
|
|
156
178
|
|
157
179
|
if not (_is_cuda or _is_hip):
|
158
180
|
logger.info(
|
159
|
-
"sgl-kernel is not available on
|
181
|
+
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
|
160
182
|
)
|
161
183
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
@@ -335,13 +335,13 @@ class LogitsProcessor(nn.Module):
|
|
335
335
|
aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
|
336
336
|
hidden_states_to_store = (
|
337
337
|
aux_pruned_states[sample_indices]
|
338
|
-
if sample_indices
|
338
|
+
if sample_indices is not None
|
339
339
|
else aux_pruned_states
|
340
340
|
)
|
341
341
|
else:
|
342
342
|
hidden_states_to_store = (
|
343
343
|
pruned_states[sample_indices]
|
344
|
-
if sample_indices
|
344
|
+
if sample_indices is not None
|
345
345
|
else pruned_states
|
346
346
|
)
|
347
347
|
else:
|
@@ -8,6 +8,7 @@ from typing import Callable, Optional
|
|
8
8
|
import torch
|
9
9
|
from torch.nn import functional as F
|
10
10
|
|
11
|
+
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
11
12
|
from sglang.srt.layers.moe.topk import select_experts
|
12
13
|
|
13
14
|
|
@@ -30,7 +31,7 @@ def fused_moe_forward_native(
|
|
30
31
|
) -> torch.Tensor:
|
31
32
|
|
32
33
|
if apply_router_weight_on_input:
|
33
|
-
raise NotImplementedError
|
34
|
+
raise NotImplementedError()
|
34
35
|
|
35
36
|
topk_weights, topk_ids = select_experts(
|
36
37
|
hidden_states=x,
|
@@ -75,9 +76,6 @@ def moe_forward_native(
|
|
75
76
|
activation: str = "silu",
|
76
77
|
routed_scaling_factor: Optional[float] = None,
|
77
78
|
) -> torch.Tensor:
|
78
|
-
|
79
|
-
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
80
|
-
|
81
79
|
topk_weights, topk_ids = select_experts(
|
82
80
|
hidden_states=x,
|
83
81
|
router_logits=router_logits,
|