sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/srt/configs/model_config.py +35 -0
  3. sglang/srt/conversation.py +9 -5
  4. sglang/srt/disaggregation/base/conn.py +5 -2
  5. sglang/srt/disaggregation/decode.py +6 -1
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  8. sglang/srt/disaggregation/prefill.py +2 -0
  9. sglang/srt/distributed/parallel_state.py +11 -9
  10. sglang/srt/entrypoints/context.py +244 -0
  11. sglang/srt/entrypoints/engine.py +4 -3
  12. sglang/srt/entrypoints/harmony_utils.py +370 -0
  13. sglang/srt/entrypoints/http_server.py +71 -0
  14. sglang/srt/entrypoints/openai/protocol.py +227 -1
  15. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  16. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  18. sglang/srt/entrypoints/tool.py +87 -0
  19. sglang/srt/eplb/expert_location.py +5 -1
  20. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  21. sglang/srt/hf_transformers_utils.py +30 -3
  22. sglang/srt/jinja_template_utils.py +8 -1
  23. sglang/srt/layers/attention/aiter_backend.py +5 -8
  24. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  25. sglang/srt/layers/attention/triton_backend.py +85 -14
  26. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  28. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  29. sglang/srt/layers/attention/vision.py +13 -5
  30. sglang/srt/layers/communicator.py +21 -4
  31. sglang/srt/layers/dp_attention.py +12 -0
  32. sglang/srt/layers/linear.py +2 -7
  33. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  34. sglang/srt/layers/moe/ep_moe/layer.py +77 -73
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
  37. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  38. sglang/srt/layers/moe/topk.py +12 -3
  39. sglang/srt/layers/moe/utils.py +16 -0
  40. sglang/srt/layers/quantization/__init__.py +22 -0
  41. sglang/srt/layers/quantization/fp4.py +557 -0
  42. sglang/srt/layers/quantization/fp8.py +3 -6
  43. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  44. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  45. sglang/srt/layers/quantization/mxfp4.py +651 -0
  46. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  47. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  48. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  49. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  50. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  51. sglang/srt/layers/quantization/quark/utils.py +107 -0
  52. sglang/srt/layers/quantization/unquant.py +60 -6
  53. sglang/srt/layers/quantization/w4afp8.py +1 -1
  54. sglang/srt/layers/rotary_embedding.py +225 -1
  55. sglang/srt/layers/utils.py +9 -0
  56. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  57. sglang/srt/lora/lora_manager.py +70 -14
  58. sglang/srt/lora/lora_registry.py +3 -2
  59. sglang/srt/lora/mem_pool.py +43 -5
  60. sglang/srt/managers/cache_controller.py +55 -30
  61. sglang/srt/managers/detokenizer_manager.py +1 -1
  62. sglang/srt/managers/io_struct.py +15 -3
  63. sglang/srt/managers/mm_utils.py +5 -11
  64. sglang/srt/managers/schedule_batch.py +28 -7
  65. sglang/srt/managers/scheduler.py +26 -12
  66. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  67. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  68. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  69. sglang/srt/managers/template_manager.py +35 -1
  70. sglang/srt/managers/tokenizer_manager.py +24 -6
  71. sglang/srt/managers/tp_worker.py +3 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  73. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  74. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  75. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  76. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  77. sglang/srt/model_executor/cuda_graph_runner.py +7 -6
  78. sglang/srt/model_executor/forward_batch_info.py +35 -14
  79. sglang/srt/model_executor/model_runner.py +19 -2
  80. sglang/srt/model_loader/weight_utils.py +10 -0
  81. sglang/srt/models/bailing_moe.py +425 -0
  82. sglang/srt/models/deepseek_v2.py +72 -33
  83. sglang/srt/models/ernie4.py +426 -0
  84. sglang/srt/models/ernie4_eagle.py +203 -0
  85. sglang/srt/models/gemma3n_mm.py +39 -0
  86. sglang/srt/models/glm4_moe.py +24 -12
  87. sglang/srt/models/gpt_oss.py +1134 -0
  88. sglang/srt/models/qwen2.py +6 -0
  89. sglang/srt/models/qwen2_moe.py +6 -0
  90. sglang/srt/models/qwen3_moe.py +32 -6
  91. sglang/srt/models/step3_vl.py +9 -0
  92. sglang/srt/models/transformers.py +2 -5
  93. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  94. sglang/srt/reasoning_parser.py +18 -39
  95. sglang/srt/server_args.py +142 -7
  96. sglang/srt/two_batch_overlap.py +157 -5
  97. sglang/srt/utils.py +38 -2
  98. sglang/test/runners.py +2 -2
  99. sglang/test/test_utils.py +1 -1
  100. sglang/version.py +1 -1
  101. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
  102. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
  103. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  104. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  105. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -556,7 +556,7 @@ class TokenizerManager:
