sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. sglang/bench_one_batch.py +1 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/lang/chat_template.py +44 -0
  4. sglang/srt/configs/deepseekvl2.py +3 -0
  5. sglang/srt/configs/device_config.py +1 -1
  6. sglang/srt/configs/internvl.py +696 -0
  7. sglang/srt/configs/janus_pro.py +3 -0
  8. sglang/srt/configs/model_config.py +17 -0
  9. sglang/srt/constrained/xgrammar_backend.py +11 -19
  10. sglang/srt/conversation.py +30 -3
  11. sglang/srt/disaggregation/decode.py +4 -1
  12. sglang/srt/disaggregation/mini_lb.py +74 -23
  13. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  14. sglang/srt/disaggregation/nixl/conn.py +241 -71
  15. sglang/srt/disaggregation/utils.py +44 -1
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  17. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  19. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  20. sglang/srt/distributed/parallel_state.py +22 -1
  21. sglang/srt/entrypoints/engine.py +14 -2
  22. sglang/srt/entrypoints/http_server.py +28 -1
  23. sglang/srt/entrypoints/verl_engine.py +3 -2
  24. sglang/srt/hf_transformers_utils.py +20 -1
  25. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  26. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  27. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  28. sglang/srt/layers/attention/merge_state.py +46 -0
  29. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  30. sglang/srt/layers/attention/vision.py +290 -163
  31. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  32. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  33. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
  37. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  38. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  39. sglang/srt/layers/quantization/deep_gemm.py +5 -0
  40. sglang/srt/layers/quantization/fp8.py +108 -95
  41. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  42. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  43. sglang/srt/layers/quantization/kv_cache.py +3 -10
  44. sglang/srt/layers/quantization/utils.py +0 -5
  45. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  46. sglang/srt/lora/lora_manager.py +10 -13
  47. sglang/srt/managers/cache_controller.py +115 -119
  48. sglang/srt/managers/io_struct.py +10 -0
  49. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  50. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  51. sglang/srt/managers/schedule_batch.py +19 -1
  52. sglang/srt/managers/schedule_policy.py +11 -5
  53. sglang/srt/managers/scheduler.py +28 -13
  54. sglang/srt/managers/tokenizer_manager.py +24 -13
  55. sglang/srt/managers/tp_worker.py +9 -12
  56. sglang/srt/mem_cache/chunk_cache.py +2 -0
  57. sglang/srt/mem_cache/memory_pool.py +2 -2
  58. sglang/srt/model_executor/model_runner.py +44 -33
  59. sglang/srt/model_loader/loader.py +18 -11
  60. sglang/srt/models/clip.py +4 -4
  61. sglang/srt/models/deepseek_janus_pro.py +1 -1
  62. sglang/srt/models/deepseek_nextn.py +1 -20
  63. sglang/srt/models/deepseek_v2.py +55 -20
  64. sglang/srt/models/gemma3_mm.py +1 -1
  65. sglang/srt/models/internlm2.py +3 -0
  66. sglang/srt/models/internvl.py +670 -0
  67. sglang/srt/models/llama.py +1 -1
  68. sglang/srt/models/llama4.py +53 -7
  69. sglang/srt/models/minicpmv.py +1 -1
  70. sglang/srt/models/mllama.py +1 -1
  71. sglang/srt/models/phi3_small.py +16 -2
  72. sglang/srt/models/qwen2_5_vl.py +8 -4
  73. sglang/srt/models/qwen2_vl.py +4 -4
  74. sglang/srt/models/xiaomi_mimo.py +171 -0
  75. sglang/srt/openai_api/adapter.py +24 -40
  76. sglang/srt/openai_api/protocol.py +28 -16
  77. sglang/srt/reasoning_parser.py +2 -2
  78. sglang/srt/sampling/sampling_batch_info.py +54 -2
  79. sglang/srt/sampling/sampling_params.py +2 -0
  80. sglang/srt/server_args.py +30 -6
  81. sglang/srt/utils.py +35 -1
  82. sglang/test/test_block_fp8.py +2 -2
  83. sglang/test/test_deepep_utils.py +219 -0
  84. sglang/test/test_utils.py +3 -1
  85. sglang/version.py +1 -1
  86. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
  87. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
  88. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  89. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  90. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,7 @@ from torch.distributed import Backend, ProcessGroup
42
42
  from sglang.srt.utils import (
43
43
  direct_register_custom_op,
44
44
  is_cuda_alike,
45
+ is_npu,
45
46
  supports_custom_op,
46
47
  )
