sglang 0.4.2__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/layers/activation.py +10 -5
  5. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  6. sglang/srt/layers/attention/triton_backend.py +71 -7
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  8. sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
  9. sglang/srt/layers/attention/vision.py +243 -40
  10. sglang/srt/layers/layernorm.py +1 -5
  11. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  22. sglang/srt/layers/moe/topk.py +4 -0
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/fp8.py +7 -0
  46. sglang/srt/layers/quantization/fp8_kernel.py +140 -2
  47. sglang/srt/layers/rotary_embedding.py +29 -15
  48. sglang/srt/layers/sampler.py +9 -6
  49. sglang/srt/lora/backend/__init__.py +8 -0
  50. sglang/srt/lora/backend/base_backend.py +95 -0
  51. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  52. sglang/srt/lora/backend/triton_backend.py +61 -0
  53. sglang/srt/lora/lora.py +127 -112
  54. sglang/srt/lora/lora_manager.py +50 -18
  55. sglang/srt/lora/triton_ops/__init__.py +5 -0
  56. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  57. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  58. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  59. sglang/srt/managers/image_processor.py +77 -38
  60. sglang/srt/managers/scheduler.py +17 -3
  61. sglang/srt/mem_cache/base_prefix_cache.py +4 -0
  62. sglang/srt/mem_cache/chunk_cache.py +3 -0
  63. sglang/srt/mem_cache/radix_cache.py +30 -1
  64. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  65. sglang/srt/model_executor/forward_batch_info.py +58 -59
  66. sglang/srt/model_executor/model_runner.py +2 -2
  67. sglang/srt/models/minicpmv.py +129 -76
  68. sglang/srt/models/mllama.py +16 -56
  69. sglang/srt/models/qwen2.py +4 -1
  70. sglang/srt/models/qwen2_vl.py +19 -9
  71. sglang/srt/server_args.py +19 -2
  72. sglang/srt/speculative/build_eagle_tree.py +4 -2
  73. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  74. sglang/srt/speculative/eagle_utils.py +361 -372
  75. sglang/srt/speculative/eagle_worker.py +177 -45
  76. sglang/srt/utils.py +7 -2
  77. sglang/test/runners.py +2 -0
  78. sglang/utils.py +42 -0
  79. sglang/version.py +1 -1
  80. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +16 -7
  81. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +84 -45
  82. sglang/srt/layers/custom_op_util.py +0 -25
  83. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
  84. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
  85. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -41,6 +41,10 @@ class BasePrefixCache(ABC):
41
41
  def evictable_size(self):
42
42
  pass
43
43
 
44
+ @abstractmethod
45
+ def protected_size(self):
46
+ raise NotImplementedError()
47
+
44
48
  def total_size(self):
45
49
  raise NotImplementedError()
46
50
 
@@ -85,3 +85,6 @@ class ChunkCache(BasePrefixCache):
85
85
 
86
86
  def evictable_size(self):
87
87
  return 0
88
+
89
+ def protected_size(self):
90
+ return 0
@@ -34,7 +34,10 @@ if TYPE_CHECKING:
34
34
 
35
35
 
36
36
  class TreeNode:
37
- def __init__(self):
37
+
38
+ counter = 0
39
+
40
+ def __init__(self, id: Optional[int] = None):
38
41
  self.children = defaultdict(TreeNode)
39
42
  self.parent = None
40
43
  self.key = None
@@ -42,6 +45,23 @@ class TreeNode:
42
45
  self.lock_ref = 0
43
46
  self.last_access_time = time.time()
44
47
 
48
+ self.hit_count = 0
49
+ # indicating the node is loading KV cache from host
50
+ self.loading = False
51
+ # store the host indices of KV cache
52
+ self.host_value = None
53
+
54
+ self.id = TreeNode.counter if id is None else id
55
+ TreeNode.counter += 1
56
+
57
+ @property
58
+ def evicted(self):
59
+ return self.value is None
60
+
61
+ @property
62
+ def backuped(self):
63
+ return self.host_value is not None
64
+
45
65
  def __lt__(self, other: "TreeNode"):
46
66
  return self.last_access_time < other.last_access_time
47
67
 
@@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache):
75
95
  self.root_node.value = []
76
96
  self.root_node.lock_ref = 1
77
97
  self.evictable_size_ = 0
98
+ self.protected_size_ = 0
78
99
 
