sglang 0.4.1.post7__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. sglang/bench_offline_throughput.py +17 -11
  2. sglang/bench_one_batch.py +14 -6
  3. sglang/bench_serving.py +47 -44
  4. sglang/lang/chat_template.py +31 -0
  5. sglang/srt/configs/load_config.py +1 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
  7. sglang/srt/entrypoints/engine.py +5 -2
  8. sglang/srt/entrypoints/http_server.py +24 -0
  9. sglang/srt/function_call_parser.py +494 -0
  10. sglang/srt/layers/activation.py +5 -5
  11. sglang/srt/layers/dp_attention.py +3 -1
  12. sglang/srt/layers/layernorm.py +5 -5
  13. sglang/srt/layers/linear.py +24 -9
  14. sglang/srt/layers/logits_processor.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/layer.py +20 -12
  16. sglang/srt/layers/moe/fused_moe_native.py +17 -3
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  18. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
  19. sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
  20. sglang/srt/layers/parameter.py +16 -7
  21. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  22. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  23. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  25. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  27. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  29. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/fp8.py +4 -1
  31. sglang/srt/layers/rotary_embedding.py +6 -1
  32. sglang/srt/layers/sampler.py +28 -8
  33. sglang/srt/layers/torchao_utils.py +12 -6
  34. sglang/srt/managers/detokenizer_manager.py +1 -0
  35. sglang/srt/managers/io_struct.py +36 -5
  36. sglang/srt/managers/schedule_batch.py +31 -25
  37. sglang/srt/managers/scheduler.py +61 -35
  38. sglang/srt/managers/tokenizer_manager.py +4 -0
  39. sglang/srt/model_executor/cuda_graph_runner.py +23 -25
  40. sglang/srt/model_executor/forward_batch_info.py +5 -7
  41. sglang/srt/model_executor/model_runner.py +7 -4
  42. sglang/srt/model_loader/loader.py +75 -0
  43. sglang/srt/model_loader/weight_utils.py +91 -5
  44. sglang/srt/models/commandr.py +14 -2
  45. sglang/srt/models/dbrx.py +9 -1
  46. sglang/srt/models/deepseek_v2.py +3 -3
  47. sglang/srt/models/gemma2.py +9 -1
  48. sglang/srt/models/grok.py +1 -0
  49. sglang/srt/models/minicpm3.py +3 -3
  50. sglang/srt/models/torch_native_llama.py +17 -4
  51. sglang/srt/openai_api/adapter.py +139 -37
  52. sglang/srt/openai_api/protocol.py +5 -4
  53. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  54. sglang/srt/sampling/sampling_batch_info.py +4 -14
  55. sglang/srt/server.py +2 -2
  56. sglang/srt/server_args.py +20 -1
  57. sglang/srt/speculative/eagle_utils.py +37 -15
  58. sglang/srt/speculative/eagle_worker.py +11 -13
  59. sglang/srt/utils.py +62 -65
  60. sglang/test/test_programs.py +1 -0
  61. sglang/test/test_utils.py +81 -22
  62. sglang/version.py +1 -1
  63. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/METADATA +7 -7
  64. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/RECORD +67 -56
  65. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  66. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  67. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,19 @@
1
1
  import logging
2
- from typing import Dict, List
2
+ from typing import List
3
3
 
4
4
  import torch
5
+ import torch.distributed as dist
5
6
  from torch import nn
6
7
 
8
+ from sglang.srt.distributed import get_tensor_model_parallel_group
9
+ from sglang.srt.layers.dp_attention import get_attention_tp_group
7
10
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
8
11
  from sglang.srt.managers.schedule_batch import global_server_args_dict
9
- from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
10
12
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
11
- from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
13
+ from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available
12
14
 