47
48
 
@@ -206,6 +207,7 @@ class GroupCoordinator:
206
207
  use_custom_allreduce: bool,
207
208
  use_hpu_communicator: bool,
208
209
  use_xpu_communicator: bool,
210
+ use_npu_communicator: bool,
209
211
  use_message_queue_broadcaster: bool = False,
210
212
  group_name: Optional[str] = None,
211
213
  ):
@@ -244,6 +246,7 @@ class GroupCoordinator:
244
246
  self.use_custom_allreduce = use_custom_allreduce
245
247
  self.use_hpu_communicator = use_hpu_communicator
246
248
  self.use_xpu_communicator = use_xpu_communicator
249
+ self.use_npu_communicator = use_npu_communicator
247
250
  self.use_message_queue_broadcaster = use_message_queue_broadcaster
248
251
 
249
252
  # lazy import to avoid documentation build error
@@ -291,6 +294,14 @@ class GroupCoordinator:
291
294
  if use_xpu_communicator and self.world_size > 1:
292
295
  self.xpu_communicator = XpuCommunicator(group=self.device_group)
293
296
 
297
+ from sglang.srt.distributed.device_communicators.npu_communicator import (
298
+ NpuCommunicator,
299
+ )
300
+
301
+ self.npu_communicator: Optional[NpuCommunicator] = None
302
+ if use_npu_communicator and self.world_size > 1:
303
+ self.npu_communicator = NpuCommunicator(group=self.device_group)
304
+
294
305
  from sglang.srt.distributed.device_communicators.shm_broadcast import (
295
306
  MessageQueue,
296
307
  )
@@ -418,6 +429,9 @@ class GroupCoordinator:
418
429
  if self.xpu_communicator is not None and not self.xpu_communicator.disabled:
419
430
  return self.xpu_communicator.all_reduce(input_)
420
431
 
432
+ if self.npu_communicator is not None and not self.npu_communicator.disabled:
433
+ return self.npu_communicator.all_reduce(input_)
434
+
421
435
  if (
422
436
  self.ca_comm is not None
423
437
  and not self.ca_comm.disabled
@@ -497,6 +511,11 @@ class GroupCoordinator:
497
511
  if hpu_comm is not None and not hpu_comm.disabled:
498
512
  return hpu_comm.all_gather(input_, dim)
499
513
 
514
+ # For NPUs, use NPU communicator.
515
+ npu_comm = self.npu_communicator
516
+ if npu_comm is not None and not npu_comm.disabled:
517
+ return npu_comm.all_gather(input_, dim)
518
+
500
519
  if dim < 0:
501
520
  # Convert negative dim to positive.
502
521
  dim += input_.dim()
@@ -941,6 +960,7 @@ def init_world_group(
941
960
  use_custom_allreduce=False,
942
961
  use_hpu_communicator=False,
943
962
  use_xpu_communicator=False,
963
+ use_npu_communicator=False,
944
964
  group_name="world",
945
965
  )
946
966
 
@@ -959,10 +979,11 @@ def init_model_parallel_group(
959
979
  group_ranks=group_ranks,
960
980
  local_rank=local_rank,
961
981
  torch_distributed_backend=backend,
962
- use_pynccl=True,
982
+ use_pynccl=not is_npu(),
963
983
  use_custom_allreduce=use_custom_allreduce,
964
984
  use_hpu_communicator=True,
965
985
  use_xpu_communicator=True,
986
+ use_npu_communicator=True,
966
987
  use_message_queue_broadcaster=use_message_queue_broadcaster,
967
988
  group_name=group_name,
968
989
  )
@@ -163,6 +163,9 @@ class Engine(EngineBase):
163
163
  custom_logit_processor: Optional[Union[List[str], str]] = None,
164
164
  return_hidden_states: bool = False,
165
165
  stream: bool = False,
166
+ bootstrap_host: Optional[Union[List[str], str]] = None,
167
+ bootstrap_port: Optional[Union[List[int], int]] = None,
168
+ bootstrap_room: Optional[Union[List[int], int]] = None,
166
169
  ) -> Union[Dict, Iterator[Dict]]:
167
170
  """
168
171
  The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
@@ -181,6 +184,9 @@ class Engine(EngineBase):
181
184
  custom_logit_processor=custom_logit_processor,
182
185
  return_hidden_states=return_hidden_states,
183
186
  stream=stream,
187
+ bootstrap_host=bootstrap_host,
188
+ bootstrap_port=bootstrap_port,
189
+ bootstrap_room=bootstrap_room,
184
190
  )
185
191
  loop = asyncio.get_event_loop()
186
192
  generator = self.tokenizer_manager.generate_request(obj, None)
@@ -227,6 +233,9 @@ class Engine(EngineBase):
227
233
  lora_path: Optional[List[Optional[str]]] = None,
228
234
  custom_logit_processor: Optional[Union[List[str], str]] = None,
229
235
  stream: bool = False,
236
+ bootstrap_host: Optional[Union[List[str], str]] = None,
237
+ bootstrap_port: Optional[Union[List[int], int]] = None,
238
+ bootstrap_room: Optional[Union[List[int], int]] = None,
230
239
  ) -> Union[Dict, AsyncIterator[Dict]]:
231
240
  """
232
241
  The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
@@ -244,6 +253,9 @@ class Engine(EngineBase):
244
253
  lora_path=lora_path,
245
254
  stream=stream,
246
255
  custom_logit_processor=custom_logit_processor,
256
+ bootstrap_host=bootstrap_host,
257
+ bootstrap_port=bootstrap_port,
258
+ bootstrap_room=bootstrap_room,
247
259
  )
248
260
  generator = self.tokenizer_manager.generate_request(obj, None)
249
261
 
@@ -348,8 +360,8 @@ class Engine(EngineBase):
348
360
  load_format: Optional[str] = None,
349
361
  flush_cache: bool = True,
350
362
  ):
351
- """Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
352
- to avoid duplicated operations such as clearing cache."""
363
+ """Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false
364
+ to avoid duplicated cache cleaning operation."""
353
365
  obj = UpdateWeightsFromTensorReqInput(
354
366
  serialized_named_tensors=[
355
367
  MultiprocessingSerializer.serialize(named_tensors)
@@ -42,7 +42,10 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
42
42
  from fastapi.middleware.cors import CORSMiddleware
43
43
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
44
44
 
45
- from sglang.srt.disaggregation.utils import FakeBootstrapHost
45
+ from sglang.srt.disaggregation.utils import (
46
+ FakeBootstrapHost,
47
+ register_disaggregation_server,
48
+ )
46
49
  from sglang.srt.entrypoints.engine import _launch_subprocesses
47
50
  from sglang.srt.function_call_parser import FunctionCallParser
48
51
  from sglang.srt.managers.io_struct import (
@@ -59,6 +62,7 @@ from sglang.srt.managers.io_struct import (
59
62
  ResumeMemoryOccupationReqInput,
60
63
  SeparateReasoningReqInput,
61
64
  SetInternalStateReq,
65
+ SlowDownReqInput,
62
66
  UpdateWeightFromDiskReqInput,
63
67
  UpdateWeightsFromDistributedReqInput,
64
68
  UpdateWeightsFromTensorReqInput,
@@ -491,6 +495,19 @@ async def resume_memory_occupation(
491
495
  return _create_error_response(e)
492
496
 
493
497
 
498
+ @app.api_route("/slow_down", methods=["GET", "POST"])
499
+ async def slow_down(obj: SlowDownReqInput, request: Request):
500
+ """Slow down the system deliberately. Only for testing. Example scenario:
501
+ when we want to test performance of D in large-scale PD disaggregation and have no enough nodes for P,
502
+ we can use this to slow down D to let it have enough running sequences, and then disable slowdown
503
+ to let it run in full batch size.
504
+ """
505
+ try:
506
+ await _global_state.tokenizer_manager.slow_down(obj, request)
507
+ except Exception as e:
508
+ return _create_error_response(e)
509
+
510
+
494
511
  @app.api_route("/open_session", methods=["GET", "POST"])
495
512
  async def open_session(obj: OpenSessionReqInput, request: Request):
496
513
  """Open a session, and return its unique session id."""
@@ -675,6 +692,8 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
675
692
  **(vertex_req.parameters or {}),
676
693
  )
677
694
  ret = await generate_request(req, raw_request)
695
+ if isinstance(ret, Response):
696
+ return ret
678
697
  return ORJSONResponse({"predictions": ret})
679
698
 
680
699
 
@@ -869,5 +888,13 @@ def _wait_and_warmup(
869
888
  if server_args.debug_tensor_dump_input_file:
870
889
  kill_process_tree(os.getpid())
871
890
 
891
+ if server_args.pdlb_url is not None:
892
+ register_disaggregation_server(
893
+ server_args.disaggregation_mode,
894
+ server_args.port,
895
+ server_args.disaggregation_bootstrap_port,
896
+ server_args.pdlb_url,
897
+ )
898
+
872
899
  if launch_callback is not None:
873
900
  launch_callback()
@@ -37,6 +37,7 @@ class VerlEngine:
37
37
  monkey_patch_torch_reductions()
38
38
  self._device_mesh_cpu = device_mesh_cpu
39
39
  self._tp_rank = device_mesh_cpu.get_local_rank()
40
+ self._rank = device_mesh_cpu.get_rank()
40
41
  self._tp_size = device_mesh_cpu.size()
41
42
  tp_size_per_node = self._tp_size // nnodes
42
43
  node_rank = self._tp_rank // tp_size_per_node
@@ -114,7 +115,7 @@ class VerlEngine:
114
115
  # Most naive implementation, can extract tensor and send via gloo if too slow
115
116
  [output] = broadcast_pyobj(
116
117
  data=[output],
117
- rank=self._tp_rank,
118
+ rank=self._rank,
118
119
  dist_group=self._device_mesh_cpu.get_group(),
119
120
  src=self._device_mesh_cpu.mesh[0].item(),
120
121
  force_cpu_device=False,
@@ -157,7 +158,7 @@ class VerlEngine:
157
158
  )
158
159
 
159
160
  if self._tp_rank == 0:
160
- self._engine.tokenizer_manager.flush_cache()
161
+ self._engine.flush_cache()
161
162
 
162
163
  def release_memory_occupation(self):
163
164
  if self._tp_rank == 0:
@@ -19,6 +19,7 @@ import warnings
19
19
  from pathlib import Path
20
20
  from typing import Dict, Optional, Type, Union
21
21
 
22
+ import transformers
22
23
  from huggingface_hub import snapshot_download
23
24
  from transformers import (
24
25
  AutoConfig,
@@ -26,6 +27,7 @@ from transformers import (
26
27
  AutoTokenizer,
27
28
  PretrainedConfig,
28
29
  PreTrainedTokenizer,
30
+ PreTrainedTokenizerBase,
29
31
  PreTrainedTokenizerFast,
30
32
  )
31
33
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
@@ -38,6 +40,7 @@ from sglang.srt.configs import (
38
40
  KimiVLConfig,
39
41
  MultiModalityConfig,
40
42
  )
43
+ from sglang.srt.configs.internvl import InternVLChatConfig
41
44
  from sglang.srt.connector import create_remote_connector
42
45
  from sglang.srt.utils import is_remote_url
43
46
 
@@ -48,6 +51,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
48
51
  DeepseekVL2Config.model_type: DeepseekVL2Config,
49
52
  MultiModalityConfig.model_type: MultiModalityConfig,
50
53
  KimiVLConfig.model_type: KimiVLConfig,
54
+ InternVLChatConfig.model_type: InternVLChatConfig,
51
55
  }
52
56
 
53
57
  for name, cls in _CONFIG_REGISTRY.items():
@@ -90,6 +94,12 @@ def get_config(
90
94
  config = config_class.from_pretrained(model, revision=revision)
91
95
  # NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
92
96
  setattr(config, "_name_or_path", model)
97
+
98
+ if isinstance(model, str) and config.model_type == "internvl_chat":
99
+ for key, val in config.llm_config.__dict__.items():
100
+ if not hasattr(config, key):
101
+ setattr(config, key, val)
102
+
93
103
  if model_override_args:
94
104
  config.update(model_override_args)
95
105
 
@@ -211,6 +221,13 @@ def get_tokenizer(
211
221
  return tokenizer
212
222
 
213
223
 
224
+ # Some models doesn't have an available processor, e.g.: InternVL
225
+ def get_tokenizer_from_processor(processor):
226
+ if isinstance(processor, PreTrainedTokenizerBase):
227
+ return processor
228
+ return processor.tokenizer
229
+
230
+
214
231
  def get_processor(
215
232
  tokenizer_name: str,
216
233
  *args,
@@ -246,7 +263,9 @@ def get_processor(
246
263
  **kwargs,
247
264
  )
248
265
 
249
- attach_additional_stop_token_ids(processor.tokenizer)
266
+ tokenizer = get_tokenizer_from_processor(processor)
267
+
268
+ attach_additional_stop_token_ids(tokenizer)
250
269
  return processor
251
270
 
252
271
 
@@ -338,7 +338,7 @@ class FlashAttentionBackend(AttentionBackend):
338
338
  """Initialize forward metadata hence all layers in the forward pass can reuse it."""
339
339
  metadata = FlashAttentionMetadata()
340
340
  seqlens_in_batch = forward_batch.seq_lens
341
- batch_size = len(seqlens_in_batch)
341
+ batch_size = forward_batch.batch_size
342
342
  device = seqlens_in_batch.device
343
343
 
344
344
  if forward_batch.forward_mode.is_decode_or_idle():
@@ -913,8 +913,10 @@ class FlashAttentionBackend(AttentionBackend):
913
913
  # Use precomputed metadata across all layers
914
914
  metadata = self.forward_metadata
915
915
  local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
916
- use_local_attention = (
917
- self.attention_chunk_size is not None and local_attn_metadata is not None
916
+ use_local_attn = (
917
+ self.attention_chunk_size is not None
918
+ and local_attn_metadata is not None
919
+ and (hasattr(layer, "use_irope") and layer.use_irope)
918
920
  )
919
921
  # We do cascade attention for Draft Decode with topk > 1
920
922
  use_cascade_attn = self.topk > 1
@@ -970,7 +972,7 @@ class FlashAttentionBackend(AttentionBackend):
970
972
  k_descale=k_descale,
971
973
  v_descale=v_descale,
972
974
  )
973
- elif use_local_attention:
975
+ elif use_local_attn:
974
976
  # Use chunked (local) attention batching for self-attention
975
977
  o = flash_attn_with_kvcache(
976
978
  q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -979,7 +981,7 @@ class FlashAttentionBackend(AttentionBackend):
979
981
  page_table=local_attn_metadata.local_block_table,
980
982
  cache_seqlens=local_attn_metadata.local_seqused_k,
981
983
  cu_seqlens_q=local_attn_metadata.local_query_start_loc,
982
- cu_seqlens_k_new=metadata.cu_seqlens_k,
984
+ cu_seqlens_k_new=None,
983
985
  max_seqlen_q=local_attn_metadata.local_max_query_len,
984
986
  softmax_scale=layer.scaling,
985
987
  causal=True,
@@ -1127,7 +1129,6 @@ class FlashAttentionBackend(AttentionBackend):
1127
1129
  This creates fixed-size tensors that will be reused during CUDA graph replay
1128
1130
  to avoid memory allocations.
1129
1131
  """
1130
-
1131
1132
  # This is being used by normal decode and draft decode when topk == 1
1132
1133
  self.decode_cuda_graph_metadata = {
1133
1134
  "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
@@ -1154,6 +1155,34 @@ class FlashAttentionBackend(AttentionBackend):
1154
1155
  ),
1155
1156
  }
1156
1157
 
1158
+ # Only allocate local attention buffers if local attention is enabled
1159
+ # This prevents OOM errors when local attention is not being used
1160
+ if self.attention_chunk_size is not None:
1161
+ # Estimate maximum sizes for local attention metadata
1162
+ max_seq_len = self.max_context_len
1163
+ page_size = self.page_size or 1
1164
+ attn_chunk_size = self.attention_chunk_size
1165
+ max_virtual_batches = max_bs * (
1166
+ (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
1167
+ )
1168
+ max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
1169
+ max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
1170
+
1171
+ self.decode_cuda_graph_local_attn_metadata = {
1172
+ "local_query_start_loc": torch.zeros(
1173
+ max_virtual_batches + 1, dtype=torch.int32, device=self.device
1174
+ ),
1175
+ "local_seqused_k": torch.zeros(
1176
+ max_virtual_batches, dtype=torch.int32, device=self.device
1177
+ ),
1178
+ "local_block_table": torch.zeros(
1179
+ max_virtual_batches,
1180
+ max_blocks_per_seq * max_pages_per_block,
1181
+ dtype=torch.int32,
1182
+ device=self.device,
1183
+ ),
1184
+ }
1185
+
1157
1186
  # This is used by draft decode's first half of metadata when topk > 1
1158
1187
  if self.topk > 1:
1159
1188
  self.draft_decode_metadata_topk_normal = {
@@ -1405,6 +1434,21 @@ class FlashAttentionBackend(AttentionBackend):
1405
1434
  )
1406
1435
  self.decode_cuda_graph_metadata[bs] = metadata
1407
1436
 
1437
+ if self.attention_chunk_size is not None:
1438
+ metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
1439
+ local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[
1440
+ "local_query_start_loc"
1441
+ ],
1442
+ local_seqused_k=self.decode_cuda_graph_local_attn_metadata[
1443
+ "local_seqused_k"
1444
+ ],
1445
+ local_block_table=self.decode_cuda_graph_local_attn_metadata[
1446
+ "local_block_table"
1447
+ ],
1448
+ local_max_query_len=1,
1449
+ local_max_seq_len=1,
1450
+ )
1451
+
1408
1452
  elif forward_mode.is_target_verify():
1409
1453
  if self.topk <= 1:
1410
1454
  metadata.cache_seqlens_int32 = self.target_verify_metadata[
@@ -1525,12 +1569,9 @@ class FlashAttentionBackend(AttentionBackend):
1525
1569
  metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
1526
1570
  self.speculative_step_id + 1
1527
1571
  )
1528
- metadata.cu_seqlens_k.copy_(
1529
- torch.nn.functional.pad(
1530
- torch.cumsum(
1531
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1532
- ),
1533
- (1, 0),
1572
+ metadata.cu_seqlens_k[1:].copy_(
1573
+ torch.cumsum(
1574
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1534
1575
  )
1535
1576
  )
1536
1577
 
@@ -1554,12 +1595,9 @@ class FlashAttentionBackend(AttentionBackend):
1554
1595
  # metadata.max_seq_len_q = self.topk, already set in capture
1555
1596
  metadata.max_seq_len_k = seq_lens_cpu.max().item()
1556
1597
  # metadata.cu_seqlens_q already set in capture
1557
- metadata.cu_seqlens_k.copy_(
1558
- torch.nn.functional.pad(
1559
- torch.cumsum(
1560
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1561
- ),
1562
- (1, 0),
1598
+ metadata.cu_seqlens_k[1:].copy_(
1599
+ torch.cumsum(
1600
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1563
1601
  )
1564
1602
  )
1565
1603
 
@@ -1578,8 +1616,7 @@ class FlashAttentionBackend(AttentionBackend):
1578
1616
  metadata_expand.page_table[: cache_loc.shape[0]].copy_(
1579
1617
  cache_loc[:, :decode_length].contiguous().to(torch.int32)
1580
1618
  )
1581
- # TODO: we need to test this part for llama 4 eagle case
1582
- self._init_local_attn_metadata(metadata, device)
1619
+ # TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
1583
1620
  else:
1584
1621
  metadata = self.decode_cuda_graph_metadata[bs]
1585
1622
  # Normal Decode
@@ -1587,8 +1624,9 @@ class FlashAttentionBackend(AttentionBackend):
1587
1624
  metadata.max_seq_len_k = max_len
1588
1625
 
1589
1626
  metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
1590
- metadata.cu_seqlens_k = torch.nn.functional.pad(
1591
- torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
1627
+ # Optimize cumulative sequence length calculation
1628
+ metadata.cu_seqlens_k[1:].copy_(
1629
+ torch.cumsum(seq_lens, dim=0, dtype=torch.int32)
1592
1630
  )
1593
1631
 
1594
1632
  max_seq_pages = (
@@ -1604,7 +1642,7 @@ class FlashAttentionBackend(AttentionBackend):
1604
1642
  metadata.page_table[:, :max_seq_pages].copy_(page_indices)
1605
1643
  metadata.page_table[:, max_seq_pages:].fill_(0)
1606
1644
 
1607
- self._init_local_attn_metadata(metadata, device)
1645
+ self._update_local_attn_metadata_for_replay(metadata, bs)
1608
1646
  elif forward_mode.is_target_verify():
1609
1647
  if self.topk <= 1:
1610
1648
  metadata = self.target_verify_metadata[bs]
@@ -1615,13 +1653,8 @@ class FlashAttentionBackend(AttentionBackend):
1615
1653
  metadata.max_seq_len_k = (
1616
1654
  seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
1617
1655
  )
1618
- metadata.cu_seqlens_k.copy_(
1619
- torch.nn.functional.pad(
1620
- torch.cumsum(
1621
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1622
- ),
1623
- (1, 0),
1624
- )
1656
+ metadata.cu_seqlens_k[1:].copy_(
1657
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
1625
1658
  )
1626
1659
  max_seq_pages = (
1627
1660
  metadata.max_seq_len_k + self.page_size - 1
@@ -1640,13 +1673,8 @@ class FlashAttentionBackend(AttentionBackend):
1640
1673
  # metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
1641
1674
  metadata.max_seq_len_k = seq_lens_cpu.max().item()
1642
1675
  # metadata.cu_seqlens_q already set in capture
1643
- metadata.cu_seqlens_k.copy_(
1644
- torch.nn.functional.pad(
1645
- torch.cumsum(
1646
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1647
- ),
1648
- (1, 0),
1649
- )
1676
+ metadata.cu_seqlens_k[1:].copy_(
1677
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
1650
1678
  )
1651
1679
  page_table = self.req_to_token[
1652
1680
  req_pool_indices, : metadata.max_seq_len_k
@@ -1704,14 +1732,11 @@ class FlashAttentionBackend(AttentionBackend):
1704
1732
  metadata_expand.cache_seqlens_int32.copy_(
1705
1733
  mask.sum(dim=1).to(torch.int32)
1706
1734
  )
1707
- metadata_expand.cu_seqlens_k.copy_(
1708
- torch.nn.functional.pad(
1709
- torch.cumsum(
1710
- metadata_expand.cache_seqlens_int32,
1711
- dim=0,
1712
- dtype=torch.int32,
1713
- ),
1714
- (1, 0),
1735
+ metadata_expand.cu_seqlens_k[1:].copy_(
1736
+ torch.cumsum(
1737
+ metadata_expand.cache_seqlens_int32,
1738
+ dim=0,
1739
+ dtype=torch.int32,
1715
1740
  )
1716
1741
  )
1717
1742
  metadata_expand.max_seq_len_k = (
@@ -1722,11 +1747,8 @@ class FlashAttentionBackend(AttentionBackend):
1722
1747
  # Only support encoder size 1 for now
1723
1748
  metadata.encoder_max_seq_len_k = encoder_lens[0]
1724
1749
  metadata.encoder_lens_int32.copy_(encoder_lens[:1])
1725
- metadata.encoder_cu_seqlens_k.copy_(
1726
- torch.nn.functional.pad(
1727
- torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
1728
- (1, 0),
1729
- )
1750
+ metadata.encoder_cu_seqlens_k[1:].copy_(
1751
+ torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32)
1730
1752
  )
1731
1753
 
1732
1754
  metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
@@ -1776,6 +1798,7 @@ class FlashAttentionBackend(AttentionBackend):
1776
1798
  page_table,
1777
1799
  self.page_size,
1778
1800
  )
1801
+
1779
1802
  local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
1780
1803
  local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
1781
1804
  local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
@@ -1785,6 +1808,79 @@ class FlashAttentionBackend(AttentionBackend):
1785
1808
  )
1786
1809
  metadata.local_attn_metadata = local_metadata
1787
1810
 
1811
+ def _update_local_attn_metadata_for_replay(
1812
+ self, metadata: FlashAttentionMetadata, bs: int
1813
+ ):
1814
+ """Update preallocated local attention metadata in-place before CUDA graph replay."""
1815
+ if self.attention_chunk_size is None:
1816
+ return
1817
+
1818
+ # Access preallocated buffers
1819
+ local_q_buf = self.decode_cuda_graph_local_attn_metadata[
1820
+ "local_query_start_loc"
1821
+ ]
1822
+ local_k_buf = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"]
1823
+ local_block_buf = self.decode_cuda_graph_local_attn_metadata[
1824
+ "local_block_table"
1825
+ ]
1826
+ cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"]
1827
+
1828
+ # Create a modified version for local attention that only processes the last token
1829
+ # This mimics the normal decode pattern
1830
+ cu_seqlens_q = torch.arange(
1831
+ bs + 1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype
1832
+ )
1833
+ seqlens = metadata.cache_seqlens_int32[:bs]
1834
+ # Slice the page_table to match the batch size and actual sequence length
1835
+ # This serves three important purposes:
1836
+ # 1. Ensures we only process the actual batch size (bs) and not the maximum batch size
1837
+ # 2. Limits the sequence length to prevent processing padding tokens or garbage values
1838
+ # 3. Prevents zeros in the block table which can cause garbage output during replay
1839
+ #
1840
+ # Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
1841
+ # beyond the actual sequence length, leading to incorrect attention calculations
1842
+ max_seq_len = int(seqlens.max().item())
1843
+ sliced_page_table = metadata.page_table[:bs, :max_seq_len]
1844
+
1845
+ cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
1846
+ seqlens_np = seqlens.cpu().numpy()
1847
+ (
1848
+ seqlens_q_local_np,
1849
+ cu_seqlens_q_local_np,
1850
+ seqlens_k_local_np,
1851
+ block_table_local,
1852
+ ) = make_local_attention_virtual_batches(
1853
+ self.attention_chunk_size,
1854
+ cu_seqlens_q_np,
1855
+ seqlens_np,
1856
+ sliced_page_table,
1857
+ self.page_size,
1858
+ )
1859
+
1860
+ # Convert back to tensors
1861
+ device = local_q_buf.device
1862
+ cu_seqlens_q_local = torch.from_numpy(cu_seqlens_q_local_np).to(device)
1863
+ seqlens_k_local = torch.from_numpy(seqlens_k_local_np).to(device)
1864
+ block_table_local = block_table_local.to(device)
1865
+ # Get sizes
1866
+ q_len = cu_seqlens_q_local.shape[0]
1867
+ k_len = seqlens_k_local.shape[0]
1868
+ b0, b1 = block_table_local.shape
1869
+
1870
+ # In-place updates into preallocated tensors and zero out the unused space
1871
+ local_q_buf[:q_len].copy_(cu_seqlens_q_local)
1872
+ local_q_buf[q_len:].fill_(0)
1873
+ local_k_buf[:k_len].copy_(seqlens_k_local)
1874
+ local_k_buf[k_len:].fill_(0)
1875
+ local_block_buf[:b0, :b1].copy_(block_table_local)
1876
+ local_block_buf[b0:, :].fill_(0)
1877
+ local_block_buf[:b0, b1:].fill_(0)
1878
+
1879
+ if metadata.local_attn_metadata is not None:
1880
+ lam = metadata.local_attn_metadata
1881
+ lam.local_max_query_len = int(seqlens_q_local_np.max())
1882
+ lam.local_max_seq_len = int(seqlens_k_local_np.max())
1883
+
1788
1884
 
1789
1885
  class FlashAttentionMultiStepBackend:
1790
1886
 
@@ -16,8 +16,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
16
16
  import torch
17
17
 
18
18
  if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
19
- import torch._dynamo
19
+ import logging
20
20
 
21
+ torch._logging.set_logs(dynamo=logging.ERROR)
21
22
  torch._dynamo.config.suppress_errors = True
22
23
 
23
24
  from sglang.global_config import global_config
@@ -107,6 +108,7 @@ class FlashInferAttnBackend(AttentionBackend):
107
108
  if (
108
109
  "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures
109
110
  or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
111
+ or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures
110
112
  ):
111
113
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
112
114
 
@@ -416,6 +418,7 @@ class FlashInferAttnBackend(AttentionBackend):
416
418
 
417
419
  logits_soft_cap = layer.logit_cap
418
420
 
421
+ q = q.contiguous()
419
422
  if not self.forward_metadata.use_ragged:
420
423
  if k is not None:
421
424
  assert v is not None
@@ -425,7 +428,7 @@ class FlashInferAttnBackend(AttentionBackend):
425
428
  )
426
429
 
427
430
  o = prefill_wrapper_paged.forward(
428
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
431
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
429
432
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
430
433
  causal=not layer.is_cross_attention,
431
434
  sm_scale=layer.scaling,
@@ -435,20 +438,27 @@ class FlashInferAttnBackend(AttentionBackend):
435
438
  v_scale=layer.v_scale,
436
439
  )
437
440
  else:
438
- o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
439
- q.view(-1, layer.tp_q_head_num, layer.head_dim),
440
- k.view(-1, layer.tp_k_head_num, layer.head_dim),
441
- v.view(-1, layer.tp_v_head_num, layer.head_dim),
442
- causal=True,
443
- sm_scale=layer.scaling,
444
- logits_soft_cap=logits_soft_cap,
445
- )
446
-
447
441
  if self.forward_metadata.extend_no_prefix:
448
- o = o1
442
+ o = self.prefill_wrapper_ragged.forward(
443
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
444
+ k.view(-1, layer.tp_k_head_num, layer.head_dim),
445
+ v.view(-1, layer.tp_v_head_num, layer.head_dim),
446
+ causal=True,
447
+ sm_scale=layer.scaling,
448
+ logits_soft_cap=logits_soft_cap,
449
+ )
450
+
449
451
  else:
452
+ o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
453
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
454
+ k.view(-1, layer.tp_k_head_num, layer.head_dim),
455
+ v.view(-1, layer.tp_v_head_num, layer.head_dim),
456
+ causal=True,
457
+ sm_scale=layer.scaling,
458
+ logits_soft_cap=logits_soft_cap,
459
+ )
450
460
  o2, s2 = prefill_wrapper_paged.forward_return_lse(
451
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
461
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
452
462
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
453
463
  causal=False,
454
464
  sm_scale=layer.scaling,