79
100
  def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
80
101
  """Find the matching prefix from the radix tree.
@@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache):
203
224
  while node != self.root_node:
204
225
  if node.lock_ref == 0:
205
226
  self.evictable_size_ -= len(node.value)
227
+ self.protected_size_ += len(node.value)
206
228
  delta -= len(node.value)
207
229
  node.lock_ref += 1
208
230
  node = node.parent
@@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache):
216
238
  while node != self.root_node:
217
239
  if node.lock_ref == 1:
218
240
  self.evictable_size_ += len(node.value)
241
+ self.protected_size_ -= len(node.value)
219
242
  delta += len(node.value)
220
243
  node.lock_ref -= 1
221
244
  node = node.parent
@@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache):
224
247
  def evictable_size(self):
225
248
  return self.evictable_size_
226
249
 
250
+ def protected_size(self):
251
+ # protected size refers to the size of the cache that is locked
252
+ return self.protected_size_
253
+
227
254
  ##### Internal Helper Functions #####
228
255
 
229
256
  def _match_prefix_helper(
@@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache):
303
330
  self.evictable_size_ -= len(node.key)
304
331
 
305
332
  def _total_size_helper(self, node: TreeNode):
333
+ if node.evicted:
334
+ return 0
306
335
  x = len(node.value)
307
336
  for child in node.children.values():
308
337
  x += self._total_size_helper(child)
@@ -21,8 +21,8 @@ from typing import TYPE_CHECKING, Callable
21
21
 
22
22
  import torch
23
23
  import tqdm
24
- from vllm.model_executor.custom_op import CustomOp
25
24
 
25
+ from sglang.srt.custom_op import CustomOp
26
26
  from sglang.srt.distributed import get_tensor_model_parallel_rank
27
27
  from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
28
28
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -103,69 +103,75 @@ def set_torch_compile_config():
103
103
  torch._dynamo.config.cache_size_limit = 1024
104
104
 
105
105
 
106
+ def get_batch_sizes_to_capture(model_runner: ModelRunner):
107
+ server_args = model_runner.server_args
108
+ capture_bs = server_args.cuda_graph_bs
109
+ if capture_bs is None:
110
+ if server_args.disable_cuda_graph_padding:
111
+ capture_bs = list(range(1, 33)) + [64, 128]
112
+ else:
113
+ capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
114
+ if max(capture_bs) > model_runner.req_to_token_pool.size:
115
+ # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
116
+ # is very samll. We add more values here to make sure we capture the maximum bs.
117
+ capture_bs = list(
118
+ sorted(
119
+ set(
120
+ capture_bs
121
+ + [model_runner.req_to_token_pool.size - 1]
122
+ + [model_runner.req_to_token_pool.size]
123
+ )
124
+ )
125
+ )
126
+ capture_bs = [
127
+ bs
128
+ for bs in capture_bs
129
+ if bs <= model_runner.req_to_token_pool.size
130
+ and bs <= server_args.cuda_graph_max_bs
131
+ ]
132
+ compile_bs = (
133
+ [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
134
+ if server_args.enable_torch_compile
135
+ else []
136
+ )
137
+ return capture_bs, compile_bs
138
+
139
+
140
+ # Reuse this memory pool across all cuda graph runners.
141
+ global_graph_memory_pool = None
142
+
143
+
144
+ def get_global_graph_memory_pool():
145
+ return global_graph_memory_pool
146
+
147
+
148
+ def set_global_graph_memory_pool(val):
149
+ global global_graph_memory_pool
150
+ global_graph_memory_pool = val
151
+
152
+
106
153
  class CudaGraphRunner:
107
154
  """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
108
155
 
109
- def __init__(self, model_runner: "ModelRunner"):
156
+ def __init__(self, model_runner: ModelRunner):
110
157
  # Parse args
111
158
  self.model_runner = model_runner
112
159
  self.graphs = {}
113
- self.input_buffers = {}
114
160
  self.output_buffers = {}
115
- self.flashinfer_handlers = {}
116
- self.graph_memory_pool = None
117
- self.use_torch_compile = model_runner.server_args.enable_torch_compile
161
+ self.enable_torch_compile = model_runner.server_args.enable_torch_compile
118
162
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
119
- self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
120
- self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
121
- self.tp_size = self.model_runner.tp_size
122
- self.dp_size = self.model_runner.server_args.dp_size
163
+ self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
164
+ self.enable_dp_attention = model_runner.server_args.enable_dp_attention
165
+ self.tp_size = model_runner.server_args.tp_size
166
+ self.dp_size = model_runner.server_args.dp_size
123
167
 