556
556
  if self.server_args.enable_lora and obj.lora_path:
557
557
  # Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
558
558
  # `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
559
- obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
559
+ obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
560
560
 
561
561
  self._validate_one_request(obj, input_ids)
562
562
  return self._create_tokenized_object(
@@ -665,7 +665,7 @@ class TokenizerManager:
665
665
  bootstrap_host=obj.bootstrap_host,
666
666
  bootstrap_port=obj.bootstrap_port,
667
667
  bootstrap_room=obj.bootstrap_room,
668
- lora_path=obj.lora_path,
668
+ lora_id=obj.lora_id,
669
669
  input_embeds=input_embeds,
670
670
  session_params=session_params,
671
671
  custom_logit_processor=obj.custom_logit_processor,
@@ -750,7 +750,11 @@ class TokenizerManager:
750
750
  try:
751
751
  await asyncio.wait_for(state.event.wait(), timeout=4)
752
752
  except asyncio.TimeoutError:
753
- if request is not None and await request.is_disconnected():
753
+ if (
754
+ request is not None
755
+ and not obj.background
756
+ and await request.is_disconnected()
757
+ ):
754
758
  # Abort the request for disconnected requests (non-streaming, waiting queue)
755
759
  self.abort_request(obj.rid)
756
760
  # Use exception to kill the whole call stack and asyncio task
@@ -773,7 +777,7 @@ class TokenizerManager:
773
777
 
774
778
  # Mark ongoing LoRA request as finished.
775
779
  if self.server_args.enable_lora and obj.lora_path:
776
- await self.lora_registry.release(obj.lora_path)
780
+ await self.lora_registry.release(obj.lora_id)
777
781
 
778
782
  # Check if this was an abort/error created by scheduler
779
783
  if isinstance(out["meta_info"].get("finish_reason"), dict):
@@ -805,7 +809,11 @@ class TokenizerManager:
805
809
  if obj.stream:
806
810
  yield out
807
811
  else:
808
- if request is not None and await request.is_disconnected():
812
+ if (
813
+ request is not None
814
+ and not obj.background
815
+ and await request.is_disconnected()
816
+ ):
809
817
  # Abort the request for disconnected requests (non-streaming, running)
810
818
  self.abort_request(obj.rid)
811
819
  # Use exception to kill the whole call stack and asyncio task
@@ -1121,6 +1129,7 @@ class TokenizerManager:
1121
1129
  new_adapter = LoRARef(
1122
1130
  lora_name=obj.lora_name,
1123
1131
  lora_path=obj.lora_path,
1132
+ pinned=obj.pinned,
1124
1133
  )
1125
1134
 
1126
1135
  # Trigger the actual loading operation at the backend processes.
@@ -1178,7 +1187,7 @@ class TokenizerManager:
1178
1187
 
1179
1188
  return result
1180
1189
  except ValueError as e:
1181
- return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e))
1190
+ return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
1182
1191
 
