sglang 0.4.5.post3__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -9
  3. sglang/compile_deep_gemm.py +45 -4
  4. sglang/srt/code_completion_parser.py +1 -1
  5. sglang/srt/configs/deepseekvl2.py +1 -1
  6. sglang/srt/configs/model_config.py +9 -3
  7. sglang/srt/constrained/llguidance_backend.py +78 -61
  8. sglang/srt/conversation.py +34 -1
  9. sglang/srt/disaggregation/decode.py +59 -11
  10. sglang/srt/disaggregation/mini_lb.py +45 -8
  11. sglang/srt/disaggregation/mooncake/conn.py +198 -31
  12. sglang/srt/disaggregation/prefill.py +24 -9
  13. sglang/srt/entrypoints/http_server.py +8 -2
  14. sglang/srt/function_call_parser.py +77 -5
  15. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  16. sglang/srt/layers/attention/flashattention_backend.py +28 -10
  17. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  18. sglang/srt/layers/attention/vision.py +2 -0
  19. sglang/srt/layers/layernorm.py +38 -16
  20. sglang/srt/layers/logits_processor.py +2 -2
  21. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  25. sglang/srt/layers/pooler.py +6 -0
  26. sglang/srt/layers/quantization/awq.py +5 -1
  27. sglang/srt/layers/quantization/deep_gemm.py +17 -10
  28. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  29. sglang/srt/layers/radix_attention.py +13 -3
  30. sglang/srt/layers/rotary_embedding.py +170 -126
  31. sglang/srt/managers/data_parallel_controller.py +10 -3
  32. sglang/srt/managers/io_struct.py +7 -0
  33. sglang/srt/managers/mm_utils.py +85 -28
  34. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  35. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  36. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  37. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  38. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  40. sglang/srt/managers/schedule_batch.py +29 -12
  41. sglang/srt/managers/scheduler.py +31 -20
  42. sglang/srt/managers/tokenizer_manager.py +5 -1
  43. sglang/srt/mem_cache/memory_pool.py +87 -0
  44. sglang/srt/model_executor/cuda_graph_runner.py +4 -3
  45. sglang/srt/model_executor/forward_batch_info.py +51 -95
  46. sglang/srt/model_executor/model_runner.py +11 -24
  47. sglang/srt/models/deepseek.py +12 -2
  48. sglang/srt/models/deepseek_nextn.py +101 -6
  49. sglang/srt/models/deepseek_v2.py +144 -70
  50. sglang/srt/models/deepseek_vl2.py +9 -4
  51. sglang/srt/models/gemma3_causal.py +1 -1
  52. sglang/srt/models/llama4.py +0 -1
  53. sglang/srt/models/minicpmo.py +5 -1
  54. sglang/srt/models/mllama4.py +2 -2
  55. sglang/srt/models/qwen2_5_vl.py +3 -6
  56. sglang/srt/models/qwen2_vl.py +3 -7
  57. sglang/srt/models/roberta.py +178 -0
  58. sglang/srt/openai_api/adapter.py +18 -8
  59. sglang/srt/server_args.py +15 -22
  60. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  61. sglang/srt/torch_memory_saver_adapter.py +10 -1
  62. sglang/srt/utils.py +2 -1
  63. sglang/test/runners.py +6 -13
  64. sglang/test/test_utils.py +36 -18
  65. sglang/version.py +1 -1
  66. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/METADATA +4 -5
  67. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/RECORD +70 -68
  68. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  69. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  70. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -176,17 +176,25 @@ class SchedulerDisaggregationPrefillMixin:
176
176
  """
177
177
 
178
178
  @torch.no_grad()
179
- def event_loop_normal_disagg_prefill(self):
179
+ def event_loop_normal_disagg_prefill(self: Scheduler):
180
180
  """A normal scheduler loop for prefill worker in disaggregation mode."""
181
181
 
182
182
  while True:
183
183
  recv_reqs = self.recv_requests()
184
184
  self.process_input_requests(recv_reqs)
185
185
  self.waiting_queue.extend(
186
- self.disagg_prefill_pending_queue.pop_bootstrapped()
186
+ self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
187
187
  )
188
188
  self.process_prefill_chunk()
189
189
  batch = self.get_new_batch_prefill()
190
+
191
+ # Handle DP attention
192
+ if (
193
+ self.server_args.enable_dp_attention
194
+ or self.server_args.enable_sp_layernorm
195
+ ):
196
+ batch, _ = self.prepare_dp_attn_batch(batch)
197
+
190
198
  self.cur_batch = batch
191
199
 
192
200
  if batch:
@@ -206,17 +214,25 @@ class SchedulerDisaggregationPrefillMixin:
206
214
  self.running_batch.batch_is_full = False
207
215
 
208
216
  @torch.no_grad()
209
- def event_loop_overlap_disagg_prefill(self):
217
+ def event_loop_overlap_disagg_prefill(self: Scheduler):
210
218
  self.result_queue = deque()
211
219
 
212
220
  while True:
213
221
  recv_reqs = self.recv_requests()
214
222
  self.process_input_requests(recv_reqs)
215
223
  self.waiting_queue.extend(
216
- self.disagg_prefill_pending_queue.pop_bootstrapped()
224
+ self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
217
225
  )
218
226
  self.process_prefill_chunk()
219
227
  batch = self.get_new_batch_prefill()
228
+
229
+ # Handle DP attention
230
+ if (
231
+ self.server_args.enable_dp_attention
232
+ or self.server_args.enable_sp_layernorm
233
+ ):
234
+ batch, _ = self.prepare_dp_attn_batch(batch)
235
+
220
236
  self.cur_batch = batch
221
237
 
222
238
  if batch:
@@ -310,7 +326,7 @@ class SchedulerDisaggregationPrefillMixin:
310
326
  raise Exception("Transferring failed")
311
327
 
312
328
  for req in done_reqs:
313
- self.disagg_prefill_pending_queue.req_to_metadata_buffer_idx_allocator.free(
329
+ self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
314
330
  req.metadata_buffer_index
315
331
  )
316
332
 
@@ -326,9 +342,8 @@ class SchedulerDisaggregationPrefillMixin:
326
342
  # only finished requests to running_batch.
327
343
  self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
328
344
  self.tree_cache.cache_unfinished_req(self.chunked_req)
329
- if (
330
- self.enable_overlap
331
- ): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
345
+ if self.enable_overlap:
346
+ # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
332
347
  self.chunked_req.tmp_end_idx = min(
333
348
  len(self.chunked_req.fill_ids),
334
349
  len(self.chunked_req.origin_input_ids),
@@ -374,7 +389,7 @@ class SchedulerDisaggregationPrefillMixin:
374
389
  .numpy()
375
390
  )
376
391
  if last_chunk is True:
377
- self.disagg_prefill_pending_queue.store_prefill_results(
392
+ self.disagg_prefill_bootstrap_queue.store_prefill_results(
378
393
  req.metadata_buffer_index, token_id
379
394
  )
380
395
  page_indices = kv_to_page_indices(kv_indices, page_size)
@@ -84,6 +84,7 @@ from sglang.srt.utils import (
84
84
  add_api_key_middleware,
85
85
  add_prometheus_middleware,
86
86
  delete_directory,
87
+ get_bool_env_var,
87
88
  kill_process_tree,
88
89
  set_uvicorn_logging_configs,
89
90
  )
@@ -126,7 +127,10 @@ async def lifespan(fast_api_app: FastAPI):
126
127
 
127
128
 
128
129
  # Fast API
129
- app = FastAPI(lifespan=lifespan)
130
+ app = FastAPI(
131
+ lifespan=lifespan,
132
+ openapi_url=None if get_bool_env_var("DISABLE_OPENAPI_DOC") else "/openapi.json",
133
+ )
130
134
  app.add_middleware(
131
135
  CORSMiddleware,
132
136
  allow_origins=["*"],
@@ -277,7 +281,9 @@ async def generate_from_file_request(file: UploadFile, request: Request):
277
281
  )
278
282
 
279
283
  try:
280
- ret = await _global_state.generate_request(obj, request).__anext__()
284
+ ret = await _global_state.tokenizer_manager.generate_request(
285
+ obj, request
286
+ ).__anext__()
281
287
  return ret
282
288
  except ValueError as e:
283
289
  logger.error(f"Error: {e}")
@@ -491,6 +491,7 @@ class DeepSeekV3Detector(BaseFormatDetector):
491
491
  self.eot_token = "<|tool▁calls▁end|>"
492
492
  self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
493
493
  self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
494
+ self._last_arguments = ""
494
495
 
495
496
  def has_tool_call(self, text: str) -> bool:
496
497
  """Check if the text contains a deepseek format tool call."""
@@ -528,13 +529,84 @@ class DeepSeekV3Detector(BaseFormatDetector):
528
529
 
529
530
  def structure_info(self) -> _GetInfoFunc:
530
531
  return lambda name: StructureInfo(
531
- begin="<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>"
532
- + name
533
- + "\n```json\n",
534
- end="\n```<|tool▁call▁end|><|tool▁calls▁end|>",
535
- trigger="<|tool▁calls▁begin|>",
532
+ begin=">" + name + "\n```json\n",
533
+ end="\n```<",
534
+ trigger=">" + name + "\n```json\n",
536
535
  )
537
536
 
537
+ def parse_streaming_increment(
538
+ self, new_text: str, tools: List[Tool]
539
+ ) -> StreamingParseResult:
540
+ """
541
+ Streaming incremental parsing tool calls for DeepSeekV3 format.
542
+ """
543
+ self._buffer += new_text
544
+ current_text = self._buffer
545
+
546
+ if self.bot_token not in current_text:
547
+ self._buffer = ""
548
+ for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
549
+ if e_token in new_text:
550
+ new_text = new_text.replace(e_token, "")
551
+ return StreamingParseResult(normal_text=new_text)
552
+
553
+ if not hasattr(self, "_tool_indices"):
554
+ self._tool_indices = {
555
+ tool.function.name: i
556
+ for i, tool in enumerate(tools)
557
+ if tool.function and tool.function.name
558
+ }
559
+
560
+ calls: list[ToolCallItem] = []
561
+ try:
562
+ partial_match = re.search(
563
+ pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
564
+ string=current_text,
565
+ flags=re.DOTALL,
566
+ )
567
+ if partial_match:
568
+ func_name = partial_match.group(2).strip()
569
+ func_args_raw = partial_match.group(3).strip()
570
+
571
+ if not self.current_tool_name_sent:
572
+ calls.append(
573
+ ToolCallItem(
574
+ tool_index=self._tool_indices.get(func_name, 0),
575
+ name=func_name,
576
+ parameters="",
577
+ )
578
+ )
579
+ self.current_tool_name_sent = True
580
+ else:
581
+ argument_diff = (
582
+ func_args_raw[len(self._last_arguments) :]
583
+ if func_args_raw.startswith(self._last_arguments)
584
+ else func_args_raw
585
+ )
586
+
587
+ if argument_diff:
588
+ calls.append(
589
+ ToolCallItem(
590
+ tool_index=self._tool_indices.get(func_name, 0),
591
+ name=None,
592
+ parameters=argument_diff,
593
+ )
594
+ )
595
+ self._last_arguments += argument_diff
596
+
597
+ if _is_complete_json(func_args_raw):
598
+ result = StreamingParseResult(normal_text="", calls=calls)
599
+ self._buffer = ""
600
+ self._last_arguments = ""
601
+ self.current_tool_name_sent = False
602
+ return result
603
+
604
+ return StreamingParseResult(normal_text="", calls=calls)
605
+
606
+ except Exception as e:
607
+ logger.error(f"Error in parse_streaming_increment: {e}")
608
+ return StreamingParseResult(normal_text=current_text)
609
+
538
610
 
539
611
  class MultiFormatParser:
540
612
  def __init__(self, detectors: List[BaseFormatDetector]):
@@ -62,6 +62,7 @@ class AttentionBackend(ABC):
62
62
  layer: RadixAttention,
63
63
  forward_batch: ForwardBatch,
64
64
  save_kv_cache: bool = True,
65
+ **kwargs,
65
66
  ):
66
67
  """Run forward on an attention layer."""
67
68
  if forward_batch.forward_mode.is_decode():
@@ -72,6 +73,7 @@ class AttentionBackend(ABC):
72
73
  layer,
73
74
  forward_batch,
74
75
  save_kv_cache=save_kv_cache,
76
+ **kwargs,
75
77
  )
76
78
  else:
77
79
  return self.forward_extend(
@@ -81,6 +83,7 @@ class AttentionBackend(ABC):
81
83
  layer,
82
84
  forward_batch,
83
85
  save_kv_cache=save_kv_cache,
86
+ **kwargs,
84
87
  )
85
88
 
86
89
  def forward_decode(
@@ -623,6 +623,9 @@ class FlashAttentionBackend(AttentionBackend):
623
623
  layer: RadixAttention,
624
624
  forward_batch: ForwardBatch,
625
625
  save_kv_cache=True,
626
+ # For multi-head latent attention
627
+ q_rope: Optional[torch.Tensor] = None,
628
+ k_rope: Optional[torch.Tensor] = None,
626
629
  ):
627
630
  if k is not None:
628
631
  assert v is not None
@@ -637,11 +640,11 @@ class FlashAttentionBackend(AttentionBackend):
637
640
  layer, cache_loc, k, v, layer.k_scale, layer.v_scale
638
641
  )
639
642
  else:
640
- forward_batch.token_to_kv_pool.set_kv_buffer(
643
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
641
644
  layer,
642
645
  cache_loc,
643
646
  k,
644
- v,
647
+ k_rope,
645
648
  )
646
649
 
647
650
  # Use precomputed metadata across all layers
@@ -815,9 +818,15 @@ class FlashAttentionBackend(AttentionBackend):
815
818
  c_kv_cache = c_kv.view(
816
819
  -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
817
820
  )
818
- q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
819
- q_nope = q_all[:, :, : layer.v_head_dim]
820
- q_rope = q_all[:, :, layer.v_head_dim :]
821
+ if q_rope is not None:
822
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
823
+ q_rope = q_rope.view(
824
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
825
+ )
826
+ else:
827
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
828
+ q_nope = q_all[:, :, : layer.v_head_dim]
829
+ q_rope = q_all[:, :, layer.v_head_dim :]
821
830
 
822
831
  result = flash_attn_with_kvcache(
823
832
  q=q_rope,
@@ -877,6 +886,9 @@ class FlashAttentionBackend(AttentionBackend):
877
886
  layer: RadixAttention,
878
887
  forward_batch: ForwardBatch,
879
888
  save_kv_cache=True,
889
+ # For multi-head latent attention
890
+ q_rope: Optional[torch.Tensor] = None,
891
+ k_rope: Optional[torch.Tensor] = None,
880
892
  ) -> torch.Tensor:
881
893
  if k is not None:
882
894
  assert v is not None
@@ -891,11 +903,11 @@ class FlashAttentionBackend(AttentionBackend):
891
903
  layer, cache_loc, k, v, layer.k_scale, layer.v_scale
892
904
  )
893
905
  else:
894
- forward_batch.token_to_kv_pool.set_kv_buffer(
906
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
895
907
  layer,
896
908
  cache_loc,
897
909
  k,
898
- v,
910
+ k_rope,
899
911
  )
900
912
 
901
913
  # Use precomputed metadata across all layers
@@ -1047,9 +1059,15 @@ class FlashAttentionBackend(AttentionBackend):
1047
1059
  -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
1048
1060
  )
1049
1061
 
1050
- q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
1051
- q_nope = q_all[:, :, : layer.v_head_dim]
1052
- q_rope = q_all[:, :, layer.v_head_dim :]
1062
+ if q_rope is not None:
1063
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
1064
+ q_rope = q_rope.view(
1065
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
1066
+ )
1067
+ else:
1068
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
1069
+ q_nope = q_all[:, :, : layer.v_head_dim]
1070
+ q_rope = q_all[:, :, layer.v_head_dim :]
1053
1071
  max_seqlen_q = metadata.max_seq_len_q
1054
1072
 
1055
1073
  result = flash_attn_with_kvcache(
@@ -68,9 +68,6 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
68
68
  self.num_q_heads = (
69
69
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
70
70
  )
71
- self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
72
- get_attention_tp_size()
73
- )
74
71
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
75
72
  self.num_local_heads = (
76
73
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
@@ -111,8 +108,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
111
108
  )
112
109
  mla_metadata, num_splits = get_mla_metadata(
113
110
  forward_batch.seq_lens.to(torch.int32),
114
- Q_LEN * self.num_q_heads // self.num_kv_heads,
115
- self.num_kv_heads,
111
+ Q_LEN * self.num_q_heads,
112
+ 1,
116
113
  )
117
114
  self.forward_metadata = FlashMLADecodeMetadata(
118
115
  mla_metadata,
@@ -141,8 +138,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
141
138
 
142
139
  self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
143
140
  torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
144
- Q_LEN * self.num_q_heads // self.num_kv_heads,
145
- self.num_kv_heads,
141
+ Q_LEN * self.num_q_heads,
142
+ 1,
146
143
  )
147
144
  self.cuda_graph_kv_indices = cuda_graph_kv_indices
148
145
 
@@ -171,8 +168,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
171
168
  )
172
169
  mla_metadata, num_splits = get_mla_metadata(
173
170
  seq_lens.to(torch.int32),
174
- Q_LEN * self.num_q_heads // self.num_kv_heads,
175
- self.num_kv_heads,
171
+ Q_LEN * self.num_q_heads,
172
+ 1,
176
173
  )
177
174
  self.cuda_graph_mla_metadata.copy_(mla_metadata)
178
175
  self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -221,8 +218,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
221
218
  )
222
219
  mla_metadata, num_splits = get_mla_metadata(
223
220
  seq_lens.to(torch.int32),
224
- Q_LEN * self.num_q_heads // self.num_kv_heads,
225
- self.num_kv_heads,
221
+ Q_LEN * self.num_q_heads,
222
+ 1,
226
223
  )
227
224
  self.cuda_graph_mla_metadata.copy_(mla_metadata)
228
225
  self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -271,6 +271,8 @@ class VisionSdpaAttention(nn.Module):
271
271
  Returns:
272
272
  [b * s, h, head_size]
273
273
  """
