sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -8
  3. sglang/compile_deep_gemm.py +177 -0
  4. sglang/lang/backend/openai.py +5 -1
  5. sglang/lang/backend/runtime_endpoint.py +5 -1
  6. sglang/srt/code_completion_parser.py +1 -1
  7. sglang/srt/configs/deepseekvl2.py +1 -1
  8. sglang/srt/configs/model_config.py +11 -2
  9. sglang/srt/constrained/llguidance_backend.py +78 -61
  10. sglang/srt/constrained/xgrammar_backend.py +1 -0
  11. sglang/srt/conversation.py +34 -1
  12. sglang/srt/disaggregation/decode.py +96 -5
  13. sglang/srt/disaggregation/mini_lb.py +113 -15
  14. sglang/srt/disaggregation/mooncake/conn.py +199 -32
  15. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  16. sglang/srt/disaggregation/nixl/conn.py +622 -0
  17. sglang/srt/disaggregation/prefill.py +119 -20
  18. sglang/srt/disaggregation/utils.py +17 -0
  19. sglang/srt/entrypoints/engine.py +4 -0
  20. sglang/srt/entrypoints/http_server.py +11 -9
  21. sglang/srt/function_call_parser.py +132 -0
  22. sglang/srt/layers/activation.py +2 -2
  23. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +809 -160
  25. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  26. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
  28. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  29. sglang/srt/layers/attention/vision.py +2 -0
  30. sglang/srt/layers/dp_attention.py +1 -1
  31. sglang/srt/layers/layernorm.py +42 -5
  32. sglang/srt/layers/logits_processor.py +2 -2
  33. sglang/srt/layers/moe/ep_moe/layer.py +2 -0
  34. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  38. sglang/srt/layers/pooler.py +6 -0
  39. sglang/srt/layers/quantization/awq.py +5 -1
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  41. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  42. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  43. sglang/srt/layers/quantization/deep_gemm.py +385 -0
  44. sglang/srt/layers/quantization/fp8_kernel.py +7 -38
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/quantization/gptq.py +13 -7
  47. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  48. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  49. sglang/srt/layers/quantization/w8a8_int8.py +3 -3
  50. sglang/srt/layers/radix_attention.py +13 -3
  51. sglang/srt/layers/rotary_embedding.py +176 -132
  52. sglang/srt/layers/sampler.py +2 -2
  53. sglang/srt/managers/data_parallel_controller.py +17 -4
  54. sglang/srt/managers/io_struct.py +21 -3
  55. sglang/srt/managers/mm_utils.py +85 -28
  56. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  57. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  58. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  59. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  60. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  61. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  62. sglang/srt/managers/schedule_batch.py +42 -12
  63. sglang/srt/managers/scheduler.py +47 -26
  64. sglang/srt/managers/tokenizer_manager.py +120 -30
  65. sglang/srt/managers/tp_worker.py +1 -0
  66. sglang/srt/mem_cache/hiradix_cache.py +40 -32
  67. sglang/srt/mem_cache/memory_pool.py +118 -13
  68. sglang/srt/model_executor/cuda_graph_runner.py +16 -10
  69. sglang/srt/model_executor/forward_batch_info.py +51 -95
  70. sglang/srt/model_executor/model_runner.py +29 -27
  71. sglang/srt/models/deepseek.py +12 -2
  72. sglang/srt/models/deepseek_nextn.py +101 -6
  73. sglang/srt/models/deepseek_v2.py +153 -76
  74. sglang/srt/models/deepseek_vl2.py +9 -4
  75. sglang/srt/models/gemma3_causal.py +1 -1
  76. sglang/srt/models/llama4.py +0 -1
  77. sglang/srt/models/minicpm3.py +2 -2
  78. sglang/srt/models/minicpmo.py +22 -7
  79. sglang/srt/models/mllama4.py +2 -2
  80. sglang/srt/models/qwen2_5_vl.py +3 -6
  81. sglang/srt/models/qwen2_vl.py +3 -7
  82. sglang/srt/models/roberta.py +178 -0
  83. sglang/srt/openai_api/adapter.py +87 -10
  84. sglang/srt/openai_api/protocol.py +6 -1
  85. sglang/srt/server_args.py +65 -60
  86. sglang/srt/speculative/build_eagle_tree.py +2 -2
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +2 -2
  89. sglang/srt/speculative/eagle_worker.py +2 -7
  90. sglang/srt/torch_memory_saver_adapter.py +10 -1
  91. sglang/srt/utils.py +48 -6
  92. sglang/test/runners.py +6 -13
  93. sglang/test/test_utils.py +39 -19
  94. sglang/version.py +1 -1
  95. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
  96. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
  97. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  98. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -29,15 +29,17 @@ class HiRadixCache(RadixCache):