1183
1192
  async def get_weights_by_name(
1184
1193
  self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
@@ -1548,8 +1557,17 @@ class TokenizerManager:
1548
1557
 
1549
1558
  if isinstance(recv_obj, BatchStrOut):
1550
1559
  state.text += recv_obj.output_strs[i]
1560
+ if state.obj.stream:
1561
+ state.output_ids.extend(recv_obj.output_ids[i])
1562
+ output_token_ids = state.output_ids[state.last_output_offset :]
1563
+ state.last_output_offset = len(state.output_ids)
1564
+ else:
1565
+ state.output_ids.extend(recv_obj.output_ids[i])
1566
+ output_token_ids = state.output_ids.copy()
1567
+
1551
1568
  out_dict = {
1552
1569
  "text": state.text,
1570
+ "output_ids": output_token_ids,
1553
1571
  "meta_info": meta_info,
1554
1572
  }
1555
1573
  elif isinstance(recv_obj, BatchTokenIDOut):
@@ -311,3 +311,6 @@ class TpModelWorker:
311
311
  def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
312
312
  result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
313
313
  return result
314
+
315
+ def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
316
+ return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
@@ -288,6 +288,9 @@ class TpModelWorkerClient:
288
288
  def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
289
289
  return self.worker.unload_lora_adapter(recv_req)
290
290
 
291
+ def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
292
+ return self.worker.can_run_lora_batch(lora_ids)
293
+
291
294
  def __delete__(self):
292
295
  self.input_queue.put((None, None))
293
296
  self.copy_queue.put((None, None, None))
@@ -2,11 +2,12 @@ import heapq
2
2
  import logging
3
3
  import threading
4
4
  import time
5
+ from queue import Queue
5
6
  from typing import List, Optional
6
7
 
7
8
  import torch
8
9
 
9
- from sglang.srt.managers.cache_controller import HiCacheController
10
+ from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
10
11
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
11
12
  from sglang.srt.mem_cache.base_prefix_cache import MatchResult
12
13
  from sglang.srt.mem_cache.memory_pool import (
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
37
38
  hicache_io_backend: str,
38
39
  hicache_mem_layout: str,
39
40
  hicache_storage_backend: Optional[str] = None,
41
+ hicache_storage_prefetch_policy: Optional[str] = "best_effort",
40
42
  ):
41
43
 
42
44
  if hicache_io_backend == "direct":
@@ -85,6 +87,13 @@ class HiRadixCache(RadixCache):
85
87
  prefetch_threshold=self.prefetch_threshold,
86
88
  )
87
89
 
90
+ self.prefetch_stop_policy = hicache_storage_prefetch_policy
91
+ # todo: customizable storage prefetch timeout
92
+ self.prefetch_timeout = 3 # seconds
93
+ logger.info(
94
+ f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
95
+ )
96
+
88
97
  # record the nodes with ongoing write through
89
98
  self.ongoing_write_through = {}
90
99
  # record the node segments with ongoing load back
@@ -385,9 +394,10 @@ class HiRadixCache(RadixCache):
385
394
  for _ in range(queue_size.item()):
386
395
  req_id = self.cache_controller.prefetch_revoke_queue.get()
387
396
  if req_id in self.ongoing_prefetch:
388
- last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
397
+ last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
389
398
  last_host_node.release_host()
390
399
  del self.ongoing_prefetch[req_id]
400
+ self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
391
401
  else:
392
402
  # the revoked operation already got terminated
393
403
  pass
@@ -419,10 +429,41 @@ class HiRadixCache(RadixCache):
419
429
  host_node.release_host()
420
430
  del self.ongoing_backup[ack_id]
421
431
 