124
168
  # Batch sizes to capture
125
- self.capture_bs = self.model_runner.server_args.cuda_graph_bs
126
- if self.capture_bs is None:
127
- if model_runner.server_args.disable_cuda_graph_padding:
128
- self.capture_bs = list(range(1, 33)) + [64, 128]
129
- else:
130
- self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
131
-
132
- if max(self.capture_bs) > model_runner.req_to_token_pool.size:
133
- # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
134
- # is very samll. We add more values here to make sure we capture the maximum bs.
135
- self.capture_bs = list(
136
- sorted(
137
- set(
138
- self.capture_bs
139
- + [model_runner.req_to_token_pool.size - 1]
140
- + [model_runner.req_to_token_pool.size]
141
- )
142
- )
143
- )
144
-
145
- self.capture_bs = [
146
- bs
147
- for bs in self.capture_bs
148
- if bs <= model_runner.req_to_token_pool.size
149
- and bs <= model_runner.server_args.cuda_graph_max_bs
150
- ]
151
-
152
- self.compile_bs = (
153
- [
154
- bs
155
- for bs in self.capture_bs
156
- if bs <= self.model_runner.server_args.torch_compile_max_bs
157
- ]
158
- if self.use_torch_compile
159
- else []
160
- )
161
-
169
+ self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
162
170
  self.capture_forward_mode = ForwardMode.DECODE
163
171
  self.num_tokens_per_bs = 1
164
172
  if model_runner.spec_algorithm.is_eagle():
165
173
  if self.model_runner.is_draft_worker:
166
- self.num_tokens_per_bs = (
167
- self.model_runner.server_args.speculative_eagle_topk
168
- )
174
+ raise RuntimeError("This should not happen")
169
175
  else:
170
176
  self.capture_forward_mode = ForwardMode.TARGET_VERIFY