29
29
  tp_cache_group: torch.distributed.ProcessGroup,
30
30
  page_size: int,
31
31
  hicache_ratio: float,
32
+ hicache_size: int,
33
+ hicache_write_policy: str,
32
34
  ):
33
35
  self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
34
36
  if isinstance(self.kv_cache, MHATokenToKVPool):
35
37
  self.token_to_kv_pool_host = MHATokenToKVPoolHost(
36
- self.kv_cache, hicache_ratio, page_size
38
+ self.kv_cache, hicache_ratio, hicache_size, page_size
37
39
  )
38
40
  elif isinstance(self.kv_cache, MLATokenToKVPool):
39
41
  self.token_to_kv_pool_host = MLATokenToKVPoolHost(
40
- self.kv_cache, hicache_ratio, page_size
42
+ self.kv_cache, hicache_ratio, hicache_size, page_size
41
43
  )
42
44
  else:
43
45
  raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
@@ -50,6 +52,7 @@ class HiRadixCache(RadixCache):
50
52
  self.token_to_kv_pool_host,
51
53
  page_size,
52
54
  load_cache_event=self.load_cache_event,
55
+ write_policy=hicache_write_policy,
53
56
  )
54
57
 
55
58
  # record the nodes with ongoing write through
@@ -57,7 +60,9 @@ class HiRadixCache(RadixCache):
57
60
  # record the node segments with ongoing load back
58
61
  self.ongoing_load_back = {}
59
62
  # todo: dynamically adjust the threshold
60
- self.write_through_threshold = 1
63
+ self.write_through_threshold = (
64
+ 1 if hicache_write_policy == "write_through" else 3
65
+ )
61
66
  self.load_back_threshold = 10