422
- def check_prefetch_progress(self, req_id: str):
432
+ def can_terminate_prefetch(self, operation: PrefetchOperation):
433
+ can_terminate = True
434
+
435
+ if self.prefetch_stop_policy == "best_effort":
436
+ return can_terminate
437
+
438
+ completed = (
439
+ operation.completed_tokens == len(operation.hash_value) * self.page_size
440
+ )
441
+
442
+ if self.prefetch_stop_policy == "wait_complete":
443
+ can_terminate = completed
444
+ elif self.prefetch_stop_policy == "timeout":
445
+ can_terminate = completed or (
446
+ time.monotonic() - operation.start_time > self.prefetch_timeout
447
+ )
448
+ else:
449
+ # unknown prefetch stop policy, just return True
450
+ return True
451
+
452
+ if self.tp_world_size > 1:
453
+ can_terminate = torch.tensor(can_terminate, dtype=torch.int)
454
+ torch.distributed.all_reduce(
455
+ can_terminate,
456
+ op=torch.distributed.ReduceOp.MIN,
457
+ group=self.tp_group,
458
+ )
459
+ can_terminate = bool(can_terminate.item())
460
+
461
+ return can_terminate
462
+
463
+ def check_prefetch_progress(self, req_id: str) -> bool:
423
464
  if req_id not in self.ongoing_prefetch:
424
465
  # there is no ongoing prefetch for this request or it has been revoked
425
- return
466
+ return True
426
467
 
427
468
  # todo: more policies for prefetch progress such as timeout
428
469
  # the current policy is to prefetch with best effort and terminate when queuing is over
@@ -430,13 +471,16 @@ class HiRadixCache(RadixCache):
430
471
  req_id
431
472
  ]
432
473
 
474
+ if not self.can_terminate_prefetch(operation):
475
+ return False
476
+
433
477
  completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
434
478
  operation
435
479
  )
436
480
  logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
437
481
 
438
482
  min_completed_tokens = completed_tokens
439
- if self.tp_world_size > 1:
483
+ if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
440
484
  # synchrnoize TP workers to make the same update to hiradix cache
441
485
  completed_tokens_tensor = torch.tensor(
442
486
  min_completed_tokens, dtype=torch.int
@@ -464,6 +508,9 @@ class HiRadixCache(RadixCache):
464
508
  )
465
509
  last_host_node.release_host()
466
510
  del self.ongoing_prefetch[req_id]
511
+ self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
512
+
513
+ return True
467
514
 
468
515
  def match_prefix(self, key: List[int], **kwargs):
469
516
  empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
@@ -531,6 +578,7 @@ class HiRadixCache(RadixCache):
531
578
  host_indices,
532
579
  operation,
533
580
  )
581
+ self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
534
582
 
535
583
  def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
536
584
  node.last_access_time = time.monotonic()
@@ -618,7 +618,7 @@ class MLATokenToKVPoolHost(HostKVCache):
618
618
  elif self.layout == "page_first":