274
+ if self.flatten_batch:
275
+ assert bsz == 1, "flatten_batch is True, bsz must be 1"
274
276
 
275
277
  s = q.shape[0] // bsz
276
278
 
@@ -22,8 +22,6 @@ import torch.nn as nn
22
22
  from sglang.srt.custom_op import CustomOp
23
23
  from sglang.srt.utils import is_cuda, is_hip
24
24
 
25
- logger = logging.getLogger(__name__)
26
-
27
25
  _is_cuda = is_cuda()
28
26
  _is_hip = is_hip()
29
27
 
@@ -36,19 +34,9 @@ if _is_cuda:
36
34
  )
37
35
 
38
36
  if _is_hip:
37
+ from vllm._custom_ops import fused_add_rms_norm, rms_norm
39
38
 
40
- from aiter.ops.rmsnorm import rms_norm, rmsnorm2d_fwd_with_add
41
-
42
- rmsnorm = rms_norm
43
-
44
- def fused_add_rmsnorm(
45
- x: torch.Tensor,
46
- residual: torch.Tensor,
47
- w: torch.Tensor,
48
- eps: float,
49
- ) -> Tuple[torch.Tensor, torch.Tensor]:
50
- rmsnorm2d_fwd_with_add(x, x, residual, residual, w, eps)
51
- return x, residual
39
+ logger = logging.getLogger(__name__)
52
40
 
