sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/bench_serving.py +18 -1
  3. sglang/lang/interpreter.py +71 -1
  4. sglang/lang/ir.py +2 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/chatglm.py +78 -0
  7. sglang/srt/configs/dbrx.py +279 -0
  8. sglang/srt/configs/model_config.py +1 -1
  9. sglang/srt/hf_transformers_utils.py +9 -14
  10. sglang/srt/layers/attention/__init__.py +22 -6
  11. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  12. sglang/srt/layers/attention/flashinfer_backend.py +215 -83
  13. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  14. sglang/srt/layers/attention/triton_backend.py +20 -11
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  16. sglang/srt/layers/linear.py +159 -55
  17. sglang/srt/layers/logits_processor.py +170 -215
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
  41. sglang/srt/layers/parameter.py +431 -0
  42. sglang/srt/layers/quantization/__init__.py +3 -2
  43. sglang/srt/layers/quantization/fp8.py +3 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  45. sglang/srt/layers/sampler.py +57 -21
  46. sglang/srt/layers/torchao_utils.py +17 -3
  47. sglang/srt/layers/vocab_parallel_embedding.py +1 -1
  48. sglang/srt/managers/cache_controller.py +307 -0
  49. sglang/srt/managers/data_parallel_controller.py +2 -0
  50. sglang/srt/managers/io_struct.py +1 -2
  51. sglang/srt/managers/schedule_batch.py +33 -3
  52. sglang/srt/managers/schedule_policy.py +159 -90
  53. sglang/srt/managers/scheduler.py +68 -28
  54. sglang/srt/managers/session_controller.py +1 -1
  55. sglang/srt/managers/tokenizer_manager.py +27 -21
  56. sglang/srt/managers/tp_worker.py +16 -4
  57. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  58. sglang/srt/mem_cache/memory_pool.py +206 -1
  59. sglang/srt/metrics/collector.py +22 -30
  60. sglang/srt/model_executor/cuda_graph_runner.py +129 -77
  61. sglang/srt/model_executor/forward_batch_info.py +51 -21
  62. sglang/srt/model_executor/model_runner.py +72 -64
  63. sglang/srt/models/chatglm.py +1 -1
  64. sglang/srt/models/dbrx.py +1 -1
  65. sglang/srt/models/deepseek_v2.py +34 -7
  66. sglang/srt/models/grok.py +109 -29
  67. sglang/srt/models/llama.py +9 -2
  68. sglang/srt/openai_api/adapter.py +0 -17
  69. sglang/srt/openai_api/protocol.py +3 -3
  70. sglang/srt/sampling/sampling_batch_info.py +22 -0
  71. sglang/srt/sampling/sampling_params.py +9 -1
  72. sglang/srt/server.py +20 -13
  73. sglang/srt/server_args.py +120 -58
  74. sglang/srt/speculative/build_eagle_tree.py +347 -0
  75. sglang/srt/speculative/eagle_utils.py +626 -0
  76. sglang/srt/speculative/eagle_worker.py +184 -0
  77. sglang/srt/speculative/spec_info.py +5 -0
  78. sglang/srt/utils.py +47 -7
  79. sglang/test/test_programs.py +23 -1
  80. sglang/test/test_utils.py +36 -7
  81. sglang/version.py +1 -1
  82. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
  83. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
  84. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
  85. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
  86. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -25,14 +25,15 @@ from vllm.distributed import get_tensor_model_parallel_rank
25
25
  from vllm.distributed.parallel_state import graph_capture
26
26
  from vllm.model_executor.custom_op import CustomOp
27
27
 
28
- from sglang.srt.layers.logits_processor import (
29
- LogitsMetadata,
30
- LogitsProcessor,
31
- LogitsProcessorOutput,
32
- )
28
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
33
29
  from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
34
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
35
- from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
30
+ from sglang.srt.layers.torchao_utils import save_gemlite_cache
31
+ from sglang.srt.model_executor.forward_batch_info import (
32
+ CaptureHiddenMode,
33
+ ForwardBatch,
34
+ ForwardMode,
35
+ )
36
+ from sglang.srt.utils import monkey_patch_vllm_all_gather
36
37
 
37
38
  if TYPE_CHECKING:
38
39
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -105,11 +106,6 @@ def set_torch_compile_config():
105
106
  torch._dynamo.config.cache_size_limit = 1024
106
107
 
107
108
 