619
619
  transfer_kv_all_layer_mla_lf_pf(
620
620
  src_layers=device_pool.data_ptrs,
621
- dst_k=self.kv_buffer,
621
+ dst=self.kv_buffer,
622
622
  src_indices=device_indices,
623
623
  dst_indices=host_indices,
624
624
  item_size=self.token_stride_size,
@@ -1,24 +1,46 @@
1
+ import logging
2
+ from collections import OrderedDict
1
3
  from typing import Dict
2
4
 
3
5
  import torch
4
6
 
7
+ # Set up logging for cache behavior
8
+ logger = logging.getLogger(__name__)
9
+
5
10
 
6
11
  class MultiModalCache:
7
- """MultiModalCache is used to store vlm encoder results"""
12
+ """MultiModalCache is used to store vlm encoder results with LRU eviction"""
8
13
 
9
14
  def __init__(
10
15
  self,
11
16
  max_size: int,
12
17
  ):
13
18
  self.max_size = max_size
14
- self.mm_cache: Dict[int, torch.Tensor] = {}
19
+ self.mm_cache: OrderedDict[int, torch.Tensor] = OrderedDict()
15
20
  self.current_size = 0
16
21
 
22
+ def _allocate(self, embedding_size: int) -> bool:
23
+ """Allocate space by evicting least recently used entries"""
24
+ evictions = 0
25
+ while self.current_size + embedding_size > self.max_size and self.mm_cache:
26
+ _, old_embedding = self.mm_cache.popitem(last=False)
27
+ evicted_size = self._get_tensor_size(old_embedding)
28
+ self.current_size -= evicted_size
29
+ evictions += evicted_size
30
+
31
+ if evictions > 0:
32
+ logger.debug(
33
+ f"Cache eviction: evicted {evictions} bytes, remaining size: {self.current_size}/{self.max_size} bytes"
34
+ )
35
+
36
+ if self.current_size + embedding_size > self.max_size:
37
+ return False
38
+ return True
39
+
17
40
  def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
18
- if mm_hash in self.mm_cache:
19
- return True
20
41
  data_size = self._get_tensor_size(embedding)
21
- if self.current_size + data_size > self.max_size:
42
+ # Lazy free cache if not enough space
43
+ if not self._allocate(data_size):
22
44
  return False
23
45
  self.mm_cache[mm_hash] = embedding
24
46
  self.current_size += data_size
@@ -28,14 +50,12 @@ class MultiModalCache:
28
50
  return mm_hash in self.mm_cache
29
51
 
30
52
  def get(self, mm_hash: int) -> torch.Tensor:
31
- return self.mm_cache.get(mm_hash)
32
-
33
- def free(self, mm_hash: int) -> bool:
34
- if mm_hash not in self.mm_cache:
35
- return False
36
- old_embedding = self.mm_cache.pop(mm_hash)
37
- self.current_size -= self._get_tensor_size(old_embedding)
38
- return True
53
+ """Get embedding and update LRU order"""
54
+ if mm_hash in self.mm_cache:
55
+ # Move to end (most recently used)
56
+ self.mm_cache.move_to_end(mm_hash)
57
+ return self.mm_cache[mm_hash]
58
+ return None
39
59
 
40
60
  def clear(self):
41
61
  self.mm_cache.clear()
@@ -96,6 +96,8 @@ class Hf3fsClient:
96
96
  )
97
97
  self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
98
98
  self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
99
+ self.shm_r.unlink()
100
+ self.shm_w.unlink()
99
101
 
100
102
  self.rlock = threading.RLock()
101
103
  self.wlock = threading.RLock()
@@ -176,8 +178,6 @@ class Hf3fsClient:
176
178
  del self.iov_w
177
179
  self.shm_r.close()
178
180
  self.shm_w.close()
179
- self.shm_r.unlink()
180
- self.shm_w.unlink()
181
181
 
182
182
  def flush(self) -> None:
183
183
  os.fsync(self.file)
@@ -576,11 +576,11 @@ class CudaGraphRunner:
576
576
  )
577
577
 
578
578
  if self.model_runner.server_args.enable_lora:
579
- # It is safe to capture CUDA graph using empty LoRA path, as the LoRA kernels will always be launched whenever
580
- # `--enable-lora` is set to True (and return immediately if the LoRA path is empty for perf optimization).
581
- lora_paths = [None] * bs
579
+ # It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
580
+ # `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
581
+ lora_ids = [None] * bs
582
582
  else:
583
- lora_paths = None
583
+ lora_ids = None
584
584
 