53
41
 
54
42
  class RMSNorm(CustomOp):
@@ -61,23 +49,49 @@ class RMSNorm(CustomOp):
61
49
  self.weight = nn.Parameter(torch.ones(hidden_size))
62
50
  self.variance_epsilon = eps
63
51
 
52
+ def forward(self, *args, **kwargs):
53
+ if torch.compiler.is_compiling():
54
+ return self.forward_native(*args, **kwargs)
55
+ if _is_cuda:
56
+ return self.forward_cuda(*args, **kwargs)
57
+ elif _is_hip:
58
+ return self.forward_hip(*args, **kwargs)
59
+ else:
60
+ return self.forward_native(*args, **kwargs)
61
+
64
62
  def forward_cuda(
65
63
  self,
66
64
  x: torch.Tensor,
67
65
  residual: Optional[torch.Tensor] = None,
68
66
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
69
-
70
67
  if residual is not None:
71
68
  fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
72
69
  return x, residual
73
70
  out = rmsnorm(x, self.weight.data, self.variance_epsilon)
74
71
  return out
75
72
 
73
+ def forward_hip(
74
+ self,
75
+ x: torch.Tensor,
76
+ residual: Optional[torch.Tensor] = None,
77
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
78
+ if not x.is_contiguous():
79
+ # NOTE: Romove this if aiter kernel supports discontinuous input
80
+ x = x.contiguous()
81
+ if residual is not None:
82
+ fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
83
+ return x, residual
84
+ out = torch.empty_like(x)
85
+ rms_norm(out, x, self.weight.data, self.variance_epsilon)
86
+ return out
87
+
76
88
  def forward_native(
77
89
  self,
78
90
  x: torch.Tensor,
79
91
  residual: Optional[torch.Tensor] = None,
80
92
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
93
+ if not x.is_contiguous():
94
+ x = x.contiguous()
81
95
  orig_dtype = x.dtype
82
96
  x = x.to(torch.float32)
83
97
  if residual is not None:
@@ -103,6 +117,14 @@ class GemmaRMSNorm(CustomOp):
103
117
  self.weight = nn.Parameter(torch.zeros(hidden_size))
104
118
  self.variance_epsilon = eps
105
119
 
120
+ def forward(self, *args, **kwargs):
121
+ if torch.compiler.is_compiling():
122
+ return self.forward_native(*args, **kwargs)
123
+ if _is_cuda:
124
+ return self.forward_cuda(*args, **kwargs)
125
+ else:
126
+ return self.forward_native(*args, **kwargs)
127
+
106
128
  def forward_native(
107
129
  self,
108
130
  x: torch.Tensor,
@@ -156,6 +178,6 @@ class Gemma3RMSNorm(nn.Module):
156
178
 
157
179
  if not (_is_cuda or _is_hip):
158
180
  logger.info(
159
- "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
181
+ "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
160
182
  )
161
183
  from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
@@ -335,13 +335,13 @@ class LogitsProcessor(nn.Module):
335
335
  aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
336
336
  hidden_states_to_store = (
337
337
  aux_pruned_states[sample_indices]
338
- if sample_indices
338
+ if sample_indices is not None
339
339
  else aux_pruned_states
340
340
  )
341
341
  else:
342
342
  hidden_states_to_store = (
343
343
  pruned_states[sample_indices]
344
- if sample_indices
344
+ if sample_indices is not None
345
345
  else pruned_states
346
346
  )
347
347
  else:
@@ -8,6 +8,7 @@ from typing import Callable, Optional
8
8
  import torch
9
9
  from torch.nn import functional as F
10
10
 
11
+ from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
11
12
  from sglang.srt.layers.moe.topk import select_experts
12
13
 
13
14
 
@@ -30,7 +31,7 @@ def fused_moe_forward_native(
30
31
  ) -> torch.Tensor:
31
32
 
32
33
  if apply_router_weight_on_input:
33
- raise NotImplementedError
34
+ raise NotImplementedError()
34
35
 
35
36
  topk_weights, topk_ids = select_experts(
36
37
  hidden_states=x,
@@ -75,9 +76,6 @@ def moe_forward_native(
75
76
  activation: str = "silu",
76
77
  routed_scaling_factor: Optional[float] = None,
77
78
  ) -> torch.Tensor:
78
-
79
- from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
80
-
81
79
  topk_weights, topk_ids = select_experts(
82
80
  hidden_states=x,
83
81
  router_logits=router_logits,