13
- if is_flashinfer_available():
14
- from flashinfer.sampling import (
15
+ if is_cuda_available():
16
+ from sgl_kernel import (
15
17
  min_p_sampling_from_probs,
16
18
  top_k_renorm_prob,
17
19
  top_k_top_p_sampling_from_probs,
@@ -21,11 +23,17 @@ if is_flashinfer_available():
21
23
 
22
24
  logger = logging.getLogger(__name__)
23
25
 
26
+ SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
27
+
24
28
 
25
29
  class Sampler(nn.Module):
26
30
  def __init__(self):
27
31
  super().__init__()
28
32
  self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
33
+ self.tp_sync_group = get_tensor_model_parallel_group().device_group
34
+
35
+ if global_server_args_dict["enable_dp_attention"]:
36
+ self.tp_sync_group = get_attention_tp_group().device_group
29
37
 
30
38
  def forward(
31
39
  self,
@@ -109,8 +117,6 @@ class Sampler(nn.Module):
109
117
  f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
110
118
  )
111
119
 
112
- batch_next_token_ids = batch_next_token_ids.to(torch.int32)
113
-
114
120
  # Attach logprobs to logits_output (in-place modification)
115
121
  if return_logprob:
116
122
  if any(x > 0 for x in top_logprobs_nums):
@@ -124,7 +130,21 @@ class Sampler(nn.Module):
124
130
  batch_next_token_ids,
125
131
  ]
126
132
 
127
- return batch_next_token_ids
133
+ if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars:
134
+ # For performance reasons, SGLang does not sync the final token IDs across TP ranks by default.
135
+ # This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators:
136
+ # the last all-reduce, the last lm_head matmul, and all sampling kernels.
137
+ # These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic.
138
+ # In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized.
139
+ # When using xgrammar, this becomes more likely so we also do the sync when grammar is used.
140
+
141
+ torch.distributed.all_reduce(
142
+ batch_next_token_ids,
143
+ op=dist.ReduceOp.MIN,
144
+ group=self.tp_sync_group,
145
+ )
146
+
147
+ return batch_next_token_ids.to(torch.int32)
128
148
 
129
149
  def _apply_custom_logit_processor(
130
150
  self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
@@ -5,6 +5,7 @@ Common utilities for torchao.
5
5
  import logging
6
6
  import os
7
7
  import pwd
8
+ from typing import Callable, Optional
8
9
 
9
10
  import torch
10
11
 
@@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool:
27
28
  return True
28
29
 
29
30
 
31
+ def proj_filter(
32
+ module: torch.nn.Module,
33
+ fqn: str,
34
+ ):
35
+ """Filter function for quantizing projection layers."""
36
+ return "proj" in fqn
37
+
38
+
30
39
  def apply_torchao_config_to_model(
31
- model: torch.nn.Module, torchao_config: str, filter_fn=None
40
+ model: torch.nn.Module,
41
+ torchao_config: str,
42
+ filter_fn: Optional[Callable] = proj_filter,
32
43
  ):
33
44
  """Quantize a modelwith torchao quantization specified by torchao_config
34
45
 
@@ -49,11 +60,6 @@ def apply_torchao_config_to_model(
49
60
  )
50
61
  from torchao.quantization.observer import PerRow, PerTensor
51
62
 
52
- if filter_fn is None:
53
-
54
- def filter_fn(module, fqn):
55
- return "proj" in fqn
56
-
57
63
  if torchao_config == "" or torchao_config is None:
58
64
  return model
59
65
  elif "int8wo" in torchao_config:
@@ -201,6 +201,7 @@ class DetokenizerManager:
201
201
  prompt_tokens=recv_obj.prompt_tokens,
202
202
  completion_tokens=recv_obj.completion_tokens,
203
203
  cached_tokens=recv_obj.cached_tokens,
204
+ spec_verify_ct=recv_obj.spec_verify_ct,
204
205
  input_token_logprobs_val=recv_obj.input_token_logprobs_val,
205
206
  input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
206
207
  output_token_logprobs_val=recv_obj.output_token_logprobs_val,
@@ -17,7 +17,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
17
17
  """
18
18
 
19
19
  import uuid
20
- from dataclasses import dataclass
20
+ from dataclasses import dataclass, field
21
21
  from enum import Enum
22
22
  from typing import Dict, List, Optional, Union
23
23
 
@@ -69,8 +69,10 @@ class GenerateReqInput:
69
69
 
70
70
  # Session info for continual prompting
71
71
  session_params: Optional[Union[List[Dict], Dict]] = None
72
- # Custom logit processor (serialized function)
73
- custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None
72
+ # Custom logit processor for advanced sampling control. Must be a serialized instance
73
+ # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
74
+ # Use the processor's `to_str()` method to generate the serialized string.
75
+ custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
74
76
 
75
77
  def normalize_batch_and_arguments(self):
76
78
  if (
@@ -248,8 +250,9 @@ class TokenizedGenerateReqInput:
248
250
  # Session info for continual prompting
249
251
  session_params: Optional[SessionParams] = None
250
252
 
251
- # Custom logit processor (serialized function)
252
- # TODO (hpguo): Add an example and update doc string here
253
+ # Custom logit processor for advanced sampling control. Must be a serialized instance
254
+ # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
255
+ # Use the processor's `to_str()` method to generate the serialized string.
253
256
  custom_logit_processor: Optional[str] = None
254
257
 
255
258
 
@@ -351,10 +354,13 @@ class BatchTokenIDOut:
351
354
  skip_special_tokens: List[bool]
352
355
  spaces_between_special_tokens: List[bool]
353
356
  no_stop_trim: List[bool]
357
+
354
358
  # Token counts
355
359
  prompt_tokens: List[int]
356
360
  completion_tokens: List[int]
357
361
  cached_tokens: List[int]
362
+ spec_verify_ct: List[int]
363
+
358
364
  # Logprobs
359
365
  input_token_logprobs_val: List[float]
360
366
  input_token_logprobs_idx: List[int]
@@ -379,6 +385,7 @@ class BatchStrOut:
379
385
  prompt_tokens: List[int]
380
386
  completion_tokens: List[int]
381
387
  cached_tokens: List[int]
388
+ spec_verify_ct: List[int]
382
389
 
383
390
  # Logprobs
384
391
  input_token_logprobs_val: List[float]
@@ -533,3 +540,27 @@ class CloseSessionReqInput:
533
540
  class OpenSessionReqOutput:
534
541
  session_id: Optional[str]
535
542
  success: bool
543
+
544
+
545
+ @dataclass
546
+ class Function:
547
+ description: Optional[str] = None
548
+ name: Optional[str] = None
549
+ parameters: Optional[object] = None
550
+
551
+
552
+ @dataclass
553
+ class Tool:
554
+ function: Function
555
+ type: Optional[str] = "function"
556
+
557
+
558
+ @dataclass
559
+ class FunctionCallReqInput:
560
+ text: str # The text to parse.
561
+ tools: List[Tool] = field(
562
+ default_factory=list
563
+ ) # A list of available function tools (name, parameters, etc.).
564
+ tool_call_parser: Optional[str] = (
565
+ None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
566
+ )
@@ -247,12 +247,12 @@ class Req:
247
247
  # Each decode stage's output ids
248
248
  self.output_ids = []
249
249
  # fill_ids = origin_input_ids + output_ids. Updated if chunked.
250
+ self.fill_ids = None
250
251
  self.session_id = session_id
251
252
  self.input_embeds = input_embeds
252
253
 
253
254
  # Sampling info
254
255
  self.sampling_params = sampling_params
255
- self.lora_path = lora_path
256
256
  self.custom_logit_processor = custom_logit_processor
257
257
 
258
258
  # Memory pool info
@@ -300,7 +300,7 @@ class Req:
300
300
  self.logprob_start_len = 0
301
301
  self.top_logprobs_num = top_logprobs_num
302
302
 
303
- # Logprobs (return value)
303
+ # Logprobs (return values)
304
304
  self.input_token_logprobs_val: Optional[List[float]] = None
305
305
  self.input_token_logprobs_idx: Optional[List[int]] = None
306
306
  self.input_top_logprobs_val: Optional[List[float]] = None
@@ -329,8 +329,14 @@ class Req:
329
329
  # Constrained decoding
330
330
  self.grammar: Optional[BaseGrammarObject] = None
331
331
 
332
- # The number of cached tokens, that were already cached in the KV cache
332
+ # The number of cached tokens that were already cached in the KV cache
333
333
  self.cached_tokens = 0
334
+ self.already_computed = 0
335
+
336
+ # The number of verification forward passes in the speculative decoding.
337
+ # This is used to compute the average acceptance length per request.
338
+ self.spec_verify_ct = 0
339
+ self.lora_path = lora_path
334
340
 
335
341
  def extend_image_inputs(self, image_inputs):
336
342
  if self.image_inputs is None:
@@ -550,13 +556,13 @@ class ScheduleBatch:
550
556
  next_batch_sampling_info: SamplingBatchInfo = None
551
557
 
552
558
  # Batched arguments to model runner
553
- input_ids: torch.Tensor = None
554
- input_embeds: torch.Tensor = None
555
- req_pool_indices: torch.Tensor = None
556
- seq_lens: torch.Tensor = None
559
+ input_ids: torch.Tensor = None # shape: [b], int32
560
+ input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
561
+ req_pool_indices: torch.Tensor = None # shape: [b], int32
562
+ seq_lens: torch.Tensor = None # shape: [b], int64
557
563
  # The output locations of the KV cache
558
- out_cache_loc: torch.Tensor = None
559
- output_ids: torch.Tensor = None
564
+ out_cache_loc: torch.Tensor = None # shape: [b], int32
565
+ output_ids: torch.Tensor = None # shape: [b], int32
560
566
 
561
567
  # The sum of all sequence lengths
562
568
  seq_lens_sum: int = None
@@ -750,13 +756,6 @@ class ScheduleBatch:
750
756
 
751
757
  pt = 0
752
758
  for i, req in enumerate(reqs):
753
- already_computed = (
754
- req.extend_logprob_start_len + 1 + req.cached_tokens
755
- if req.extend_logprob_start_len > 0
756
- else 0
757
- )
758
- req.cached_tokens += len(req.prefix_indices) - already_computed
759
-
760
759
  req.req_pool_idx = req_pool_indices[i]
761
760
  pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
762
761
  seq_lens.append(seq_len)
@@ -772,15 +771,20 @@ class ScheduleBatch:
772
771
  # If req.input_embeds is already a list, append its content directly
773
772
  input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
774
773
 
775
- # Compute the relative logprob_start_len in an extend batch
776
- if req.logprob_start_len >= pre_len:
777
- extend_logprob_start_len = min(
778
- req.logprob_start_len - pre_len, req.extend_input_len - 1
779
- )
780
- else:
781
- extend_logprob_start_len = req.extend_input_len - 1
774
+ if req.return_logprob:
775
+ # Compute the relative logprob_start_len in an extend batch
776
+ if req.logprob_start_len >= pre_len:
777
+ extend_logprob_start_len = min(
778
+ req.logprob_start_len - pre_len, req.extend_input_len - 1
779
+ )
780
+ else:
781
+ raise RuntimeError(
782
+ f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
783
+ )
784
+ req.extend_logprob_start_len = extend_logprob_start_len
782
785
 
783
- req.extend_logprob_start_len = extend_logprob_start_len
786
+ req.cached_tokens += pre_len - req.already_computed
787
+ req.already_computed = seq_len
784
788
  req.is_retracted = False
785
789
  pre_lens.append(pre_len)
786
790
 
@@ -1026,7 +1030,7 @@ class ScheduleBatch:
1026
1030
  self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
1027
1031
  self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1028
1032
  self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
1029
- self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device)
1033
+ self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1030
1034
  self.seq_lens_sum = 0
1031
1035
  self.extend_num_tokens = 0
1032
1036
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(
@@ -1112,6 +1116,8 @@ class ScheduleBatch:
1112
1116
  self.has_grammar = any(req.grammar for req in self.reqs)
1113
1117
 
1114
1118
  self.sampling_info.filter_batch(keep_indices, new_indices)
1119
+ if self.spec_info:
1120
+ self.spec_info.filter_batch(new_indices)
1115
1121
 
1116
1122
  def merge_batch(self, other: "ScheduleBatch"):
1117
1123
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
@@ -281,6 +281,7 @@ class Scheduler:
281
281
  # Print debug info
282
282
  logger.info(
283
283
  f"max_total_num_tokens={self.max_total_num_tokens}, "
284
+ f"chunked_prefill_size={server_args.chunked_prefill_size}, "
284
285
  f"max_prefill_tokens={self.max_prefill_tokens}, "
285
286
  f"max_running_requests={self.max_running_requests}, "
286
287
  f"context_len={self.model_config.context_len}"
@@ -408,6 +409,11 @@ class Scheduler:
408
409
  },
409
410
  )
410
411
 
412
+ # The largest prefill length of a single request
413
+ self._largest_prefill_len: int = 0
414
+ # The largest context length (prefill + generation) of a single request
415
+ self._largest_prefill_decode_len: int = 0
416
+
411
417
  # Init request dispatcher
412
418
  self._request_dispatcher = TypeBasedDispatcher(
413
419
  [
@@ -480,7 +486,7 @@ class Scheduler:
480
486
  @torch.no_grad()
481
487
  def event_loop_overlap(self):
482
488
  """A scheduler loop that overlaps the CPU processing and GPU computation."""
483
- result_queue = deque()
489
+ self.result_queue = deque()
484
490
 
485
491
  while True:
486
492
  recv_reqs = self.recv_requests()
@@ -491,7 +497,7 @@ class Scheduler:
491
497
 
492
498
  if batch:
493
499
  result = self.run_batch(batch)
494
- result_queue.append((batch.copy(), result))
500
+ self.result_queue.append((batch.copy(), result))
495
501
 
496
502
  if self.last_batch is None:
497
503
  # Create a dummy first batch to start the pipeline for overlap schedule.
@@ -505,7 +511,7 @@ class Scheduler:
505
511
 
506
512
  if self.last_batch:
507
513
  # Process the results of the last batch
508
- tmp_batch, tmp_result = result_queue.popleft()
514
+ tmp_batch, tmp_result = self.result_queue.popleft()
509
515
  tmp_batch.next_batch_sampling_info = (
510
516
  self.tp_worker.cur_sampling_info if batch else None
511
517
  )
@@ -636,7 +642,7 @@ class Scheduler:
636
642
  self.waiting_queue.append(req)
637
643
  return
638
644
 
639
- # Handle image inputs
645
+ # Handle multimodal inputs
640
646
  if recv_req.image_inputs is not None:
641
647
  image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
642
648
  # Expand a single image token into multiple dummy tokens for receiving image embeddings
@@ -660,24 +666,23 @@ class Scheduler:
660
666
  self.waiting_queue.append(req)
661
667
  return
662
668
 
663
- # Copy more attributes
664
- req.logprob_start_len = recv_req.logprob_start_len
665
-
666
- if req.logprob_start_len == -1:
667
- # By default, only return the logprobs for output tokens
668
- req.logprob_start_len = len(req.origin_input_ids) - 1
669
-
670
669
  # Validate prompts length
671
670
  error_msg = validate_input_length(
672
671
  req,
673
672
  self.max_req_input_len,
674
673
  self.server_args.allow_auto_truncate,
675
674
  )
676
-
677
675
  if error_msg:
678
676
  self.waiting_queue.append(req)
679
677
  return
680
678
 
679
+ # Copy more attributes
680
+ if recv_req.logprob_start_len == -1:
681
+ # By default, only return the logprobs for output tokens
682
+ req.logprob_start_len = len(req.origin_input_ids) - 1
683
+ else:
684
+ req.logprob_start_len = recv_req.logprob_start_len
685
+
681
686
  req.sampling_params.max_new_tokens = min(
682
687
  (
683
688
  req.sampling_params.max_new_tokens
@@ -725,15 +730,26 @@ class Scheduler:
725
730
  req.tokenizer = self.tokenizer
726
731
 
727
732
  # Validate prompts length
728
- validate_input_length(
733
+ error_msg = validate_input_length(
729
734
  req,
730
735
  self.max_req_input_len,
731
736
  self.server_args.allow_auto_truncate,
732
737
  )
738
+ if error_msg:
739
+ self.waiting_queue.append(req)
740
+ return
733
741
 
742
+ # Copy more attributes
743
+ req.logprob_start_len = len(req.origin_input_ids) - 1
734
744
  self.waiting_queue.append(req)
735
745
 
736
- def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
746
+ def log_prefill_stats(
747
+ self,
748
+ adder: PrefillAdder,
749
+ can_run_list: List[Req],
750
+ running_bs: ScheduleBatch,
751
+ has_being_chunked: bool,
752
+ ):
737
753
  self.tree_cache_metrics["total"] += (
738
754
  adder.log_input_tokens + adder.log_hit_tokens
739
755
  ) / 10**9
@@ -1023,7 +1039,7 @@ class Scheduler:
1023
1039
  )
1024
1040
 
1025
1041
  # Check for jump-forward
1026
- if not self.disable_jump_forward:
1042
+ if not self.disable_jump_forward and batch.has_grammar:
1027
1043
  jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
1028
1044
  self.waiting_queue.extend(jump_forward_reqs)
1029
1045
  if batch.is_empty():
@@ -1044,26 +1060,23 @@ class Scheduler:
1044
1060
  self.forward_ct += 1
1045
1061
 
1046
1062
  if self.is_generation:
1047
- if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
1048
- if self.spec_algorithm.is_none():
1049
- model_worker_batch = batch.get_model_worker_batch()
1050
- logits_output, next_token_ids = (
1051
- self.tp_worker.forward_batch_generation(model_worker_batch)
1052
- )
1053
- else:
1054
- (
1055
- logits_output,
1056
- next_token_ids,
1057
- model_worker_batch,
1058
- num_accepted_tokens,
1059
- ) = self.draft_worker.forward_batch_speculative_generation(batch)
1060
- self.spec_num_total_accepted_tokens += (
1061
- num_accepted_tokens + batch.batch_size()
1062
- )
1063
- self.spec_num_total_forward_ct += batch.batch_size()
1064
- self.num_generated_tokens += num_accepted_tokens
1063
+ if self.spec_algorithm.is_none():
1064
+ model_worker_batch = batch.get_model_worker_batch()
1065
+ logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
1066
+ model_worker_batch
1067
+ )
1065
1068
  else:
1066
- assert False, "batch.extend_num_tokens == 0, this is unexpected!"
1069
+ (
1070
+ logits_output,
1071
+ next_token_ids,
1072
+ model_worker_batch,
1073
+ num_accepted_tokens,
1074
+ ) = self.draft_worker.forward_batch_speculative_generation(batch)
1075
+ self.spec_num_total_accepted_tokens += (
1076
+ num_accepted_tokens + batch.batch_size()
1077
+ )
1078
+ self.spec_num_total_forward_ct += batch.batch_size()
1079
+ self.num_generated_tokens += num_accepted_tokens
1067
1080
  batch.output_ids = next_token_ids
1068
1081
 
1069
1082
  ret = GenerationBatchResult(
@@ -1072,7 +1085,6 @@ class Scheduler:
1072
1085
  bid=model_worker_batch.bid,
1073
1086
  )
1074
1087
  else: # embedding or reward model
1075
- assert batch.extend_num_tokens != 0
1076
1088
  model_worker_batch = batch.get_model_worker_batch()
1077
1089
  embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1078
1090
  ret = EmbeddingBatchResult(
@@ -1371,6 +1383,7 @@ class Scheduler:
1371
1383
  prompt_tokens = []
1372
1384
  completion_tokens = []
1373
1385
  cached_tokens = []
1386
+ spec_verify_ct = []
1374
1387
 
1375
1388
  if return_logprob:
1376
1389
  input_token_logprobs_val = []
@@ -1424,6 +1437,9 @@ class Scheduler:
1424
1437
  completion_tokens.append(len(req.output_ids))
1425
1438
  cached_tokens.append(req.cached_tokens)
1426
1439
 
1440
+ if not self.spec_algorithm.is_none():
1441
+ spec_verify_ct.append(req.spec_verify_ct)
1442
+
1427
1443
  if return_logprob:
1428
1444
  input_token_logprobs_val.append(req.input_token_logprobs_val)
1429
1445
  input_token_logprobs_idx.append(req.input_token_logprobs_idx)
@@ -1451,6 +1467,7 @@ class Scheduler:
1451
1467
  prompt_tokens,
1452
1468
  completion_tokens,
1453
1469
  cached_tokens,
1470
+ spec_verify_ct,
1454
1471
  input_token_logprobs_val,
1455
1472
  input_token_logprobs_idx,
1456
1473
  output_token_logprobs_val,
@@ -1564,6 +1581,15 @@ class Scheduler:
1564
1581
  self.grammar_backend.reset()
1565
1582
  self.req_to_token_pool.clear()
1566
1583
  self.token_to_kv_pool.clear()
1584
+
1585
+ if not self.spec_algorithm.is_none():
1586
+ self.draft_worker.model_runner.req_to_token_pool.clear()
1587
+ self.draft_worker.model_runner.token_to_kv_pool.clear()
1588
+
1589
+ self.num_generated_tokens = 0
1590
+ self.forward_ct_decode = 0
1591
+ self.spec_num_total_accepted_tokens = 0
1592
+ self.spec_num_total_forward_ct = 0
1567
1593
  torch.cuda.empty_cache()
1568
1594
  logger.info("Cache flushed successfully!")
1569
1595
  if_success = True
@@ -785,6 +785,9 @@ class TokenizerManager:
785
785
  i,
786
786
  )
787
787
 
788
+ if self.server_args.speculative_algorithm:
789
+ meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
790
+
788
791
  if not isinstance(recv_obj, BatchEmbeddingOut):
789
792
  meta_info.update(
790
793
  {
@@ -809,6 +812,7 @@ class TokenizerManager:
809
812
  "embedding": recv_obj.embeddings[i],
810
813
  "meta_info": meta_info,
811
814
  }
815
+
812
816
  state.out_list.append(out_dict)
813
817
  state.finished = recv_obj.finished_reasons[i] is not None
814
818
  state.event.set()