62
67
  super().__init__(
63
68
  req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
@@ -76,7 +81,7 @@ class HiRadixCache(RadixCache):
76
81
  height += 1
77
82
  return height
78
83
 
79
- def write_backup(self, node: TreeNode):
84
+ def write_backup(self, node: TreeNode, write_back=False):
80
85
  host_indices = self.cache_controller.write(
81
86
  device_indices=node.value,
82
87
  node_id=node.id,
@@ -90,21 +95,29 @@ class HiRadixCache(RadixCache):
90
95
  if host_indices is not None:
91
96
  node.host_value = host_indices
92
97
  self.ongoing_write_through[node.id] = node
93
- self.inc_lock_ref(node)
98
+ if not write_back:
99
+ # no need to lock nodes if write back
100
+ self.inc_lock_ref(node)
94
101
  else:
95
102
  return 0
96
103
 
97
104
  return len(host_indices)
98
105
 
99
106
  def inc_hit_count(self, node: TreeNode):
100
- if self.cache_controller.write_policy != "write_through_selective":
107
+ if node.backuped or self.cache_controller.write_policy == "write_back":
101
108
  return
102
109
  node.hit_count += 1
103
- if node.host_value is None and node.hit_count > self.write_through_threshold:
110
+ if node.hit_count >= self.write_through_threshold:
104
111
  self.write_backup(node)
105
112
  node.hit_count = 0
106
113
 
107
- def writing_check(self):
114
+ def writing_check(self, write_back=False):
115
+ if write_back:
116
+ # blocking till all write back complete
117
+ while len(self.ongoing_write_through) > 0:
118
+ ack_id = self.cache_controller.ack_write_queue.get()
119
+ del self.ongoing_write_through[ack_id]
120
+ return
108
121
  queue_size = torch.tensor(
109
122
  self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
110
123
  )
@@ -143,29 +156,25 @@ class HiRadixCache(RadixCache):
143
156
  heapq.heapify(leaves)
144
157
 
145
158
  num_evicted = 0
146
- pending_nodes = []
159
+ write_back_nodes = []
147
160
  while num_evicted < num_tokens and len(leaves):
148
161
  x = heapq.heappop(leaves)
149
162
 
150
163
  if x.lock_ref > 0:
151
164
  continue
152
165
 
153
- if x.host_value is None:
166
+ if not x.backuped:
154
167
  if self.cache_controller.write_policy == "write_back":
155
- num_evicted += self.write_backup(x)
156
- pending_nodes.append(x)
157
- elif self.cache_controller.write_policy == "write_through_selective":
158
- num_evicted += self._evict_write_through_selective(x)
168
+ # write to host if the node is not backuped
169
+ num_evicted += self.write_backup(x, write_back=True)
170
+ write_back_nodes.append(x)
159
171
  else:
160
- assert (
161
- self.cache_controller.write_policy != "write_through"
162
- ), "write_through should be inclusive"
163
- raise NotImplementedError
172
+ num_evicted += self._evict_regular(x)
164
173
  else:
165
- num_evicted += self._evict_write_through(x)
174
+ num_evicted += self._evict_backuped(x)
166
175
 
167
176
  for child in x.parent.children.values():
168
- if child in pending_nodes:
177
+ if child in write_back_nodes:
169
178
  continue
170
179
  if not child.evicted:
171
180
  break
@@ -174,15 +183,12 @@ class HiRadixCache(RadixCache):
174
183
  heapq.heappush(leaves, x.parent)
175
184
 
176
185
  if self.cache_controller.write_policy == "write_back":
177
- # blocking till all write back complete
178
- while len(self.ongoing_write_through) > 0:
179
- self.writing_check()
180
- time.sleep(0.1)
181
- for node in pending_nodes:
182
- assert node.host_value is not None
183
- self._evict_write_through(node)
186
+ self.writing_check(write_back=True)
187
+ for node in write_back_nodes:
188
+ assert node.backuped
189
+ self._evict_backuped(node)
184
190
 
185
- def _evict_write_through(self, node: TreeNode):
191
+ def _evict_backuped(self, node: TreeNode):
186
192
  # evict a node already written to host
187
193
  num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
188
194
  assert num_evicted > 0
@@ -190,7 +196,7 @@ class HiRadixCache(RadixCache):
190
196
  node.value = None
191
197
  return num_evicted
192
198
 
193
- def _evict_write_through_selective(self, node: TreeNode):
199
+ def _evict_regular(self, node: TreeNode):
194
200
  # evict a node not initiated write to host
195
201
  self.cache_controller.mem_pool_device_allocator.free(node.value)
196
202
  num_evicted = len(node.value)
@@ -339,11 +345,13 @@ class HiRadixCache(RadixCache):
339
345
  prefix_len = self.key_match_fn(child.key, key)
340
346
  if prefix_len < len(child.key):
341
347
  new_node = self._split_node(child.key, child, prefix_len)
348
+ self.inc_hit_count(new_node)
342
349
  if not new_node.evicted:
343
350
  value.append(new_node.value)
344
351
  node = new_node
345
352
  break
346
353
  else:
354
+ self.inc_hit_count(child)
347
355
  if not child.evicted:
348
356
  value.append(child.value)
349
357
  node = child
@@ -369,7 +377,7 @@ class HiRadixCache(RadixCache):
369
377
  else:
370
378
  new_node.value = child.value[:split_len]
371
379
  child.value = child.value[split_len:]
372
- if child.host_value is not None:
380
+ if child.backuped:
373
381
  new_node.host_value = child.host_value[:split_len]
374
382
  child.host_value = child.host_value[split_len:]
375
383
  child.parent = new_node
@@ -426,8 +434,8 @@ class HiRadixCache(RadixCache):
426
434
  node.children[child_key] = new_node
427
435
  self.evictable_size_ += len(value)
428
436
 
429
- if self.cache_controller.write_policy == "write_through":
430
- self.write_backup(new_node)
437
+ if self.cache_controller.write_policy != "write_back":
438
+ self.inc_hit_count(new_node)
431
439
  return total_prefix_length
432
440
 
433
441
  def _collect_leaves_device(self):
@@ -34,6 +34,8 @@ from typing import List, Optional, Tuple, Union
34
34
  import numpy as np
35
35
  import psutil
36
36
  import torch
37
+ import triton
38
+ import triton.language as tl
37
39
 
38
40
  from sglang.srt.layers.radix_attention import RadixAttention
39
41
  from sglang.srt.utils import debug_timing, get_compiler_backend
@@ -405,6 +407,72 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
405
407
  dst_2[loc] = src_2.to(dtype).view(store_dtype)
406
408
 
407
409
 
410
+ @triton.jit
411
+ def set_mla_kv_buffer_kernel(
412
+ kv_buffer_ptr,
413
+ cache_k_nope_ptr,
414
+ cache_k_rope_ptr,
415
+ loc_ptr,
416
+ buffer_stride: tl.constexpr,
417
+ nope_stride: tl.constexpr,
418
+ rope_stride: tl.constexpr,
419
+ nope_dim: tl.constexpr,
420
+ rope_dim: tl.constexpr,
421
+ BLOCK: tl.constexpr,
422
+ ):
423
+ pid_loc = tl.program_id(0)
424
+ pid_blk = tl.program_id(1)
425
+
426
+ base = pid_blk * BLOCK
427
+ offs = base + tl.arange(0, BLOCK)
428
+ total_dim = nope_dim + rope_dim
429
+ mask = offs < total_dim
430
+
431
+ loc = tl.load(loc_ptr + pid_loc)
432
+ dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
433
+
434
+ if base + BLOCK <= nope_dim:
435
+ src = tl.load(
436
+ cache_k_nope_ptr + pid_loc * nope_stride + offs,
437
+ mask=mask,
438
+ )
439
+ else:
440
+ offs_rope = offs - nope_dim
441
+ src = tl.load(
442
+ cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
443
+ mask=mask,
444
+ )
445
+
446
+ tl.store(dst_ptr, src, mask=mask)
447
+
448
+
449
+ def set_mla_kv_buffer_triton(
450
+ kv_buffer: torch.Tensor,
451
+ loc: torch.Tensor,
452
+ cache_k_nope: torch.Tensor,
453
+ cache_k_rope: torch.Tensor,
454
+ ):
455
+ nope_dim = cache_k_nope.shape[-1]
456
+ rope_dim = cache_k_rope.shape[-1]
457
+ total_dim = nope_dim + rope_dim
458
+ BLOCK = 128
459
+ n_loc = loc.numel()
460
+ grid = (n_loc, triton.cdiv(total_dim, BLOCK))
461
+
462
+ set_mla_kv_buffer_kernel[grid](
463
+ kv_buffer,
464
+ cache_k_nope,
465
+ cache_k_rope,
466
+ loc,
467
+ kv_buffer.stride(0),
468
+ cache_k_nope.stride(0),
469
+ cache_k_rope.stride(0),
470
+ nope_dim,
471
+ rope_dim,
472
+ BLOCK=BLOCK,
473
+ )
474
+
475
+
408
476
  class MLATokenToKVPool(KVCache):
409
477
  def __init__(
410
478
  self,
@@ -446,13 +514,28 @@ class MLATokenToKVPool(KVCache):
446
514
  ]
447
515
 
448
516
  self.layer_transfer_counter = None
517
+ self.page_size = page_size
518
+
519
+ kv_size = self.get_kv_size_bytes()
520
+ logger.info(
521
+ f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
522
+ )
523
+
524
+ def get_kv_size_bytes(self):
525
+ assert hasattr(self, "kv_buffer")
526
+ kv_size_bytes = 0
527
+ for kv_cache in self.kv_buffer:
528
+ kv_size_bytes += np.prod(kv_cache.shape) * kv_cache.dtype.itemsize
529
+ return kv_size_bytes
449
530
 
450
531
  # for disagg
451
532
  def get_contiguous_buf_infos(self):
452
533
  # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
453
534
  kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
454
535
  kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
455
- kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
536
+ kv_item_lens = [
537
+ self.kv_buffer[i][0].nbytes * self.page_size for i in range(self.layer_num)
538
+ ]
456
539
  return kv_data_ptrs, kv_data_lens, kv_item_lens
457
540
 
458
541
  def get_key_buffer(self, layer_id: int):
@@ -489,6 +572,25 @@ class MLATokenToKVPool(KVCache):
489
572
  else:
490
573
  self.kv_buffer[layer_id][loc] = cache_k
491
574
 
575
+ def set_mla_kv_buffer(
576
+ self,
577
+ layer: RadixAttention,
578
+ loc: torch.Tensor,
579
+ cache_k_nope: torch.Tensor,
580
+ cache_k_rope: torch.Tensor,
581
+ ):
582
+ layer_id = layer.layer_id
583
+ if cache_k_nope.dtype != self.dtype:
584
+ cache_k_nope = cache_k_nope.to(self.dtype)
585
+ cache_k_rope = cache_k_rope.to(self.dtype)
586
+ if self.store_dtype != self.dtype:
587
+ cache_k_nope = cache_k_nope.view(self.store_dtype)
588
+ cache_k_rope = cache_k_rope.view(self.store_dtype)
589
+
590
+ set_mla_kv_buffer_triton(
591
+ self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
592
+ )
593
+
492
594
  def get_flat_data(self, indices):
493
595
  # prepare a large chunk of contiguous data for efficient transfer
494
596
  return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
@@ -621,26 +723,27 @@ class HostKVCache(abc.ABC):
621
723
  self,
622
724
  device_pool: MHATokenToKVPool,
623
725
  host_to_device_ratio: float,
726
+ host_size: int,
624
727
  pin_memory: bool,
625
728
  device: str,
626
729
  page_size: int,
627
730
  ):
628
- assert (
629
- host_to_device_ratio >= 1
630
- ), "The host memory should be larger than the device memory with the current protocol"
631
- # todo, other ways of configuring the size
632
-
633
731
  self.device_pool = device_pool
634
- self.host_to_device_ratio = host_to_device_ratio
732
+ self.dtype = device_pool.store_dtype
635
733
  self.pin_memory = pin_memory
636
734
  self.device = device
637
735
  self.page_size = page_size
638
-
639
- self.size = int(device_pool.size * host_to_device_ratio)
736
+ self.size_per_token = self.get_size_per_token()
737
+ if host_size > 0:
738
+ self.size = int(host_size * 1e9 // self.size_per_token)
739
+ else:
740
+ self.size = int(device_pool.size * host_to_device_ratio)
640
741
  # Align the host memory pool size to the page size
641
742
  self.size = self.size - (self.size % self.page_size)
642
- self.dtype = device_pool.store_dtype
643
- self.size_per_token = self.get_size_per_token()
743
+
744
+ assert (
745
+ self.size > device_pool.size
746
+ ), "The host memory should be larger than the device memory with the current protocol"
644
747
 
645
748
  # Verify there is enough available host memory.
646
749
  host_mem = psutil.virtual_memory()
@@ -792,12 +895,13 @@ class MHATokenToKVPoolHost(HostKVCache):
792
895
  self,
793
896
  device_pool: MHATokenToKVPool,
794
897
  host_to_device_ratio: float,
898
+ host_size: int,
795
899
  page_size: int,
796
900
  pin_memory: bool = True,
797
901
  device: str = "cpu",
798
902
  ):
799
903
  super().__init__(
800
- device_pool, host_to_device_ratio, pin_memory, device, page_size
904
+ device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
801
905
  )
802
906
 
803
907
  def get_size_per_token(self):
@@ -866,12 +970,13 @@ class MLATokenToKVPoolHost(HostKVCache):
866
970
  self,
867
971
  device_pool: MLATokenToKVPool,
868
972
  host_to_device_ratio: float,
973
+ host_size: int,
869
974
  page_size: int,
870
975
  pin_memory: bool = True,
871
976
  device: str = "cpu",
872
977
  ):