108
- @maybe_torch_compile(dynamic=True)
109
- def clamp_position(seq_lens):
110
- return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
111
-
112
-
113
109
  class CudaGraphRunner:
114
110
  """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
115
111
 
@@ -128,10 +124,12 @@ class CudaGraphRunner:
128
124
  self.tp_size = self.model_runner.tp_size
129
125
 
130
126
  # Batch sizes to capture
131
- if model_runner.server_args.disable_cuda_graph_padding:
132
- self.capture_bs = list(range(1, 33)) + [64, 128]
133
- else:
134
- self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
127
+ self.capture_bs = self.model_runner.server_args.cuda_graph_bs
128
+ if self.capture_bs is None:
129
+ if model_runner.server_args.disable_cuda_graph_padding:
130
+ self.capture_bs = list(range(1, 33)) + [64, 128]
131
+ else:
132
+ self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
135
133
 
136
134
  if max(self.capture_bs) > model_runner.req_to_token_pool.size:
137
135
  # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
@@ -152,6 +150,21 @@ class CudaGraphRunner:
152
150
  if bs <= model_runner.req_to_token_pool.size
153
151
  and bs <= model_runner.server_args.cuda_graph_max_bs
154
152
  ]
153
+
154
+ self.capture_forward_mode = ForwardMode.DECODE
155
+ self.num_tokens_per_bs = 1
156
+
157
+ if model_runner.spec_algorithm.is_eagle():
158
+ if self.model_runner.is_draft_worker:
159
+ self.num_tokens_per_bs = (
160
+ self.model_runner.server_args.speculative_eagle_topk
161
+ )
162
+ else:
163
+ self.capture_forward_mode = ForwardMode.TARGET_VERIFY
164
+ self.num_tokens_per_bs = (
165
+ self.model_runner.server_args.speculative_num_draft_tokens
166
+ )
167
+
155
168
  self.compile_bs = (
156
169
  [
157
170
  bs
@@ -164,8 +177,8 @@ class CudaGraphRunner:
164
177
 
165
178
  # Attention backend
166
179
  self.max_bs = max(self.capture_bs)
167
- self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
168
-
180
+ self.max_num_token = self.max_bs * self.num_tokens_per_bs
181
+ self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
169
182
  self.seq_len_fill_value = (
170
183
  self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
171
184
  )
@@ -178,14 +191,22 @@ class CudaGraphRunner:
178
191
 
179
192
  # Common inputs
180
193
  with torch.device("cuda"):
181
- self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32)
194
+ self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32)
182
195
  self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
183
196
  self.seq_lens = torch.full(
184
197
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
185
198
  )
186
- self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
199
+ self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32)
200
+ self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
187
201
  self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
188
202
 
203
+ # Speculative_inference
204
+ if model_runner.spec_algorithm.is_eagle():
205
+ self.hidden_states = torch.zeros(
206
+ (self.max_num_token, self.model_runner.model_config.hidden_size),
207
+ dtype=self.model_runner.dtype,
208
+ )
209
+
189
210
  if self.is_encoder_decoder:
190
211
  # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
191
212
  self.encoder_lens = torch.full(
@@ -257,12 +278,12 @@ class CudaGraphRunner:
257
278
  def capture(self):
258
279
  with graph_capture() as graph_capture_context:
259
280
  self.stream = graph_capture_context.stream
260
- capture_bs = (
281
+ capture_range = (
261
282
  tqdm.tqdm(self.capture_bs)
262
283
  if get_tensor_model_parallel_rank() == 0
263
284
  else self.capture_bs
264
285
  )
265
- for bs in capture_bs:
286
+ for bs in capture_range:
266
287
  with patch_model(
267
288
  self.model_runner.model,
268
289
  bs in self.compile_bs,
@@ -276,21 +297,24 @@ class CudaGraphRunner:
276
297
  self.graphs[bs] = graph
277
298
  self.output_buffers[bs] = output_buffers
278
299
 
300
+ # Save gemlite cache after each capture
301
+ save_gemlite_cache()
302
+
279
303
  def capture_one_batch_size(self, bs: int, forward: Callable):
280
304
  graph = torch.cuda.CUDAGraph()
281
305
  stream = self.stream
306
+ num_tokens = bs * self.num_tokens_per_bs
282
307
 
283
308
  # Common inputs
284
- input_ids = self.input_ids[:bs]
309
+ input_ids = self.input_ids[:num_tokens]
285
310
  req_pool_indices = self.req_pool_indices[:bs]
286
311
  seq_lens = self.seq_lens[:bs]
287
- out_cache_loc = self.out_cache_loc[:bs]
312
+ out_cache_loc = self.out_cache_loc[:num_tokens]
313
+ positions = self.positions[:num_tokens]
288
314
  if self.is_encoder_decoder:
289
315
  encoder_lens = self.encoder_lens[:bs]
290
316
  else:
291
317
  encoder_lens = None
292
-
293
- seq_lens_sum = seq_lens.sum().item()
294
318
  mrope_positions = self.mrope_positions[:, :bs]
295
319
 
296
320
  if self.enable_dp_attention:
@@ -300,37 +324,48 @@ class CudaGraphRunner:
300
324
  global_num_tokens = None
301
325
  gathered_buffer = None
302
326
 
327
+ spec_info = self.get_spec_info(num_tokens, positions)
328
+
329
+ forward_batch = ForwardBatch(
330
+ forward_mode=self.capture_forward_mode,
331
+ batch_size=bs,
332
+ input_ids=input_ids,
333
+ req_pool_indices=req_pool_indices,
334
+ seq_lens=seq_lens,
335
+ req_to_token_pool=self.model_runner.req_to_token_pool,
336
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
337
+ attn_backend=self.model_runner.attn_backend,
338
+ out_cache_loc=out_cache_loc,
339
+ seq_lens_sum=seq_lens.sum(),
340
+ encoder_lens=encoder_lens,
341
+ return_logprob=False,
342
+ top_logprobs_nums=[0] * bs,
343
+ positions=positions,
344
+ global_num_tokens=global_num_tokens,
345
+ gathered_buffer=gathered_buffer,
346
+ mrope_positions=mrope_positions,
347
+ spec_algorithm=self.model_runner.spec_algorithm,
348
+ spec_info=spec_info,
349
+ capture_hidden_mode=(
350
+ spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
351
+ ),
352
+ )
353
+
303
354
  # Attention backend
304
355
  self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
305
356
  bs,
357
+ num_tokens,
306
358
  req_pool_indices,
307
359
  seq_lens,
308
360
  encoder_lens,
361
+ forward_batch.forward_mode,
362
+ forward_batch.spec_info,
309
363
  )
310
364
 
311
365
  # Run and capture
312
366
  def run_once():
313
- forward_batch = ForwardBatch(
314
- forward_mode=ForwardMode.DECODE,
315
- batch_size=bs,
316
- input_ids=input_ids,
317
- req_pool_indices=req_pool_indices,
318
- seq_lens=seq_lens,
319
- req_to_token_pool=self.model_runner.req_to_token_pool,
320
- token_to_kv_pool=self.model_runner.token_to_kv_pool,
321
- attn_backend=self.model_runner.attn_backend,
322
- out_cache_loc=out_cache_loc,
323
- seq_lens_sum=seq_lens_sum,
324
- encoder_lens=encoder_lens,
325
- return_logprob=False,
326
- top_logprobs_nums=[0] * bs,
327
- positions=clamp_position(seq_lens),
328
- mrope_positions=mrope_positions,
329
- global_num_tokens=global_num_tokens,
330
- gathered_buffer=gathered_buffer,
331
- )
332
367
  logits_output = forward(input_ids, forward_batch.positions, forward_batch)
333
- return logits_output.next_token_logits
368
+ return logits_output.next_token_logits, logits_output.hidden_states
334
369
 
335
370
  for _ in range(2):
336
371
  torch.cuda.synchronize()
@@ -356,6 +391,7 @@ class CudaGraphRunner:
356
391
  def replay(self, forward_batch: ForwardBatch):
357
392
  assert forward_batch.out_cache_loc is not None
358
393
  raw_bs = forward_batch.batch_size
394
+ raw_num_token = raw_bs * self.num_tokens_per_bs
359
395
 
360
396
  # Pad
361
397
  if self.enable_dp_attention:
@@ -370,15 +406,20 @@ class CudaGraphRunner:
370
406
  self.out_cache_loc.zero_()
371
407
 
372
408
  # Common inputs
373
- self.input_ids[:raw_bs].copy_(forward_batch.input_ids)
409
+ self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
374
410
  self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
375
411
  self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
376
- self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
412
+ self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
413
+ self.positions[:raw_num_token].copy_(forward_batch.positions)
414
+
377
415
  if self.is_encoder_decoder:
378
416
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
379
417
  if forward_batch.mrope_positions is not None:
380
418
  self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
381
419
 
420
+ if hasattr(forward_batch.spec_info, "hidden_states"):
421
+ self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
422
+
382
423
  # Attention backend
383
424
  self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
384
425
  bs,
@@ -386,40 +427,51 @@ class CudaGraphRunner:
386
427
  self.seq_lens,
387
428
  forward_batch.seq_lens_sum + (bs - raw_bs),
388
429
  self.encoder_lens,
430
+ forward_batch.forward_mode,
431
+ forward_batch.spec_info,
389
432
  )
390
433
 
391
434
  # Replay
392
435
  self.graphs[bs].replay()
393
- next_token_logits = self.output_buffers[bs][:raw_bs]
436
+ next_token_logits, hidden_states = self.output_buffers[bs]
437
+
438
+ logits_output = LogitsProcessorOutput(
439
+ next_token_logits=next_token_logits[:raw_num_token],
440
+ hidden_states=(
441
+ hidden_states[:raw_num_token] if hidden_states is not None else None
442
+ ),
443
+ )
444
+ return logits_output
394
445
 
395
- # Extract logprobs
396
- if forward_batch.return_logprob:
397
- logits_metadata = LogitsMetadata(
398
- forward_mode=ForwardMode.DECODE,
399
- top_logprobs_nums=forward_batch.top_logprobs_nums,
446
+ def get_spec_info(self, num_tokens: int, positions: torch.Tensor):
447
+ spec_info = None
448
+ if self.model_runner.spec_algorithm.is_eagle():
449
+ from sglang.srt.speculative.eagle_utils import (
450
+ EAGLEDraftInput,
451
+ EagleVerifyInput,
400
452
  )
401
- next_token_logprobs = (
402
- LogitsProcessor.compute_temp_top_p_normalized_logprobs(
403
- next_token_logits, logits_metadata
453
+
454
+ if self.model_runner.is_draft_worker:
455
+ spec_info = EAGLEDraftInput()
456
+ spec_info.load_server_args(self.model_runner.server_args)
457
+ spec_info.hidden_states = self.hidden_states[:num_tokens]
458
+ spec_info.positions = positions
459
+ spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
460
+ else:
461
+ spec_info = EagleVerifyInput(
462
+ None,
463
+ None,
464
+ None,
465
+ None,
466
+ None,
467
+ None,
468
+ self.model_runner.server_args.speculative_num_draft_tokens,
404
469
  )
405
- )
406
- logits_output = LogitsProcessorOutput(
407
- next_token_logits=next_token_logits,
408
- next_token_logprobs=next_token_logprobs,
409
- )
410
- return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
411
- if return_top_logprob:
412
- (
413
- logits_output.output_top_logprobs_val,
414
- logits_output.output_top_logprobs_idx,
415
- ) = LogitsProcessor.get_top_logprobs(
416
- next_token_logprobs, logits_metadata
417
- )[
418
- 2:4
419
- ]
420
- else:
421
- logits_output = LogitsProcessorOutput(
422
- next_token_logits=next_token_logits,
423
- )
470
+ spec_info.custom_mask = torch.zeros(
471
+ (num_tokens * self.model_runner.model_config.context_len),
472
+ dtype=torch.bool,
473
+ device="cuda",
474
+ )
475
+ spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
424
476
 
425
- return logits_output
477
+ return spec_info
@@ -38,6 +38,7 @@ import triton
38
38
  import triton.language as tl
39
39
 
40
40
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
+ from sglang.srt.utils import maybe_torch_compile
41
42
 
42
43
  if TYPE_CHECKING:
43
44
  from sglang.srt.layers.attention import AttentionBackend
@@ -96,11 +97,33 @@ class ForwardMode(IntEnum):
96
97
  return self == ForwardMode.DRAFT_EXTEND
97
98
 
98
99
  def is_cuda_graph(self):
99
- return self in (ForwardMode.DECODE, ForwardMode.TARGET_VERIFY)
100
+ return (
101
+ self == ForwardMode.DECODE
102
+ or self == ForwardMode.TARGET_VERIFY
103
+ or self == ForwardMode.IDLE
104
+ )
100
105
 
101
106
  def is_dummy_first(self):
102
107
  return self == ForwardMode.DUMMY_FIRST
103
108
 
109
+ def is_decode_or_idle(self):
110
+ return self == ForwardMode.DECODE or self == ForwardMode.IDLE
111
+
112
+
113
+ class CaptureHiddenMode(IntEnum):
114
+ NULL = auto()
115
+ FULL = auto()
116
+ LAST = auto()
117
+
118
+ def need_capture(self):
119
+ return self != CaptureHiddenMode.NULL
120
+
121
+ def is_full(self):
122
+ return self == CaptureHiddenMode.FULL
123
+
124
+ def is_last(self):
125
+ return self == CaptureHiddenMode.LAST
126
+
104
127
 
105
128
  @dataclass
106
129
  class ForwardBatch:
@@ -161,15 +184,16 @@ class ForwardBatch:
161
184
  token_to_kv_pool: BaseTokenToKVPool = None
162
185
  attn_backend: AttentionBackend = None
163
186
 
164
- # Speculative decoding
165
- spec_info: SpecInfo = None
166
- spec_algorithm: SpeculativeAlgorithm = None
167
-
168
187
  # For DP attention
169
188
  global_num_tokens: Optional[List[int]] = None
170
189
  gathered_buffer: Optional[torch.Tensor] = None
171
190
  can_run_dp_cuda_graph: bool = False
172
191
 
192
+ # Speculative decoding
193
+ spec_info: SpecInfo = None
194
+ spec_algorithm: SpeculativeAlgorithm = None
195
+ capture_hidden_mode: CaptureHiddenMode = None
196
+
173
197
  # For Qwen2-VL
174
198
  mrope_positions: torch.Tensor = None
175
199
 
@@ -258,6 +282,9 @@ class ForwardBatch:
258
282
  can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
259
283
  lora_paths=batch.lora_paths,
260
284
  sampling_info=batch.sampling_info,
285
+ spec_algorithm=batch.spec_algorithm,
286
+ spec_info=batch.spec_info,
287
+ capture_hidden_mode=batch.capture_hidden_mode,
261
288
  input_embeds=batch.input_embeds,
262
289
  )
263
290
 
@@ -270,10 +297,21 @@ class ForwardBatch:
270
297
  )
271
298
 
272
299
  if ret.forward_mode.is_idle():
300
+ ret.positions = torch.empty((0,), device=device)
273
301
  return ret
274
302
 
303
+ # Override the positions with spec_info
304
+ if (
305
+ ret.spec_info is not None
306
+ and getattr(ret.spec_info, "positions", None) is not None
307
+ ):
308
+ ret.positions = ret.spec_info.positions
309
+
275
310
  # Init position information
276
- if not ret.forward_mode.is_decode():
311
+ if ret.forward_mode.is_decode():
312
+ if ret.positions is None:
313
+ ret.positions = clamp_position(batch.seq_lens)
314
+ else:
277
315
  ret.extend_seq_lens = torch.tensor(
278
316
  batch.extend_seq_lens, dtype=torch.int32
279
317
  ).to(device, non_blocking=True)
@@ -282,13 +320,15 @@ class ForwardBatch:
282
320
  ).to(device, non_blocking=True)
283
321
  if model_runner.server_args.attention_backend != "torch_native":
284
322
  ret.extend_num_tokens = batch.extend_num_tokens
285
- ret.positions, ret.extend_start_loc = compute_position_triton(
323
+ positions, ret.extend_start_loc = compute_position_triton(
286
324
  ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
287
325
  )
288
326
  else:
289
- ret.positions, ret.extend_start_loc = compute_position_torch(
327
+ positions, ret.extend_start_loc = compute_position_torch(
290
328
  ret.extend_prefix_lens, ret.extend_seq_lens
291
329
  )
330
+ if ret.positions is None:
331
+ ret.positions = positions
292
332
  ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
293
333
  ret.extend_seq_lens_cpu = batch.extend_seq_lens
294
334
  ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
@@ -377,16 +417,6 @@ def compute_position_torch(
377
417
  return positions.to(torch.int64), extend_start_loc
378
418
 
379
419
 
380
- class CaptureHiddenMode(IntEnum):
381
- NULL = auto()
382
- FULL = auto()
383
- LAST = auto()
384
-
385
- def need_capture(self):
386
- return self != CaptureHiddenMode.NULL
387
-
388
- def is_full(self):
389
- return self == CaptureHiddenMode.FULL
390
-
391
- def is_last(self):
392
- return self == CaptureHiddenMode.LAST
420
+ @maybe_torch_compile(dynamic=True)
421
+ def clamp_position(seq_lens):
422
+ return torch.clamp((seq_lens - 1), min=0).to(torch.int64)