sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/configs/internvl.py +3 -0
  3. sglang/srt/configs/model_config.py +7 -0
  4. sglang/srt/constrained/base_grammar_backend.py +10 -2
  5. sglang/srt/constrained/xgrammar_backend.py +7 -5
  6. sglang/srt/conversation.py +16 -1
  7. sglang/srt/debug_utils/__init__.py +0 -0
  8. sglang/srt/debug_utils/dump_comparator.py +131 -0
  9. sglang/srt/debug_utils/dumper.py +108 -0
  10. sglang/srt/debug_utils/text_comparator.py +172 -0
  11. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  12. sglang/srt/disaggregation/mooncake/conn.py +16 -0
  13. sglang/srt/disaggregation/prefill.py +13 -1
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +13 -1
  16. sglang/srt/entrypoints/openai/protocol.py +3 -1
  17. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  18. sglang/srt/entrypoints/openai/serving_chat.py +132 -79
  19. sglang/srt/function_call/ebnf_composer.py +10 -3
  20. sglang/srt/function_call/function_call_parser.py +2 -0
  21. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  22. sglang/srt/function_call/qwen3_coder_detector.py +1 -0
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  24. sglang/srt/layers/attention/vision.py +56 -8
  25. sglang/srt/layers/layernorm.py +26 -1
  26. sglang/srt/layers/logits_processor.py +14 -3
  27. sglang/srt/layers/moe/ep_moe/layer.py +323 -242
  28. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  33. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  34. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  35. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  36. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  37. sglang/srt/layers/moe/topk.py +90 -24
  38. sglang/srt/layers/multimodal.py +11 -8
  39. sglang/srt/layers/quantization/fp8.py +25 -247
  40. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  41. sglang/srt/layers/quantization/modelopt_quant.py +27 -10
  42. sglang/srt/layers/quantization/unquant.py +24 -76
  43. sglang/srt/layers/quantization/w4afp8.py +68 -17
  44. sglang/srt/lora/lora_registry.py +93 -29
  45. sglang/srt/managers/cache_controller.py +9 -7
  46. sglang/srt/managers/data_parallel_controller.py +4 -0
  47. sglang/srt/managers/io_struct.py +12 -0
  48. sglang/srt/managers/mm_utils.py +154 -35
  49. sglang/srt/managers/multimodal_processor.py +3 -14
  50. sglang/srt/managers/schedule_batch.py +14 -8
  51. sglang/srt/managers/scheduler.py +64 -1
  52. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  53. sglang/srt/managers/tokenizer_manager.py +80 -15
  54. sglang/srt/managers/tp_worker.py +8 -0
  55. sglang/srt/mem_cache/hiradix_cache.py +5 -2
  56. sglang/srt/model_executor/model_runner.py +83 -27
  57. sglang/srt/models/deepseek_v2.py +75 -84
  58. sglang/srt/models/glm4_moe.py +1035 -0
  59. sglang/srt/models/glm4_moe_nextn.py +167 -0
  60. sglang/srt/models/interns1.py +328 -0
  61. sglang/srt/models/internvl.py +143 -47
  62. sglang/srt/models/llava.py +9 -5
  63. sglang/srt/models/minicpmo.py +4 -1
  64. sglang/srt/models/qwen2_moe.py +2 -2
  65. sglang/srt/models/qwen3_moe.py +17 -71
  66. sglang/srt/multimodal/processors/base_processor.py +20 -6
  67. sglang/srt/multimodal/processors/clip.py +2 -2
  68. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  69. sglang/srt/multimodal/processors/gemma3.py +2 -2
  70. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  71. sglang/srt/multimodal/processors/internvl.py +21 -8
  72. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  73. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  74. sglang/srt/multimodal/processors/llava.py +4 -4
  75. sglang/srt/multimodal/processors/minicpm.py +2 -3
  76. sglang/srt/multimodal/processors/mlama.py +2 -2
  77. sglang/srt/multimodal/processors/mllama4.py +18 -111
  78. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  79. sglang/srt/multimodal/processors/pixtral.py +2 -2
  80. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  81. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  82. sglang/srt/multimodal/processors/vila.py +3 -1
  83. sglang/srt/poll_based_barrier.py +31 -0
  84. sglang/srt/reasoning_parser.py +2 -1
  85. sglang/srt/server_args.py +65 -6
  86. sglang/srt/two_batch_overlap.py +8 -3
  87. sglang/srt/utils.py +96 -1
  88. sglang/srt/weight_sync/utils.py +119 -0
  89. sglang/test/runners.py +4 -0
  90. sglang/test/test_utils.py +118 -5
  91. sglang/utils.py +19 -0
  92. sglang/version.py +1 -1
  93. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
  94. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
  95. sglang/srt/debug_utils.py +0 -74
  96. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
  97. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
  98. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -27,6 +27,7 @@ import threading