873
978
  super().__init__(
874
- device_pool, host_to_device_ratio, pin_memory, device, page_size
979
+ device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
875
980
  )
876
981
 
877
982
  def get_size_per_token(self):
@@ -35,7 +35,11 @@ from sglang.srt.model_executor.forward_batch_info import (
35
35
  ForwardMode,
36
36
  )
37
37
  from sglang.srt.patch_torch import monkey_patch_torch_compile
38
- from sglang.srt.utils import get_available_gpu_memory, is_hip
38
+ from sglang.srt.utils import (
39
+ get_available_gpu_memory,
40
+ get_device_memory_capacity,
41
+ is_hip,
42
+ )
39
43
 
40
44
  if TYPE_CHECKING:
41
45
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -129,7 +133,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
129
133
  list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
130
134
  )
131
135
 
132
- if _is_hip:
136
+ gpu_mem = get_device_memory_capacity()
137
+ # Batch size of each rank will not become so large when DP is on
138
+ if gpu_mem is not None and gpu_mem > 81920 and server_args.dp_size == 1:
133
139
  capture_bs += list(range(160, 257, 8))
134
140
 
135
141
  if max(capture_bs) > model_runner.req_to_token_pool.size:
@@ -140,12 +146,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
140
146
  ]
141
147
 