171
177
  self.num_tokens_per_bs = (
@@ -182,10 +188,10 @@ class CudaGraphRunner:
182
188
  # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
183
189
  self.encoder_len_fill_value = 0
184
190
 
185
- if self.use_torch_compile:
191
+ if self.enable_torch_compile:
186
192
  set_torch_compile_config()
187
193
 
188
- # Common inputs
194
+ # Graph inputs
189
195
  with torch.device("cuda"):
190
196
  self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
191
197
  self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
@@ -301,7 +307,7 @@ class CudaGraphRunner:
301
307
  stream = self.stream
302
308
  num_tokens = bs * self.num_tokens_per_bs
303
309
 
304
- # Common inputs
310
+ # Graph inputs
305
311
  input_ids = self.input_ids[:num_tokens]
306
312
  req_pool_indices = self.req_pool_indices[:bs]
307
313
  seq_lens = self.seq_lens[:bs]
@@ -320,7 +326,7 @@ class CudaGraphRunner:
320
326
  global_num_tokens = None
321
327
  gathered_buffer = None
322
328
 
323
- spec_info = self.get_spec_info(num_tokens, positions)
329
+ spec_info = self.get_spec_info(num_tokens)
324
330
 
325
331
  forward_batch = ForwardBatch(
326
332
  forward_mode=self.capture_forward_mode,
@@ -335,7 +341,6 @@ class CudaGraphRunner:
335
341
  seq_lens_sum=seq_lens.sum(),
336
342
  encoder_lens=encoder_lens,
337
343
  return_logprob=False,
338
- top_logprobs_nums=[0] * bs,
339
344
  positions=positions,
340
345
  global_num_tokens=global_num_tokens,
341
346
  gathered_buffer=gathered_buffer,
@@ -375,13 +380,14 @@ class CudaGraphRunner:
375
380
  torch.cuda.synchronize()
376
381
  self.model_runner.tp_group.barrier()
377
382
 
378
- with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
383
+ global global_graph_memory_pool
384
+ with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
379
385
  out = run_once()
380
386
 
381
387
  torch.cuda.synchronize()
382
388
  self.model_runner.tp_group.barrier()
383
389
 
384
- self.graph_memory_pool = graph.pool()
390
+ global_graph_memory_pool = graph.pool()
385
391
  return graph, out
386
392
 
387
393
  def replay(self, forward_batch: ForwardBatch):
@@ -439,35 +445,26 @@ class CudaGraphRunner:
439
445
  )
440
446
  return logits_output
441
447
 
442
- def get_spec_info(self, num_tokens: int, positions: torch.Tensor):
448
+ def get_spec_info(self, num_tokens: int):
443
449
  spec_info = None
444
450
  if self.model_runner.spec_algorithm.is_eagle():
445
- from sglang.srt.speculative.eagle_utils import (
446
- EAGLEDraftInput,
447
- EagleVerifyInput,
448
- )
451
+ from sglang.srt.speculative.eagle_utils import EagleVerifyInput
449
452
 
450
453
  if self.model_runner.is_draft_worker:
451
- spec_info = EAGLEDraftInput()
452
- spec_info.load_server_args(self.model_runner.server_args)
453
- spec_info.hidden_states = self.hidden_states[:num_tokens]
454
- spec_info.positions = positions
455
- spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
454
+ raise RuntimeError("This should not happen.")
456
455
  else:
457
456
  spec_info = EagleVerifyInput(
458
- None,
459
- None,
460
- None,
461
- None,
462
- None,
463
- None,
464
- self.model_runner.server_args.speculative_num_draft_tokens,
465
- )
466
- spec_info.custom_mask = torch.zeros(
467
- (num_tokens * self.model_runner.model_config.context_len),
468
- dtype=torch.bool,
469
- device="cuda",
457
+ draft_token=None,
458
+ custom_mask=torch.zeros(
459
+ (num_tokens * self.model_runner.model_config.context_len),
460
+ dtype=torch.bool,
461
+ device="cuda",
462
+ ),
463
+ positions=None,
464
+ retrive_index=None,
465
+ retrive_cum_len=None,
466
+ draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
467
+ capture_hidden_mode=CaptureHiddenMode.FULL,
470
468
  )
471
- spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
472
469
 
473
470
  return spec_info
@@ -197,64 +197,6 @@ class ForwardBatch:
197
197
  # For Qwen2-VL
198
198
  mrope_positions: torch.Tensor = None
199
199
 
200
- def compute_mrope_positions(
201
- self, model_runner: ModelRunner, batch: ModelWorkerBatch
202
- ):
203
- device = model_runner.device
204
- hf_config = model_runner.model_config.hf_config
205
- mrope_positions_list = [None] * self.seq_lens.shape[0]
206
- if self.forward_mode.is_decode():
207
- for i, _ in enumerate(mrope_positions_list):
208
- mrope_position_delta = (
209
- 0
210
- if batch.image_inputs[i] is None
211
- else batch.image_inputs[i].mrope_position_delta
212
- )
213
- mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
214
- mrope_position_delta,
215
- int(self.seq_lens[i]) - 1,
216
- int(self.seq_lens[i]),
217
- )
218
- elif self.forward_mode.is_extend():
219
- extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
220
- for i, image_inputs in enumerate(batch.image_inputs):
221
- extend_start_loc, extend_seq_len, extend_prefix_len = (
222
- extend_start_loc_cpu[i],
223
- batch.extend_seq_lens[i],
224
- batch.extend_prefix_lens[i],
225
- )
226
- if image_inputs is None:
227
- # text only
228
- mrope_positions = [
229
- [
230
- pos
231
- for pos in range(
232
- extend_prefix_len, extend_prefix_len + extend_seq_len
233
- )
234
- ]
235
- ] * 3
236
- else:
237
- # TODO: current qwen2-vl do not support radix cache since mrope position calculation
238
- mrope_positions, mrope_position_delta = (
239
- MRotaryEmbedding.get_input_positions(
240
- input_tokens=self.input_ids[
241
- extend_start_loc : extend_start_loc + extend_seq_len
242
- ],
243
- image_grid_thw=image_inputs.image_grid_thws,
244
- vision_start_token_id=hf_config.vision_start_token_id,
245
- spatial_merge_size=hf_config.vision_config.spatial_merge_size,
246
- context_len=0,
247
- )
248
- )
249
- batch.image_inputs[i].mrope_position_delta = mrope_position_delta
250
- mrope_positions_list[i] = mrope_positions
251
-
252
- self.mrope_positions = torch.concat(
253
- [torch.tensor(pos, device=device) for pos in mrope_positions_list],
254
- axis=1,
255
- )
256
- self.mrope_positions = self.mrope_positions.to(torch.int64)
257
-
258
200
  @classmethod