27
27
  import time
28
28
  import uuid
29
29
  from collections import deque
30
+ from contextlib import nullcontext
30
31
  from datetime import datetime
31
32
  from http import HTTPStatus
32
33
  from typing import (
@@ -69,6 +70,7 @@ from sglang.srt.managers.io_struct import (
69
70
  BatchMultimodalOut,
70
71
  BatchStrOut,
71
72
  BatchTokenIDOut,
73
+ BlockReqType,
72
74
  CloseSessionReqInput,
73
75
  ConfigureLoggingReq,
74
76
  EmbeddingReqInput,
@@ -112,7 +114,9 @@ from sglang.srt.managers.io_struct import (
112
114
  UpdateWeightsFromTensorReqInput,
113
115
  UpdateWeightsFromTensorReqOutput,
114
116
  )
117
+ from sglang.srt.managers.mm_utils import TensorTransportMode
115
118
  from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
119
+ from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
116
120
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
117
121
  from sglang.srt.sampling.sampling_params import SamplingParams
118
122
  from sglang.srt.server_args import PortArgs, ServerArgs
@@ -166,6 +170,16 @@ class ReqState:
166
170
  output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
167
171
 
168
172
 
173
+ def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
174
+ is_cross_node = server_args.dist_init_addr
175
+
176
+ if is_cross_node:
177
+ # Fallback to default CPU transport for multi-node
178
+ return "default"
179
+ else:
180
+ return "cuda_ipc"
181
+
182
+
169
183
  class TokenizerManager:
170
184
  """TokenizerManager is a process that tokenizes the text."""
171
185
 
@@ -216,12 +230,13 @@ class TokenizerManager:
216
230
  revision=server_args.revision,
217
231
  use_fast=not server_args.disable_fast_image_processor,
218
232
  )
233
+ transport_mode = _determine_tensor_transport_mode(self.server_args)
219
234
 
220
235
  # We want to parallelize the image pre-processing so we create an executor for it
221
236
  # We create mm_processor for any skip_tokenizer_init to make sure we still encode
222
237
  # images even with skip_tokenizer_init=False.
223
238
  self.mm_processor = get_mm_processor(
224
- self.model_config.hf_config, server_args, _processor
239
+ self.model_config.hf_config, server_args, _processor, transport_mode
225
240
  )
226
241
 
227
242
  if server_args.skip_tokenizer_init:
@@ -270,6 +285,11 @@ class TokenizerManager:
270
285
  None
271
286
  )
272
287
 
288
+ # Lock to serialize LoRA update operations.
289
+ # Please note that, unlike `model_update_lock`, this does not block inference, allowing
290
+ # LoRA updates and inference to overlap.
291
+ self.lora_update_lock = asyncio.Lock()
292
+
273
293
  # For pd disaggregtion
274
294
  self.disaggregation_mode = DisaggregationMode(
275
295
  self.server_args.disaggregation_mode
@@ -525,7 +545,8 @@ class TokenizerManager:
525
545
  mm_inputs = None
526
546
 
527
547
  if self.server_args.enable_lora and obj.lora_path:
528
- # Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
548
+ # Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
549
+ # `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
529
550
  obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
530
551
 
531
552
  self._validate_one_request(obj, input_ids)
@@ -735,6 +756,10 @@ class TokenizerManager:
735
756
  msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
736
757
  logger.info(msg)
737
758
 
759
+ # Mark ongoing LoRA request as finished.
760
+ if self.server_args.enable_lora and obj.lora_path:
761
+ await self.lora_registry.release(obj.lora_path)
762
+
738
763
  # Check if this was an abort/error created by scheduler
739
764
  if isinstance(out["meta_info"].get("finish_reason"), dict):
740
765
  finish_reason = out["meta_info"]["finish_reason"]
@@ -744,6 +769,19 @@ class TokenizerManager:
744
769
  ):