142
148
  capture_bs = list(sorted(set(capture_bs)))
143
- capture_bs = [
144
- bs
145
- for bs in capture_bs
146
- if bs <= model_runner.req_to_token_pool.size
147
- and bs <= server_args.cuda_graph_max_bs
148
- ]
149
+
150
+ assert len(capture_bs) > 0 and capture_bs[0] > 0
151
+ capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
152
+ if server_args.cuda_graph_max_bs:
153
+ capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
149
154
  compile_bs = (
150
155
  [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
151
156
  if server_args.enable_torch_compile
@@ -186,6 +191,7 @@ class CudaGraphRunner:
186
191
 
187
192
  # Batch sizes to capture
188
193
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
194
+
189
195
  self.capture_forward_mode = ForwardMode.DECODE
190
196
  self.capture_hidden_mode = CaptureHiddenMode.NULL
191
197
  self.num_tokens_per_bs = 1
@@ -273,9 +279,9 @@ class CudaGraphRunner:
273
279
  f"Capture cuda graph failed: {e}\n"
274
280
  "Possible solutions:\n"
275
281
  "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
276
- "2. set --cuda-graph-max-bs to a smaller value (e.g., 32)\n"
282
+ "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
277
283
  "3. disable torch compile by not using --enable-torch-compile\n"
278
- "4. disable cuda graph by --disable-cuda-graph\n"
284
+ "4. disable cuda graph by --disable-cuda-graph. (Not recommonded. Huge perf loss)\n"
279
285
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
280
286
  )
281
287
 
@@ -38,7 +38,7 @@ import triton
38
38
  import triton.language as tl
39
39
 
40
40
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
- from sglang.srt.utils import get_compiler_backend
41
+ from sglang.srt.utils import flatten_nested_list, get_compiler_backend
42
42
 
43
43
  if TYPE_CHECKING:
44
44
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -364,23 +364,23 @@ class ForwardBatch:
364
364
 
365
365
  def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
366
366
  """
367
- Merge all image inputs in the batch into a single MultiModalInputs object.
367
+ Merge all multimodal inputs in the batch into a single MultiModalInputs object.
368
368
 
369
369
  Returns:
370
- if none, current batch contains no image input
370
+ if none, current batch contains no multimodal input
371
371
 
372
372
  """
373
373
  if not self.mm_inputs or all(x is None for x in self.mm_inputs):
374
374
  return None
375
-
376
375
  # Filter out None values
377
376
  valid_inputs = [x for x in self.mm_inputs if x is not None]
378
377
 
379
- # Start with the first valid image input
380
- merged = valid_inputs[0]
378
+ # TODO: is it expensive?
379
+ # a workaround to avoid importing `MultimodalInputs`
380
+ merged = valid_inputs[0].__class__(mm_items=[])
381
381
 
382
382
  # Merge remaining inputs
383
- for mm_input in valid_inputs[1:]:
383
+ for mm_input in valid_inputs:
384
384
  merged.merge(mm_input)
385
385
 
386
386
  return merged
@@ -407,104 +407,60 @@ class ForwardBatch:
407
407
  def _compute_mrope_positions(
408
408
  self, model_runner: ModelRunner, batch: ModelWorkerBatch
409
409
  ):
410
- device = model_runner.device
411
- hf_config = model_runner.model_config.hf_config
412
- mrope_positions_list = [None] * self.seq_lens.shape[0]
413
- if self.forward_mode.is_decode():
414
- for i, _ in enumerate(mrope_positions_list):
415
- mrope_position_delta = (
416
- 0
417
- if batch.multimodal_inputs[i] is None
418
- else batch.multimodal_inputs[i].mrope_position_delta
419
- )
420
- mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
421
- mrope_position_delta,
422
- int(self.seq_lens[i]) - 1,
423
- int(self.seq_lens[i]),
410
+ # batch_size * [3 * seq_len]
411
+ batch_size = self.seq_lens.shape[0]
412
+ mrope_positions_list = [[]] * batch_size
413
+ for batch_idx in range(batch_size):
414
+ mm_input = batch.multimodal_inputs[batch_idx]
415
+ if self.forward_mode.is_decode():
416
+ mrope_position_deltas = (
417
+ [0]
418
+ if mm_input is None
419
+ else flatten_nested_list(mm_input.mrope_position_delta.tolist())
424
420
  )
425
- elif self.forward_mode.is_extend():
426
- extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
427
- for i, mm_input in enumerate(batch.multimodal_inputs):
428
- extend_start_loc, extend_seq_len, extend_prefix_len = (
429
- extend_start_loc_cpu[i],
430
- batch.extend_seq_lens[i],
431
- batch.extend_prefix_lens[i],
421
+ next_input_positions = []
422
+ for mrope_position_delta in mrope_position_deltas:
423
+ # batched deltas needs to be processed separately
424
+ # Convert list of lists to tensor with shape [3, seq_len]
425
+ next_input_positions += [
426
+ MRotaryEmbedding.get_next_input_positions(
427
+ mrope_position_delta,
428
+ int(self.seq_lens[batch_idx]) - 1,
429
+ int(self.seq_lens[batch_idx]),
430
+ )
431
+ ]
432
+ # 3 * N
433
+ mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1)
434
+ elif self.forward_mode.is_extend():
435
+ extend_seq_len, extend_prefix_len = (
436
+ batch.extend_seq_lens[batch_idx],
437
+ batch.extend_prefix_lens[batch_idx],
432
438
  )
433
439
  if mm_input is None:
434
440
  # text only
435
- mrope_positions = [
441
+ mrope_positions = torch.tensor(
436
442
  [
437
- pos
438
- for pos in range(
439
- extend_prefix_len, extend_prefix_len + extend_seq_len
440
- )
443
+ [
444
+ pos
445
+ for pos in range(
446
+ extend_prefix_len,
447
+ extend_prefix_len + extend_seq_len,
448
+ )
449
+ ]
441
450
  ]
442
- ] * 3
443
- else:
444
- image_grid_thws_list = [
445
- item.image_grid_thws
446
- for item in mm_input.mm_items
447
- if item.image_grid_thws is not None
448
- ]
449
- image_grid_thw = (
450
- None
451
- if len(image_grid_thws_list) == 0
452
- else torch.cat(image_grid_thws_list, dim=0)
453
- )
454
-
455
- video_grid_thws_list = [
456
- item.video_grid_thws
457
- for item in mm_input.mm_items
458
- if item.video_grid_thws is not None
459
- ]
460
- video_grid_thw = (
461
- None
462
- if len(video_grid_thws_list) == 0
463
- else torch.cat(video_grid_thws_list, dim=0)
451
+ * 3
464
452
  )
465
-
466
- second_per_grid_ts_list = [
467
- item.second_per_grid_ts
468
- for item in mm_input.mm_items
469
- if item.second_per_grid_ts is not None
453
+ else:
454
+ mrope_positions = mm_input.mrope_positions[
455
+ :,
456
+ extend_prefix_len : extend_prefix_len + extend_seq_len,
470
457
  ]
471
- second_per_grid_ts = (
472
- None
473
- if len(second_per_grid_ts_list) == 0
474
- else torch.cat(second_per_grid_ts_list, dim=0)
475
- )
476
-
477
- # TODO: current qwen2-vl do not support radix cache since mrope position calculation
478
- mrope_positions, mrope_position_delta = (
479
- MRotaryEmbedding.get_input_positions(
480
- input_tokens=self.input_ids[
481
- extend_start_loc : extend_start_loc + extend_seq_len
482
- ].tolist(),
483
- image_grid_thw=image_grid_thw,
484
- video_grid_thw=video_grid_thw,
485
- image_token_id=hf_config.image_token_id,
486
- video_token_id=hf_config.video_token_id,
487
- vision_start_token_id=hf_config.vision_start_token_id,
488
- vision_end_token_id=hf_config.vision_end_token_id,
489
- spatial_merge_size=hf_config.vision_config.spatial_merge_size,
490
- context_len=0,
491
- seq_len=len(self.input_ids),
492
- second_per_grid_ts=second_per_grid_ts,
493
- tokens_per_second=getattr(
494
- hf_config.vision_config, "tokens_per_second", None
495
- ),
496
- )
497
- )
498
- batch.multimodal_inputs[i].mrope_position_delta = (
499
- mrope_position_delta
500
- )
501
- mrope_positions_list[i] = mrope_positions
458
+ mrope_positions_list[batch_idx] = mrope_positions
502
459
 
503
460
  self.mrope_positions = torch.cat(
504
- [torch.tensor(pos, device=device) for pos in mrope_positions_list],
505
- axis=1,
506
- )
507
- self.mrope_positions = self.mrope_positions.to(torch.int64)
461
+ [pos.to(device=model_runner.device) for pos in mrope_positions_list],
462
+ dim=1,
463
+ ).to(dtype=torch.int64, device=model_runner.device)
508
464
 
509
465
  def get_max_chunk_capacity(self):
510
466
  # Maximum number of tokens in each chunk