585
585
  forward_batch = ForwardBatch(
586
586
  forward_mode=self.capture_forward_mode,
@@ -589,6 +589,7 @@ class CudaGraphRunner:
589
589
  req_pool_indices=req_pool_indices,
590
590
  seq_lens=seq_lens,
591
591
  next_token_logits_buffer=next_token_logits_buffer,
592
+ orig_seq_lens=seq_lens,
592
593
  req_to_token_pool=self.model_runner.req_to_token_pool,
593
594
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
594
595
  attn_backend=self.model_runner.attn_backend,
@@ -607,11 +608,11 @@ class CudaGraphRunner:
607
608
  capture_hidden_mode=self.capture_hidden_mode,
608
609
  num_token_non_padded=self.num_token_non_padded,
609
610
  global_forward_mode=self.capture_forward_mode,
610
- lora_paths=lora_paths,
611
+ lora_ids=lora_ids,
611
612
  )
612
613
  self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
613
614
 
614
- if lora_paths is not None:
615
+ if lora_ids is not None:
615
616
  self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
616
617
 
617
618
  # Attention backend
@@ -180,6 +180,9 @@ class ForwardBatch:
180
180
  # The sum of all sequence lengths
181
181
  seq_lens_sum: int
182
182
 
183
+ # The original sequence length without being chunked. Qwen-1M related.
184
+ orig_seq_lens: Optional[torch.Tensor] = None
185
+
183
186
  # Optional seq_lens on cpu
184
187
  seq_lens_cpu: Optional[torch.Tensor] = None
185
188
 
@@ -248,7 +251,7 @@ class ForwardBatch:
248
251
  encoder_out_cache_loc: Optional[torch.Tensor] = None
249
252
 
250
253
  # For LoRA
251
- lora_paths: Optional[List[str]] = None
254
+ lora_ids: Optional[List[str]] = None
252
255
 
253
256
  # For input embeddings
254
257
  input_embeds: Optional[torch.Tensor] = None
@@ -321,13 +324,14 @@ class ForwardBatch:
321
324
  encoder_out_cache_loc=batch.encoder_out_cache_loc,
322
325
  seq_lens_sum=batch.seq_lens_sum,
323
326
  seq_lens_cpu=batch.seq_lens_cpu,
327
+ orig_seq_lens=batch.orig_seq_lens,
324
328
  return_logprob=batch.return_logprob,
325
329
  top_logprobs_nums=batch.top_logprobs_nums,
326
330
  token_ids_logprobs=batch.token_ids_logprobs,
327
331
  is_extend_in_batch=batch.is_extend_in_batch,
328
332
  can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
329
333
  global_forward_mode=batch.global_forward_mode,
330
- lora_paths=batch.lora_paths,
334
+ lora_ids=batch.lora_ids,
331
335
  sampling_info=batch.sampling_info,
332
336
  req_to_token_pool=model_runner.req_to_token_pool,
333
337
  token_to_kv_pool=model_runner.token_to_kv_pool,
@@ -420,16 +424,12 @@ class ForwardBatch:
420
424
  batch.extend_prefix_lens, dtype=torch.int32
421
425
  ).to(device, non_blocking=True)
422
426
  ret.extend_num_tokens = batch.extend_num_tokens
423
- if support_triton(model_runner.server_args.attention_backend):
424
- positions, ret.extend_start_loc = compute_position_triton(
425
- ret.extend_prefix_lens,
426
- ret.extend_seq_lens,
427
- ret.extend_num_tokens,
428
- )
429
- else:
430
- positions, ret.extend_start_loc = compute_position_torch(
431
- ret.extend_prefix_lens, ret.extend_seq_lens
432
- )
427
+ positions, ret.extend_start_loc = compute_position(
428
+ model_runner.server_args.attention_backend,
429
+ ret.extend_prefix_lens,
430
+ ret.extend_seq_lens,
431
+ ret.extend_num_tokens,
432
+ )
433
433
  if ret.positions is None:
434
434
  ret.positions = positions
435
435
  ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
@@ -632,8 +632,10 @@ class ForwardBatch:
632
632
  self.dp_padding_mode = dp_padding_mode
633
633
 
634
634
  if dp_padding_mode.is_max_len():
635
- # when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
636
- # where transferred tokens should be padded to the same length.
635
+ # when DP gather mode is all gather, we will use
636
+ # all_gather_into_tensor to gather hidden states, where transferred
637
+ # tokens should be padded to the same length. We will also use
638
+ # reduce-scatter instead of all-reduce after MLP.
637
639
  max_num_tokens = max(global_num_tokens)
638
640
  global_num_tokens = [max_num_tokens] * sync_group_size
639
641
  buffer_len = max_num_tokens * sync_group_size
@@ -882,6 +884,25 @@ class PPProxyTensors:
882
884
  return f"PPProxyTensors(tensors={self.tensors})"
883
885
 
884
886
 
887
+ def compute_position(
888
+ attn_backend: str,
889
+ extend_prefix_lens: torch.Tensor,
890
+ extend_seq_lens: torch.Tensor,
891
+ extend_seq_lens_sum: int,
892
+ ):
893
+ if support_triton(attn_backend):
894
+ positions, extend_start_loc = compute_position_triton(
895
+ extend_prefix_lens,
896
+ extend_seq_lens,
897
+ extend_seq_lens_sum,
898
+ )
899
+ else:
900
+ positions, extend_start_loc = compute_position_torch(
901
+ extend_prefix_lens, extend_seq_lens
902
+ )
903
+ return positions, extend_start_loc
904
+
905
+
885
906
  def compute_position_triton(
886
907
  extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
887
908
  ):
@@ -1443,19 +1443,36 @@ class ModelRunner:
1443
1443
  )
1444
1444
 
1445
1445
  return CutlassMLABackend(self)
1446
- elif self.server_args.attention_backend == "trtllm_mla":
1446
+ elif backend_str == "trtllm_mla":
1447
1447
  if not self.use_mla_backend:
1448
1448
  raise ValueError("trtllm_mla backend can only be used with MLA models.")
1449
1449
  from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
1450
1450
 
1451
1451
  return TRTLLMMLABackend(self)
1452
- elif self.server_args.attention_backend == "intel_amx":
1452
+ elif backend_str == "trtllm_mha":
1453
+ if self.use_mla_backend:
1454
+ raise ValueError(
1455
+ "trtllm_mha backend can only be used with non-MLA models."
1456
+ )
1457
+ from sglang.srt.layers.attention.trtllm_mha_backend import (
1458
+ TRTLLMHAAttnBackend,
1459
+ )
1460
+
1461
+ return TRTLLMHAAttnBackend(self)
1462
+
1463
+ elif backend_str == "intel_amx":
1453
1464
  from sglang.srt.layers.attention.intel_amx_backend import (
1454
1465
  IntelAMXAttnBackend,
1455
1466
  )
1456
1467
 
1457
1468
  logger.info(f"Intel AMX attention backend is enabled.")
1458
1469
  return IntelAMXAttnBackend(self)
1470
+ elif self.server_args.attention_backend == "dual_chunk_flash_attn":
1471
+ from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
1472
+ DualChunkFlashAttentionBackend,
1473
+ )
1474
+
1475
+ return DualChunkFlashAttentionBackend(self)
1459
1476
  else:
1460
1477
  raise ValueError(f"Invalid attention backend: {backend_str}")
1461
1478
 
@@ -843,6 +843,16 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
843
843
  return None
844
844
  return remapped_name
845
845
 
846
+ quark_scale_names = {
847
+ ".q_proj.output_scale": ".attn.q_scale",
848
+ ".k_proj.output_scale": ".attn.k_scale",
849
+ ".v_proj.output_scale": ".attn.v_scale",
850
+ "self_attn.prob_output_scale": ".attn.prob_scale",
851
+ }
852
+ for quark_scale_name, sglang_scale_name in quark_scale_names.items():
853
+ if name.endswith(quark_scale_name):
854
+ return name.replace(quark_scale_name, sglang_scale_name)
855
+
846
856
  # If there were no matches, return the untouched param name
847
857
  return name
848
858