259
201
  def init_new(
260
202
  cls,
@@ -337,7 +279,7 @@ class ForwardBatch:
337
279
  ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
338
280
 
339
281
  if model_runner.model_is_mrope:
340
- ret.compute_mrope_positions(model_runner, batch)
282
+ ret._compute_mrope_positions(model_runner, batch)
341
283
 
342
284
  # Init lora information
343
285
  if model_runner.server_args.lora_paths is not None:
@@ -345,6 +287,63 @@ class ForwardBatch:
345
287
 
346
288
  return ret
347
289
 
290
+ def _compute_mrope_positions(
291
+ self, model_runner: ModelRunner, batch: ModelWorkerBatch
292
+ ):
293
+ device = model_runner.device
294
+ hf_config = model_runner.model_config.hf_config
295
+ mrope_positions_list = [None] * self.seq_lens.shape[0]
296
+ if self.forward_mode.is_decode():
297
+ for i, _ in enumerate(mrope_positions_list):
298
+ mrope_position_delta = (
299
+ 0
300
+ if batch.image_inputs[i] is None
301
+ else batch.image_inputs[i].mrope_position_delta
302
+ )
303
+ mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
304
+ mrope_position_delta,
305
+ int(self.seq_lens[i]) - 1,
306
+ int(self.seq_lens[i]),
307
+ )
308
+ elif self.forward_mode.is_extend():
309
+ extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
310
+ for i, image_inputs in enumerate(batch.image_inputs):
311
+ extend_start_loc, extend_seq_len, extend_prefix_len = (
312
+ extend_start_loc_cpu[i],
313
+ batch.extend_seq_lens[i],
314
+ batch.extend_prefix_lens[i],
315
+ )
316
+ if image_inputs is None:
317
+ # text only
318
+ mrope_positions = [
319
+ [
320
+ pos
321
+ for pos in range(
322
+ extend_prefix_len, extend_prefix_len + extend_seq_len
323
+ )
324
+ ]
325
+ ] * 3
326
+ else:
327
+ # TODO: current qwen2-vl do not support radix cache since mrope position calculation
328
+ mrope_positions, mrope_position_delta = (
329
+ MRotaryEmbedding.get_input_positions(
330
+ input_tokens=self.input_ids[
331
+ extend_start_loc : extend_start_loc + extend_seq_len
332
+ ],
333
+ image_grid_thw=image_inputs.image_grid_thws,
334
+ vision_start_token_id=hf_config.vision_start_token_id,
335
+ spatial_merge_size=hf_config.vision_config.spatial_merge_size,
336
+ context_len=0,
337
+ )
338
+ )
339
+ batch.image_inputs[i].mrope_position_delta = mrope_position_delta
340
+ mrope_positions_list[i] = mrope_positions
341
+ self.mrope_positions = torch.concat(
342
+ [torch.tensor(pos, device=device) for pos in mrope_positions_list],
343
+ axis=1,
344
+ )
345
+ self.mrope_positions = self.mrope_positions.to(torch.int64)
346
+
348
347
 
349
348
  def compute_position_triton(
350
349
  extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
@@ -52,6 +52,7 @@ from sglang.srt.mem_cache.memory_pool import (
52
52
  MLATokenToKVPool,
53
53
  ReqToTokenPool,
54
54
  )
55
+ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
55
56
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
56
57
  from sglang.srt.model_loader import get_model
57
58
  from sglang.srt.server_args import ServerArgs
@@ -529,6 +530,7 @@ class ModelRunner:
529
530
  max_loras_per_batch=self.server_args.max_loras_per_batch,
530
531
  load_config=self.load_config,
531
532
  dtype=self.dtype,
533
+ lora_backend=self.server_args.lora_backend,
532
534
  )
533
535
  logger.info("LoRA manager ready.")
534
536
 
@@ -714,8 +716,6 @@ class ModelRunner:
714
716
 
715
717
  def init_cuda_graphs(self):
716
718
  """Capture cuda graphs."""
717
- from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
718
-
719
719
  self.cuda_graph_runner = None
720
720
 
721
721
  if not self.is_generation: