sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/layers/activation.py +10 -5
  5. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  6. sglang/srt/layers/attention/triton_backend.py +71 -7
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  8. sglang/srt/layers/layernorm.py +1 -5
  9. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  18. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
  19. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  20. sglang/srt/layers/moe/topk.py +4 -0
  21. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  22. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/fp8_kernel.py +140 -2
  44. sglang/srt/layers/rotary_embedding.py +1 -3
  45. sglang/srt/layers/sampler.py +4 -4
  46. sglang/srt/lora/backend/__init__.py +8 -0
  47. sglang/srt/lora/backend/base_backend.py +95 -0
  48. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  49. sglang/srt/lora/backend/triton_backend.py +61 -0
  50. sglang/srt/lora/lora.py +127 -112
  51. sglang/srt/lora/lora_manager.py +50 -18
  52. sglang/srt/lora/triton_ops/__init__.py +5 -0
  53. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  54. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  55. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  56. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  57. sglang/srt/model_executor/forward_batch_info.py +58 -59
  58. sglang/srt/model_executor/model_runner.py +2 -2
  59. sglang/srt/models/qwen2_vl.py +1 -1
  60. sglang/srt/server_args.py +13 -2
  61. sglang/srt/speculative/build_eagle_tree.py +4 -2
  62. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  63. sglang/srt/speculative/eagle_utils.py +361 -372
  64. sglang/srt/speculative/eagle_worker.py +177 -45
  65. sglang/srt/utils.py +7 -0
  66. sglang/test/runners.py +2 -0
  67. sglang/version.py +1 -1
  68. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +15 -6
  69. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +72 -33
  70. sglang/srt/layers/custom_op_util.py +0 -25
  71. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
  72. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
  73. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -21,8 +21,8 @@ from typing import TYPE_CHECKING, Callable
21
21
 
22
22
  import torch
23
23
  import tqdm
24
- from vllm.model_executor.custom_op import CustomOp
25
24
 
25
+ from sglang.srt.custom_op import CustomOp
26
26
  from sglang.srt.distributed import get_tensor_model_parallel_rank
27
27
  from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
28
28
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -103,69 +103,75 @@ def set_torch_compile_config():
103
103
  torch._dynamo.config.cache_size_limit = 1024
104
104
 
105
105
 
106
+ def get_batch_sizes_to_capture(model_runner: ModelRunner):
107
+ server_args = model_runner.server_args
108
+ capture_bs = server_args.cuda_graph_bs
109
+ if capture_bs is None:
110
+ if server_args.disable_cuda_graph_padding:
111
+ capture_bs = list(range(1, 33)) + [64, 128]
112
+ else:
113
+ capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
114
+ if max(capture_bs) > model_runner.req_to_token_pool.size:
115
+ # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
116
+ # is very samll. We add more values here to make sure we capture the maximum bs.
117
+ capture_bs = list(
118
+ sorted(
119
+ set(
120
+ capture_bs
121
+ + [model_runner.req_to_token_pool.size - 1]
122
+ + [model_runner.req_to_token_pool.size]
123
+ )
124
+ )
125
+ )
126
+ capture_bs = [
127
+ bs
128
+ for bs in capture_bs
129
+ if bs <= model_runner.req_to_token_pool.size
130
+ and bs <= server_args.cuda_graph_max_bs
131
+ ]
132
+ compile_bs = (
133
+ [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
134
+ if server_args.enable_torch_compile
135
+ else []
136
+ )
137
+ return capture_bs, compile_bs
138
+
139
+
140
+ # Reuse this memory pool across all cuda graph runners.
141
+ global_graph_memory_pool = None
142
+
143
+
144
+ def get_global_graph_memory_pool():
145
+ return global_graph_memory_pool
146
+
147
+
148
+ def set_global_graph_memory_pool(val):
149
+ global global_graph_memory_pool
150
+ global_graph_memory_pool = val
151
+
152
+
106
153
  class CudaGraphRunner:
107
154
  """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
108
155
 
109
- def __init__(self, model_runner: "ModelRunner"):
156
+ def __init__(self, model_runner: ModelRunner):
110
157
  # Parse args
111
158
  self.model_runner = model_runner
112
159
  self.graphs = {}
113
- self.input_buffers = {}
114
160
  self.output_buffers = {}
115
- self.flashinfer_handlers = {}
116
- self.graph_memory_pool = None
117
- self.use_torch_compile = model_runner.server_args.enable_torch_compile
161
+ self.enable_torch_compile = model_runner.server_args.enable_torch_compile
118
162
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
119
- self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
120
- self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
121
- self.tp_size = self.model_runner.tp_size
122
- self.dp_size = self.model_runner.server_args.dp_size
163
+ self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
164
+ self.enable_dp_attention = model_runner.server_args.enable_dp_attention
165
+ self.tp_size = model_runner.server_args.tp_size
166
+ self.dp_size = model_runner.server_args.dp_size
123
167
 
124
168
  # Batch sizes to capture
125
- self.capture_bs = self.model_runner.server_args.cuda_graph_bs
126
- if self.capture_bs is None:
127
- if model_runner.server_args.disable_cuda_graph_padding:
128
- self.capture_bs = list(range(1, 33)) + [64, 128]
129
- else:
130
- self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
131
-
132
- if max(self.capture_bs) > model_runner.req_to_token_pool.size:
133
- # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
134
- # is very samll. We add more values here to make sure we capture the maximum bs.
135
- self.capture_bs = list(
136
- sorted(
137
- set(
138
- self.capture_bs
139
- + [model_runner.req_to_token_pool.size - 1]
140
- + [model_runner.req_to_token_pool.size]
141
- )
142
- )
143
- )
144
-
145
- self.capture_bs = [
146
- bs
147
- for bs in self.capture_bs
148
- if bs <= model_runner.req_to_token_pool.size
149
- and bs <= model_runner.server_args.cuda_graph_max_bs
150
- ]
151
-
152
- self.compile_bs = (
153
- [
154
- bs
155
- for bs in self.capture_bs
156
- if bs <= self.model_runner.server_args.torch_compile_max_bs
157
- ]
158
- if self.use_torch_compile
159
- else []
160
- )
161
-
169
+ self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
162
170
  self.capture_forward_mode = ForwardMode.DECODE
163
171
  self.num_tokens_per_bs = 1
164
172
  if model_runner.spec_algorithm.is_eagle():
165
173
  if self.model_runner.is_draft_worker:
166
- self.num_tokens_per_bs = (
167
- self.model_runner.server_args.speculative_eagle_topk
168
- )
174
+ raise RuntimeError("This should not happen")
169
175
  else:
170
176
  self.capture_forward_mode = ForwardMode.TARGET_VERIFY
171
177
  self.num_tokens_per_bs = (
@@ -182,10 +188,10 @@ class CudaGraphRunner:
182
188
  # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
183
189
  self.encoder_len_fill_value = 0
184
190
 
185
- if self.use_torch_compile:
191
+ if self.enable_torch_compile:
186
192
  set_torch_compile_config()
187
193
 
188
- # Common inputs
194
+ # Graph inputs
189
195
  with torch.device("cuda"):
190
196
  self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
191
197
  self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
@@ -301,7 +307,7 @@ class CudaGraphRunner:
301
307
  stream = self.stream
302
308
  num_tokens = bs * self.num_tokens_per_bs
303
309
 
304
- # Common inputs
310
+ # Graph inputs
305
311
  input_ids = self.input_ids[:num_tokens]
306
312
  req_pool_indices = self.req_pool_indices[:bs]
307
313
  seq_lens = self.seq_lens[:bs]
@@ -320,7 +326,7 @@ class CudaGraphRunner:
320
326
  global_num_tokens = None
321
327
  gathered_buffer = None
322
328
 
323
- spec_info = self.get_spec_info(num_tokens, positions)
329
+ spec_info = self.get_spec_info(num_tokens)
324
330
 
325
331
  forward_batch = ForwardBatch(
326
332
  forward_mode=self.capture_forward_mode,
@@ -335,7 +341,6 @@ class CudaGraphRunner:
335
341
  seq_lens_sum=seq_lens.sum(),
336
342
  encoder_lens=encoder_lens,
337
343
  return_logprob=False,
338
- top_logprobs_nums=[0] * bs,
339
344
  positions=positions,
340
345
  global_num_tokens=global_num_tokens,
341
346
  gathered_buffer=gathered_buffer,
@@ -375,13 +380,14 @@ class CudaGraphRunner:
375
380
  torch.cuda.synchronize()
376
381
  self.model_runner.tp_group.barrier()
377
382
 
378
- with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
383
+ global global_graph_memory_pool
384
+ with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
379
385
  out = run_once()
380
386
 
381
387
  torch.cuda.synchronize()
382
388
  self.model_runner.tp_group.barrier()
383
389
 
384
- self.graph_memory_pool = graph.pool()
390
+ global_graph_memory_pool = graph.pool()
385
391
  return graph, out
386
392
 
387
393
  def replay(self, forward_batch: ForwardBatch):
@@ -439,35 +445,26 @@ class CudaGraphRunner:
439
445
  )
440
446
  return logits_output
441
447
 
442
- def get_spec_info(self, num_tokens: int, positions: torch.Tensor):
448
+ def get_spec_info(self, num_tokens: int):
443
449
  spec_info = None
444
450
  if self.model_runner.spec_algorithm.is_eagle():
445
- from sglang.srt.speculative.eagle_utils import (
446
- EAGLEDraftInput,
447
- EagleVerifyInput,
448
- )
451
+ from sglang.srt.speculative.eagle_utils import EagleVerifyInput
449
452
 
450
453
  if self.model_runner.is_draft_worker:
451
- spec_info = EAGLEDraftInput()
452
- spec_info.load_server_args(self.model_runner.server_args)
453
- spec_info.hidden_states = self.hidden_states[:num_tokens]
454
- spec_info.positions = positions
455
- spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
454
+ raise RuntimeError("This should not happen.")
456
455
  else:
457
456
  spec_info = EagleVerifyInput(
458
- None,
459
- None,
460
- None,
461
- None,
462
- None,
463
- None,
464
- self.model_runner.server_args.speculative_num_draft_tokens,
465
- )
466
- spec_info.custom_mask = torch.zeros(
467
- (num_tokens * self.model_runner.model_config.context_len),
468
- dtype=torch.bool,
469
- device="cuda",
457
+ draft_token=None,
458
+ custom_mask=torch.zeros(
459
+ (num_tokens * self.model_runner.model_config.context_len),
460
+ dtype=torch.bool,
461
+ device="cuda",
462
+ ),
463
+ positions=None,
464
+ retrive_index=None,
465
+ retrive_cum_len=None,
466
+ draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
467
+ capture_hidden_mode=CaptureHiddenMode.FULL,
470
468
  )
471
- spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
472
469
 
473
470
  return spec_info
@@ -197,64 +197,6 @@ class ForwardBatch:
197
197
  # For Qwen2-VL
198
198
  mrope_positions: torch.Tensor = None
199
199
 
200
- def compute_mrope_positions(
201
- self, model_runner: ModelRunner, batch: ModelWorkerBatch
202
- ):
203
- device = model_runner.device
204
- hf_config = model_runner.model_config.hf_config
205
- mrope_positions_list = [None] * self.seq_lens.shape[0]
206
- if self.forward_mode.is_decode():
207
- for i, _ in enumerate(mrope_positions_list):
208
- mrope_position_delta = (
209
- 0
210
- if batch.image_inputs[i] is None
211
- else batch.image_inputs[i].mrope_position_delta
212
- )
213
- mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
214
- mrope_position_delta,
215
- int(self.seq_lens[i]) - 1,
216
- int(self.seq_lens[i]),
217
- )
218
- elif self.forward_mode.is_extend():
219
- extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
220
- for i, image_inputs in enumerate(batch.image_inputs):
221
- extend_start_loc, extend_seq_len, extend_prefix_len = (
222
- extend_start_loc_cpu[i],
223
- batch.extend_seq_lens[i],
224
- batch.extend_prefix_lens[i],
225
- )
226
- if image_inputs is None:
227
- # text only
228
- mrope_positions = [
229
- [
230
- pos
231
- for pos in range(
232
- extend_prefix_len, extend_prefix_len + extend_seq_len
233
- )
234
- ]
235
- ] * 3
236
- else:
237
- # TODO: current qwen2-vl do not support radix cache since mrope position calculation
238
- mrope_positions, mrope_position_delta = (
239
- MRotaryEmbedding.get_input_positions(
240
- input_tokens=self.input_ids[
241
- extend_start_loc : extend_start_loc + extend_seq_len
242
- ],
243
- image_grid_thw=image_inputs.image_grid_thws,
244
- vision_start_token_id=hf_config.vision_start_token_id,
245
- spatial_merge_size=hf_config.vision_config.spatial_merge_size,
246
- context_len=0,
247
- )
248
- )
249
- batch.image_inputs[i].mrope_position_delta = mrope_position_delta
250
- mrope_positions_list[i] = mrope_positions
251
-
252
- self.mrope_positions = torch.concat(
253
- [torch.tensor(pos, device=device) for pos in mrope_positions_list],
254
- axis=1,
255
- )
256
- self.mrope_positions = self.mrope_positions.to(torch.int64)
257
-
258
200
  @classmethod
259
201
  def init_new(
260
202
  cls,
@@ -337,7 +279,7 @@ class ForwardBatch:
337
279
  ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
338
280
 
339
281
  if model_runner.model_is_mrope:
340
- ret.compute_mrope_positions(model_runner, batch)
282
+ ret._compute_mrope_positions(model_runner, batch)
341
283
 
342
284
  # Init lora information
343
285
  if model_runner.server_args.lora_paths is not None:
@@ -345,6 +287,63 @@ class ForwardBatch:
345
287
 
346
288
  return ret
347
289
 
290
+ def _compute_mrope_positions(
291
+ self, model_runner: ModelRunner, batch: ModelWorkerBatch
292
+ ):
293
+ device = model_runner.device
294
+ hf_config = model_runner.model_config.hf_config
295
+ mrope_positions_list = [None] * self.seq_lens.shape[0]
296
+ if self.forward_mode.is_decode():
297
+ for i, _ in enumerate(mrope_positions_list):
298
+ mrope_position_delta = (
299
+ 0
300
+ if batch.image_inputs[i] is None
301
+ else batch.image_inputs[i].mrope_position_delta
302
+ )
303
+ mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
304
+ mrope_position_delta,
305
+ int(self.seq_lens[i]) - 1,
306
+ int(self.seq_lens[i]),
307
+ )
308
+ elif self.forward_mode.is_extend():
309
+ extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
310
+ for i, image_inputs in enumerate(batch.image_inputs):
311
+ extend_start_loc, extend_seq_len, extend_prefix_len = (
312
+ extend_start_loc_cpu[i],
313
+ batch.extend_seq_lens[i],
314
+ batch.extend_prefix_lens[i],
315
+ )
316
+ if image_inputs is None:
317
+ # text only
318
+ mrope_positions = [
319
+ [
320
+ pos
321
+ for pos in range(
322
+ extend_prefix_len, extend_prefix_len + extend_seq_len
323
+ )
324
+ ]
325
+ ] * 3
326
+ else:
327
+ # TODO: current qwen2-vl do not support radix cache since mrope position calculation
328
+ mrope_positions, mrope_position_delta = (
329
+ MRotaryEmbedding.get_input_positions(
330
+ input_tokens=self.input_ids[
331
+ extend_start_loc : extend_start_loc + extend_seq_len
332
+ ],
333
+ image_grid_thw=image_inputs.image_grid_thws,
334
+ vision_start_token_id=hf_config.vision_start_token_id,
335
+ spatial_merge_size=hf_config.vision_config.spatial_merge_size,
336
+ context_len=0,
337
+ )
338
+ )
339
+ batch.image_inputs[i].mrope_position_delta = mrope_position_delta
340
+ mrope_positions_list[i] = mrope_positions
341
+ self.mrope_positions = torch.concat(
342
+ [torch.tensor(pos, device=device) for pos in mrope_positions_list],
343
+ axis=1,
344
+ )
345
+ self.mrope_positions = self.mrope_positions.to(torch.int64)
346
+
348
347
 
349
348
  def compute_position_triton(
350
349
  extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
@@ -52,6 +52,7 @@ from sglang.srt.mem_cache.memory_pool import (
52
52
  MLATokenToKVPool,
53
53
  ReqToTokenPool,
54
54
  )
55
+ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
55
56
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
56
57
  from sglang.srt.model_loader import get_model
57
58
  from sglang.srt.server_args import ServerArgs
@@ -529,6 +530,7 @@ class ModelRunner:
529
530
  max_loras_per_batch=self.server_args.max_loras_per_batch,
530
531
  load_config=self.load_config,
531
532
  dtype=self.dtype,
533
+ lora_backend=self.server_args.lora_backend,
532
534
  )
533
535
  logger.info("LoRA manager ready.")
534
536
 
@@ -714,8 +716,6 @@ class ModelRunner:
714
716
 
715
717
  def init_cuda_graphs(self):
716
718
  """Capture cuda graphs."""
717
- from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
718
-
719
719
  self.cuda_graph_runner = None
720
720
 
721
721
  if not self.is_generation:
@@ -31,10 +31,10 @@ import torch
31
31
  import torch.nn as nn
32
32
  import torch.nn.functional as F
33
33
  from einops import rearrange
34
- from vllm.model_executor.layers.activation import QuickGELU
35
34
 
36
35
  from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
37
36
  from sglang.srt.hf_transformers_utils import get_processor
37
+ from sglang.srt.layers.activation import QuickGELU
38
38
  from sglang.srt.layers.attention.vision import VisionAttention
39
39
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
sglang/srt/server_args.py CHANGED
@@ -113,6 +113,7 @@ class ServerArgs:
113
113
  # LoRA
114
114
  lora_paths: Optional[List[str]] = None
115
115
  max_loras_per_batch: int = 8
116
+ lora_backend: str = "triton"
116
117
 
117
118
  # Kernel backend
118
119
  attention_backend: Optional[str] = None
@@ -273,6 +274,10 @@ class ServerArgs:
273
274
  ) and check_gguf_file(self.model_path):
274
275
  self.quantization = self.load_format = "gguf"
275
276
 
277
+ # AMD-specific Triton attention KV splits default number
278
+ if is_hip():
279
+ self.triton_attention_num_kv_splits = 16
280
+
276
281
  @staticmethod
277
282
  def add_cli_args(parser: argparse.ArgumentParser):
278
283
  # Model and port args
@@ -649,13 +654,19 @@ class ServerArgs:
649
654
  nargs="*",
650
655
  default=None,
651
656
  action=LoRAPathAction,
652
- help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
657
+ help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
653
658
  )
654
659
  parser.add_argument(
655
660
  "--max-loras-per-batch",
656
661
  type=int,
657
662
  default=8,
658
- help="Maximum number of adapters for a running batch, include base-only request",
663
+ help="Maximum number of adapters for a running batch, include base-only request.",
664
+ )
665
+ parser.add_argument(
666
+ "--lora-backend",
667
+ type=str,
668
+ default="triton",
669
+ help="Choose the kernel backend for multi-LoRA serving.",
659
670
  )
660
671
 
661
672
  # Kernel backend
@@ -79,11 +79,13 @@ __global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected
79
79
  )
80
80
 
81
81
 
82
- def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token):
82
+ def build_tree_kernel(
83
+ parent_list, top_score_index, seq_lens, seq_lens_sum, topk, depth, draft_token
84
+ ):
83
85
  bs = seq_lens.numel()
84
86
  device = parent_list.device
85
87
  tree_mask = torch.full(
86
- (torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,),
88
+ (seq_lens_sum * draft_token + draft_token * draft_token * bs,),
87
89
  True,
88
90
  device=device,
89
91
  )
@@ -0,0 +1,213 @@
1
+ from __future__ import annotations
2
+
3
+ import bisect
4
+ import time
5
+ from typing import TYPE_CHECKING, Callable
6
+
7
+ import torch
8
+
9
+ from sglang.srt.model_executor.cuda_graph_runner import (
10
+ CudaGraphRunner,
11
+ get_batch_sizes_to_capture,
12
+ get_global_graph_memory_pool,
13
+ set_global_graph_memory_pool,
14
+ set_torch_compile_config,
15
+ )
16
+ from sglang.srt.model_executor.forward_batch_info import (
17
+ CaptureHiddenMode,
18
+ ForwardBatch,
19
+ ForwardMode,
20
+ )
21
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput
22
+
23
+ if TYPE_CHECKING:
24
+ from sglang.srt.model_executor.model_runner import ModelRunner
25
+ from sglang.srt.speculative.eagle_worker import EAGLEWorker
26
+
27
+
28
+ class EAGLEDraftCudaGraphRunner:
29
+ def __init__(self, eagle_worker: EAGLEWorker):
30
+ # Parse args
31
+ self.eagle_worker = eagle_worker
32
+ self.model_runner = model_runner = eagle_worker.model_runner
33
+ self.graphs = {}
34
+ self.output_buffers = {}
35
+ self.enable_torch_compile = model_runner.server_args.enable_torch_compile
36
+ self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
37
+ self.tp_size = self.model_runner.tp_size
38
+ self.dp_size = model_runner.server_args.dp_size
39
+ self.topk = model_runner.server_args.speculative_eagle_topk
40
+ self.speculative_num_steps = model_runner.server_args.speculative_num_steps
41
+ server_args = model_runner.server_args
42
+
43
+ assert self.disable_padding
44
+
45
+ # Batch sizes to capture
46
+ self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
47
+ self.num_tokens_per_bs = server_args.speculative_eagle_topk
48
+
49
+ # Attention backend
50
+ self.max_bs = max(self.capture_bs)
51
+ self.max_num_token = self.max_bs * self.num_tokens_per_bs
52
+ self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token)
53
+ self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
54
+ 0
55
+ ].get_cuda_graph_seq_len_fill_value()
56
+
57
+ if self.enable_torch_compile:
58
+ set_torch_compile_config()
59
+
60
+ # Graph inputs
61
+ with torch.device("cuda"):
62
+ self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
63
+ self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
64
+ self.seq_lens = torch.full(
65
+ (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
66
+ )
67
+ self.out_cache_loc = torch.zeros(
68
+ (self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
69
+ )
70
+ self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
71
+ self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
72
+ self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
73
+ self.hidden_states = torch.zeros(
74
+ (self.max_bs, self.model_runner.model_config.hidden_size),
75
+ dtype=self.model_runner.dtype,
76
+ )
77
+
78
+ # Capture
79
+ try:
80
+ self.capture()
81
+ except RuntimeError as e:
82
+ raise Exception(
83
+ f"Capture cuda graph failed: {e}\n"
84
+ "Possible solutions:\n"
85
+ "1. disable cuda graph by --disable-cuda-graph\n"
86
+ "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
87
+ "3. disable torch compile by not using --enable-torch-compile\n"
88
+ "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
89
+ )
90
+
91
+ def can_run(self, forward_batch: ForwardBatch):
92
+ is_bs_supported = (
93
+ forward_batch.batch_size in self.graphs
94
+ if self.disable_padding
95
+ else forward_batch.batch_size <= self.max_bs
96
+ )
97
+ return is_bs_supported
98
+
99
+ def capture(self):
100
+ CudaGraphRunner.capture(self)
101
+
102
+ def capture_one_batch_size(self, num_seqs: int, forward: Callable):
103
+ graph = torch.cuda.CUDAGraph()
104
+ stream = self.stream
105
+ num_tokens = num_seqs * self.num_tokens_per_bs
106
+
107
+ # Graph inputs
108
+ req_pool_indices = self.req_pool_indices[:num_seqs]
109
+ seq_lens = self.seq_lens[:num_seqs]
110
+ out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
111
+ positions = self.positions[:num_tokens]
112
+ topk_p = self.topk_p[:num_seqs]
113
+ topk_index = self.topk_index[:num_seqs]
114
+ hidden_states = self.hidden_states[:num_seqs]
115
+
116
+ spec_info = EagleDraftInput(
117
+ topk_p=topk_p,
118
+ topk_index=topk_index,
119
+ hidden_states=hidden_states,
120
+ )
121
+
122
+ # Forward batch
123
+ forward_batch = ForwardBatch(
124
+ forward_mode=ForwardMode.DECODE,
125
+ batch_size=num_seqs,
126
+ input_ids=None,
127
+ req_pool_indices=req_pool_indices,
128
+ seq_lens=seq_lens,
129
+ req_to_token_pool=self.model_runner.req_to_token_pool,
130
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
131
+ out_cache_loc=out_cache_loc,
132
+ seq_lens_sum=seq_lens.sum(),
133
+ return_logprob=False,
134
+ positions=positions,
135
+ spec_algorithm=self.model_runner.spec_algorithm,
136
+ spec_info=spec_info,
137
+ capture_hidden_mode=(
138
+ spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
139
+ ),
140
+ )
141
+
142
+ # Attention backend
143
+ self.model_runner.draft_attn_backend.init_forward_metadata_capture_cuda_graph(
144
+ forward_batch
145
+ )
146
+
147
+ # Run and capture
148
+ def run_once():
149
+ # Backup two fileds, which will be modified in-place in `draft_forward`.
150
+ output_cache_loc_backup = forward_batch.out_cache_loc
151
+ hidden_states_backup = forward_batch.spec_info.hidden_states
152
+
153
+ ret = self.eagle_worker.draft_forward(forward_batch)
154
+
155
+ forward_batch.out_cache_loc = output_cache_loc_backup
156
+ forward_batch.spec_info.hidden_states = hidden_states_backup
157
+ return ret
158
+
159
+ for _ in range(2):
160
+ torch.cuda.synchronize()
161
+ self.model_runner.tp_group.barrier()
162
+
163
+ run_once()
164
+
165
+ torch.cuda.synchronize()
166
+ self.model_runner.tp_group.barrier()
167
+
168
+ torch.cuda.synchronize()
169
+ self.model_runner.tp_group.barrier()
170
+
171
+ with torch.cuda.graph(
172
+ graph, pool=get_global_graph_memory_pool(), stream=stream
173
+ ):
174
+ out = run_once()
175
+
176
+ torch.cuda.synchronize()
177
+ self.model_runner.tp_group.barrier()
178
+
179
+ set_global_graph_memory_pool(graph.pool())
180
+ return graph, out
181
+
182
+ def replay(self, forward_batch: ForwardBatch):
183
+ assert forward_batch.out_cache_loc is not None
184
+ raw_bs = forward_batch.batch_size
185
+ raw_num_token = raw_bs * self.num_tokens_per_bs
186
+
187
+ # Pad
188
+ index = bisect.bisect_left(self.capture_bs, raw_bs)
189
+ bs = self.capture_bs[index]
190
+ if bs != raw_bs:
191
+ self.seq_lens.fill_(1)
192
+ self.out_cache_loc.zero_()
193
+
194
+ # Common inputs
195
+ self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
196
+ self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
197
+ self.out_cache_loc[: raw_num_token * self.speculative_num_steps].copy_(
198
+ forward_batch.out_cache_loc
199
+ )
200
+ self.positions[:raw_num_token].copy_(forward_batch.positions)
201
+ self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p)
202
+ self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
203
+ self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
204
+
205
+ # Attention backend
206
+ self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
207
+ forward_batch
208
+ )
209
+
210
+ # Replay
211
+ self.graphs[bs].replay()
212
+
213
+ return self.output_buffers[bs]