745
770
  raise ValueError(finish_reason["message"])
746
771
 
772
+ if (
773
+ finish_reason.get("type") == "abort"
774
+ and finish_reason.get("status_code")
775
+ == HTTPStatus.SERVICE_UNAVAILABLE
776
+ ):
777
+ # This is an abort request initiated by scheduler.
778
+ # Delete the key to prevent resending abort request to the scheduler and
779
+ # to ensure aborted request state is cleaned up.
780
+ del self.rid_to_state[state.obj.rid]
781
+ raise fastapi.HTTPException(
782
+ status_code=finish_reason["status_code"],
783
+ detail=finish_reason["message"],
784
+ )
747
785
  yield out
748
786
  break
749
787
 
@@ -784,12 +822,21 @@ class TokenizerManager:
784
822
  rids.append(tmp_obj.rid)
785
823
  else:
786
824
  # Sequential tokenization and processing
787
- for i in range(batch_size):
788
- tmp_obj = obj[i]
789
- tokenized_obj = await self._tokenize_one_request(tmp_obj)
790
- state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
791
- generators.append(self._wait_one_response(tmp_obj, state, request))
792
- rids.append(tmp_obj.rid)
825
+ with (
826
+ input_blocker_guard_region(send_to_scheduler=self.send_to_scheduler)
827
+ if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
828
+ else nullcontext()
829
+ ):
830
+ for i in range(batch_size):
831
+ tmp_obj = obj[i]
832
+ tokenized_obj = await self._tokenize_one_request(tmp_obj)
833
+ state = self._send_one_request(
834
+ tmp_obj, tokenized_obj, created_time
835
+ )
836
+ generators.append(
837
+ self._wait_one_response(tmp_obj, state, request)
838
+ )
839
+ rids.append(tmp_obj.rid)
793
840
  else:
794
841
  # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
795
842
  if batch_size > 128:
@@ -1041,16 +1088,18 @@ class TokenizerManager:
1041
1088
  obj.lora_path,
1042
1089
  )
1043
1090
 
1044
- async with self.model_update_lock.writer_lock:
1091
+ async with self.lora_update_lock:
1045
1092
  # Generate new uniquely identifiable LoRARef object.
1046
1093
  new_adapter = LoRARef(
1047
1094
  lora_name=obj.lora_name,
1048
1095
  lora_path=obj.lora_path,
1049
1096
  )
1050
1097
 
1051
- # Register the new adapter in the registry.
1098
+ # Trigger the actual loading operation at the backend processes.
1052
1099
  obj.lora_id = new_adapter.lora_id
1053
1100
  result = (await self.update_lora_adapter_communicator(obj))[0]
1101
+
1102
+ # Register the LoRA adapter only after loading is successful.
1054
1103
  if result.success:
1055
1104
  await self.lora_registry.register(new_adapter)
1056
1105
 
@@ -1081,8 +1130,15 @@ class TokenizerManager:
1081
1130
  obj.lora_name,
1082
1131
  )
1083
1132
 
1084
- async with self.model_update_lock.writer_lock:
1085
- obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
1133
+ async with self.lora_update_lock:
1134
+ # Unregister the LoRA adapter from the registry to stop new requests for this adapter
1135
+ # from being started.
1136
+ lora_id = await self.lora_registry.unregister(obj.lora_name)
1137
+ obj.lora_id = lora_id
1138
+
1139
+ # Initiate the actual unloading operation at the backend processes only after all
1140
+ # ongoing requests using this LoRA adapter are finished.
1141
+ await self.lora_registry.wait_for_unload(lora_id)
1086
1142
  result = (await self.update_lora_adapter_communicator(obj))[0]
1087
1143
 
1088
1144
  return result
@@ -1674,8 +1730,15 @@ class TokenizerManager:
1674
1730
  def _handle_abort_req(self, recv_obj):
1675
1731
  state = self.rid_to_state[recv_obj.rid]
1676
1732
  state.finished = True
1677
- state.out_list.append(
1678
- {
1733
+ if recv_obj.finished_reason:
1734
+ out = {
1735
+ "meta_info": {
1736
+ "id": recv_obj.rid,
1737
+ "finish_reason": recv_obj.finished_reason,
1738
+ },
1739
+ }
1740
+ else:
1741
+ out = {
1679
1742
  "text": "",
1680
1743
  "meta_info": {
1681
1744
  "id": recv_obj.rid,
@@ -1687,7 +1750,7 @@ class TokenizerManager:
1687
1750
  "completion_tokens": 0,
1688
1751
  },
1689
1752
  }
1690
- )
1753
+ state.out_list.append(out)
1691
1754
  state.event.set()
1692
1755
 
1693
1756
  def _handle_open_session_req_output(self, recv_obj):
@@ -1879,8 +1942,10 @@ class _Communicator(Generic[T]):
1879
1942
  #
1880
1943
  # | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
1881
1944
  # | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
1945
+ # | http | yes | validation | background task | fast api | del in _handle_abort_req |
1882
1946
  # | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
1883
1947
  # | http | yes | running | background task | fast api | del in _handle_batch_output |
1948
+ # | http | no | validation | http exception | http exception | del in _handle_abort_req |
1884
1949
  # | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
1885
1950
  # | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
1886
1951
  #
@@ -41,6 +41,7 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
41
41
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
42
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
43
43
  from sglang.srt.model_executor.model_runner import ModelRunner
44
+ from sglang.srt.patch_torch import monkey_patch_torch_reductions
44
45
  from sglang.srt.server_args import ServerArgs
45
46
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
46
47
 
@@ -129,6 +130,10 @@ class TpModelWorker:
129
130
  self.model_runner.req_to_token_pool.size,
130
131
  )
131
132
  assert self.max_running_requests > 0, "max_running_request is zero"
133
+ self.max_queued_requests = server_args.max_queued_requests
134
+ assert (
135
+ self.max_running_requests > 0
136
+ ), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
132
137
  self.max_req_len = min(
133
138
  self.model_config.context_len - 1,
134
139
  self.max_total_num_tokens - 1,
@@ -164,6 +169,7 @@ class TpModelWorker:
164
169
  self.max_total_num_tokens,
165
170
  self.max_prefill_tokens,
166
171
  self.max_running_requests,
172
+ self.max_queued_requests,
167
173
  self.max_req_len,
168
174
  self.max_req_input_len,
169
175
  self.random_seed,
@@ -278,6 +284,8 @@ class TpModelWorker:
278
284
  return success, message
279
285
 
280
286
  def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
287
+
288
+ monkey_patch_torch_reductions()
281
289
  success, message = self.model_runner.update_weights_from_tensor(
282
290
  named_tensors=MultiprocessingSerializer.deserialize(
283
291
  recv_req.serialized_named_tensors[self.tp_rank]
@@ -365,10 +365,12 @@ class HiRadixCache(RadixCache):
365
365
  for _ in range(queue_size.item()):
366
366
  req_id = self.cache_controller.prefetch_revoke_queue.get()
367
367
  if req_id in self.ongoing_prefetch:
368
- last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id]
368
+ last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
369
369
  last_host_node.release_host()
370
- self.cache_controller.mem_pool_host.free(host_indices)
371
370
  del self.ongoing_prefetch[req_id]
371
+ else:
372
+ # the revoked operation already got terminated
373
+ pass
372
374
 
373
375
  def check_backup_progress(self):
374
376
  queue_size = torch.tensor(
@@ -403,6 +405,7 @@ class HiRadixCache(RadixCache):
403
405
  last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
404
406
  req_id
405
407
  ]
408
+
406
409
  completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
407
410
  operation
408
411
  )
@@ -285,11 +285,21 @@ class ModelRunner:
285
285
  if architectures and not any("Llama4" in arch for arch in architectures):
286
286
  self.is_hybrid = self.model_config.is_hybrid = True
287
287
 
288
- self.start_layer = getattr(self.model, "start_layer", 0)
289
- self.end_layer = getattr(
290
- self.model, "end_layer", self.model_config.num_hidden_layers
288
+ # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
289
+ # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
290
+ # determine the number of layers.
291
+ model_has_mtp_layers = self.model_config.num_nextn_predict_layers is not None
292
+ model_num_layers = (
293
+ self.model_config.num_nextn_predict_layers
294
+ if self.is_draft_worker and model_has_mtp_layers
295
+ else self.model_config.num_hidden_layers
291
296
  )
297
+ self.start_layer = getattr(self.model, "start_layer", 0)
298
+ self.end_layer = getattr(self.model, "end_layer", model_num_layers)
292
299
  self.num_effective_layers = self.end_layer - self.start_layer
300
+ assert (not model_has_mtp_layers) or (
301
+ self.num_effective_layers == model_num_layers
302
+ ), "PP is not compatible with MTP models."
293
303
 
294
304
  # Apply torchao quantization
295
305
  torchao_applied = getattr(self.model, "torchao_applied", False)
@@ -1178,11 +1188,7 @@ class ModelRunner:
1178
1188
  dtype=self.kv_cache_dtype,
1179
1189
  kv_lora_rank=self.model_config.kv_lora_rank,
1180
1190
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1181
- layer_num=(
1182
- self.model_config.num_hidden_layers
1183
- if not self.is_draft_worker
1184
- else self.model_config.hf_config.num_nextn_predict_layers
1185
- ), # PP is not compatible with mla backend
1191
+ layer_num=self.num_effective_layers,
1186
1192
  device=self.device,
1187
1193
  enable_memory_saver=self.server_args.enable_memory_saver,
1188
1194
  start_layer=self.start_layer,
@@ -1195,11 +1201,7 @@ class ModelRunner:
1195
1201
  dtype=self.kv_cache_dtype,
1196
1202
  kv_lora_rank=self.model_config.kv_lora_rank,
1197
1203
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1198
- layer_num=(
1199
- self.model_config.num_hidden_layers
1200
- if not self.is_draft_worker
1201
- else self.model_config.hf_config.num_nextn_predict_layers
1202
- ), # PP is not compatible with mla backend
1204
+ layer_num=self.num_effective_layers,
1203
1205
  device=self.device,
1204
1206
  enable_memory_saver=self.server_args.enable_memory_saver,
1205
1207
  start_layer=self.start_layer,
@@ -1308,9 +1310,58 @@ class ModelRunner:
1308
1310
  else:
1309
1311
  self.attn_backend = self._get_attention_backend()
1310
1312
 
1311
- # TODO unify with 6338
1312
1313
  def _get_attention_backend(self):
1313
- if self.server_args.attention_backend == "flashinfer":
1314
+ """Init attention kernel backend."""
1315
+ self.decode_attention_backend_str = (
1316
+ self.server_args.decode_attention_backend
1317
+ if self.server_args.decode_attention_backend
1318
+ else self.server_args.attention_backend
1319
+ )
1320
+ self.prefill_attention_backend_str = (
1321
+ self.server_args.prefill_attention_backend
1322
+ if self.server_args.prefill_attention_backend
1323
+ else self.server_args.attention_backend
1324
+ )
1325
+ if self.decode_attention_backend_str != self.prefill_attention_backend_str:
1326
+ assert (
1327
+ self.server_args.speculative_algorithm is None
1328
+ ), "Currently HybridAttentionBackend does not support speculative decoding."
1329
+ from sglang.srt.layers.attention.hybrid_attn_backend import (
1330
+ HybridAttnBackend,
1331
+ )
1332
+
1333
+ attn_backend = HybridAttnBackend(
1334
+ decode_backend=self._get_attention_backend_from_str(
1335
+ self.decode_attention_backend_str
1336
+ ),
1337
+ prefill_backend=self._get_attention_backend_from_str(
1338
+ self.prefill_attention_backend_str
1339
+ ),
1340
+ )
1341
+ logger.info(
1342
+ f"Using hybrid attention backend for decode and prefill: "
1343
+ f"decode_backend={self.decode_attention_backend_str}, "
1344
+ f"prefill_backend={self.prefill_attention_backend_str}."
1345
+ )
1346
+ logger.warning(
1347
+ f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
1348
+ f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
1349
+ )
1350
+ else:
1351
+ attn_backend = self._get_attention_backend_from_str(
1352
+ self.server_args.attention_backend
1353
+ )
1354
+
1355
+ global_server_args_dict.update(
1356
+ {
1357
+ "decode_attention_backend": self.decode_attention_backend_str,
1358
+ "prefill_attention_backend": self.prefill_attention_backend_str,
1359
+ }
1360
+ )
1361
+ return attn_backend
1362
+
1363
+ def _get_attention_backend_from_str(self, backend_str: str):
1364
+ if backend_str == "flashinfer":
1314
1365
  if not self.use_mla_backend:
1315
1366
  from sglang.srt.layers.attention.flashinfer_backend import (
1316
1367
  FlashInferAttnBackend,
@@ -1318,7 +1369,11 @@ class ModelRunner:
1318
1369
 
1319
1370
  # Init streams
1320
1371
  if self.server_args.speculative_algorithm == "EAGLE":
1321
- self.plan_stream_for_flashinfer = torch.cuda.Stream()
1372
+ if (
1373
+ not hasattr(self, "plan_stream_for_flashinfer")
1374
+ or not self.plan_stream_for_flashinfer
1375
+ ):
1376
+ self.plan_stream_for_flashinfer = torch.cuda.Stream()
1322
1377
  return FlashInferAttnBackend(self)
1323
1378
  else:
1324
1379
  from sglang.srt.layers.attention.flashinfer_mla_backend import (
@@ -1326,15 +1381,15 @@ class ModelRunner:
1326
1381
  )
1327
1382
 
1328
1383
  return FlashInferMLAAttnBackend(self)
1329
- elif self.server_args.attention_backend == "aiter":
1384
+ elif backend_str == "aiter":
1330
1385
  from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
1331
1386
 
1332
1387
  return AiterAttnBackend(self)
1333
- elif self.server_args.attention_backend == "ascend":
1388
+ elif backend_str == "ascend":
1334
1389
  from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
1335
1390
 
1336
1391
  return AscendAttnBackend(self)
1337
- elif self.server_args.attention_backend == "triton":
1392
+ elif backend_str == "triton":
1338
1393
  assert not self.model_config.is_encoder_decoder, (
1339
1394
  "Cross attention is not supported in the triton attention backend. "
1340
1395
  "Please use `--attention-backend flashinfer`."
@@ -1349,17 +1404,17 @@ class ModelRunner:
1349
1404
  from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
1350
1405
 
1351
1406
  return TritonAttnBackend(self)
1352
- elif self.server_args.attention_backend == "torch_native":
1407
+ elif backend_str == "torch_native":
1353
1408
  from sglang.srt.layers.attention.torch_native_backend import (
1354
1409
  TorchNativeAttnBackend,
1355
1410
  )
1356
1411
 
1357
1412
  return TorchNativeAttnBackend(self)
1358
- elif self.server_args.attention_backend == "flashmla":
1413
+ elif backend_str == "flashmla":
1359
1414
  from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
1360
1415
 
1361
1416
  return FlashMLABackend(self)
1362
- elif self.server_args.attention_backend == "fa3":
1417
+ elif backend_str == "fa3":
1363
1418
  assert (
1364
1419
  torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
1365
1420
  ) or torch.cuda.get_device_capability()[0] == 9, (
@@ -1371,7 +1426,7 @@ class ModelRunner:
1371
1426
  )
1372
1427
 
1373
1428
  return FlashAttentionBackend(self)
1374
- elif self.server_args.attention_backend == "cutlass_mla":
1429
+ elif backend_str == "cutlass_mla":
1375
1430
  from sglang.srt.layers.attention.cutlass_mla_backend import (
1376
1431
  CutlassMLABackend,
1377
1432
  )
@@ -1385,9 +1440,7 @@ class ModelRunner:
1385
1440
  logger.info(f"Intel AMX attention backend is enabled.")
1386
1441
  return IntelAMXAttnBackend(self)
1387
1442
  else:
1388
- raise ValueError(
1389
- f"Invalid attention backend: {self.server_args.attention_backend}"
1390
- )
1443
+ raise ValueError(f"Invalid attention backend: {backend_str}")
1391
1444
 
1392
1445
  def init_double_sparsity_channel_config(self, selected_channel):
1393
1446
  selected_channel = "." + selected_channel + "_proj"
@@ -1475,7 +1528,10 @@ class ModelRunner:
1475
1528
  if self.support_pp:
1476
1529
  kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1477
1530
  return self.model.forward(
1478
- forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
1531
+ forward_batch.input_ids,
1532
+ forward_batch.positions,
1533
+ forward_batch,
1534
+ **kwargs,
1479
1535
  )
1480
1536
 
1481
1537
  def forward_extend(