sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +9 -7
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +1 -0
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +48 -43
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +7 -2
  20. sglang/srt/disaggregation/fake/conn.py +1 -1
  21. sglang/srt/disaggregation/mooncake/conn.py +227 -120
  22. sglang/srt/disaggregation/nixl/conn.py +1 -0
  23. sglang/srt/disaggregation/prefill.py +7 -4
  24. sglang/srt/disaggregation/utils.py +7 -1
  25. sglang/srt/entrypoints/engine.py +17 -2
  26. sglang/srt/entrypoints/http_server.py +17 -2
  27. sglang/srt/function_call_parser.py +2 -2
  28. sglang/srt/layers/attention/flashattention_backend.py +1 -1
  29. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  30. sglang/srt/layers/attention/utils.py +4 -2
  31. sglang/srt/layers/dp_attention.py +71 -21
  32. sglang/srt/layers/layernorm.py +1 -1
  33. sglang/srt/layers/logits_processor.py +46 -11
  34. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  35. sglang/srt/layers/moe/ep_moe/layer.py +1 -1
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  37. sglang/srt/layers/moe/topk.py +1 -1
  38. sglang/srt/layers/quantization/__init__.py +1 -1
  39. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  40. sglang/srt/layers/quantization/deep_gemm.py +72 -71
  41. sglang/srt/layers/quantization/fp8.py +2 -2
  42. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  43. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  44. sglang/srt/layers/sampler.py +0 -4
  45. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  46. sglang/srt/lora/lora_manager.py +1 -1
  47. sglang/srt/lora/mem_pool.py +4 -4
  48. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  49. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  50. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  51. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  52. sglang/srt/lora/utils.py +1 -1
  53. sglang/srt/managers/data_parallel_controller.py +3 -3
  54. sglang/srt/managers/detokenizer_manager.py +21 -8
  55. sglang/srt/managers/io_struct.py +3 -1
  56. sglang/srt/managers/mm_utils.py +1 -1
  57. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  58. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  59. sglang/srt/managers/schedule_batch.py +76 -24
  60. sglang/srt/managers/schedule_policy.py +0 -3
  61. sglang/srt/managers/scheduler.py +113 -88
  62. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  63. sglang/srt/managers/tokenizer_manager.py +133 -34
  64. sglang/srt/managers/tp_worker.py +12 -9
  65. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  66. sglang/srt/mem_cache/memory_pool.py +2 -0
  67. sglang/srt/metrics/collector.py +312 -37
  68. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  69. sglang/srt/model_executor/forward_batch_info.py +1 -1
  70. sglang/srt/model_executor/model_runner.py +19 -14
  71. sglang/srt/models/deepseek_janus_pro.py +2 -2
  72. sglang/srt/models/deepseek_v2.py +23 -20
  73. sglang/srt/models/llama.py +2 -0
  74. sglang/srt/models/llama4.py +5 -6
  75. sglang/srt/models/llava.py +248 -5
  76. sglang/srt/models/mixtral.py +98 -34
  77. sglang/srt/models/pixtral.py +467 -0
  78. sglang/srt/models/roberta.py +1 -1
  79. sglang/srt/models/torch_native_llama.py +1 -1
  80. sglang/srt/openai_api/adapter.py +30 -4
  81. sglang/srt/openai_api/protocol.py +0 -8
  82. sglang/srt/reasoning_parser.py +3 -3
  83. sglang/srt/sampling/custom_logit_processor.py +18 -3
  84. sglang/srt/sampling/sampling_batch_info.py +4 -56
  85. sglang/srt/sampling/sampling_params.py +2 -2
  86. sglang/srt/server_args.py +34 -4
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +7 -7
  89. sglang/srt/speculative/eagle_worker.py +22 -19
  90. sglang/srt/utils.py +6 -5
  91. sglang/test/few_shot_gsm8k.py +2 -2
  92. sglang/test/few_shot_gsm8k_engine.py +2 -2
  93. sglang/test/run_eval.py +2 -2
  94. sglang/test/runners.py +8 -1
  95. sglang/test/send_one.py +13 -3
  96. sglang/test/simple_eval_common.py +1 -1
  97. sglang/test/simple_eval_humaneval.py +1 -1
  98. sglang/test/test_programs.py +5 -5
  99. sglang/test/test_utils.py +89 -14
  100. sglang/utils.py +1 -1
  101. sglang/version.py +1 -1
  102. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
  103. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
  104. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  105. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
  106. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,20 @@
1
1
  import asyncio
2
+ import importlib
2
3
  from typing import List, Optional, Union
3
4
 
4
5
  import numpy as np
6
+ from transformers.models.auto.processing_auto import (
7
+ PROCESSOR_MAPPING_NAMES as HF_MAPPING_NAMES,
8
+ )
5
9
 
10
+ import sglang.srt.managers.multimodal_processor as sgl_mm_processor_utils
6
11
  from sglang.srt.managers.multimodal_processors.base_processor import (
7
12
  BaseMultimodalProcessor,
8
13
  )
9
14
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
10
15
  from sglang.srt.mm_utils import expand2square, process_anyres_image
11
16
  from sglang.srt.models.llava import (
17
+ LlavaForConditionalGeneration,
12
18
  LlavaLlamaForCausalLM,
13
19
  LlavaMistralForCausalLM,
14
20
  LlavaQwenForCausalLM,
@@ -133,6 +139,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
133
139
  img_data, aspect_ratio, grid_pinpoints
134
140
  )
135
141
  )
142
+
136
143
  res = await asyncio.gather(*res)
137
144
  for pixel_v, image_h, image_s in res:
138
145
  pixel_values.append(pixel_v)
@@ -165,3 +172,42 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
165
172
  )
166
173
  ],
167
174
  }
175
+
176
+
177
+ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
178
+ """
179
+ This is a wrapper class used to identify the multimodal processor for Llava architecture models.
180
+ """
181
+
182
+ models = [LlavaForConditionalGeneration]
183
+
184
+ def _get_sgl_processor_cls(self, model_type: str):
185
+ if hf_name := HF_MAPPING_NAMES.get(model_type):
186
+ sgl_mm_processor_set = sgl_mm_processor_utils.PROCESSOR_MAPPING.values()
187
+ sgl_processor_cls = list(
188
+ filter(lambda p: p.__name__ == hf_name, sgl_mm_processor_set)
189
+ )
190
+ if sgl_processor_cls:
191
+ return sgl_processor_cls[0]
192
+ raise ValueError(
193
+ f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
194
+ )
195
+
196
+ def __init__(self, hf_config, server_args, _processor):
197
+ assert hasattr(hf_config, "vision_config")
198
+ assert hasattr(hf_config, "text_config")
199
+ self.vision_config = hf_config.vision_config
200
+ self.text_config = hf_config.text_config
201
+ self.hf_config = hf_config
202
+
203
+ if vision_type := getattr(self.vision_config, "model_type"):
204
+ self.inner = self._get_sgl_processor_cls(vision_type)(
205
+ hf_config, server_args, _processor
206
+ )
207
+ else:
208
+ raise ValueError(
209
+ f"Required `vision_config.model_type` is not found in hf_config: `{hf_config}`"
210
+ )
211
+
212
+ async def process_mm_data_async(self, *args, **kwargs):
213
+ return await self.inner.process_mm_data_async(*args, **kwargs)
@@ -0,0 +1,127 @@
1
+ import asyncio
2
+ import math
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ from transformers import PretrainedConfig
7
+ from transformers.models.pixtral.image_processing_pixtral import (
8
+ _num_image_tokens as _get_pixtral_hf_num_image_tokens,
9
+ )
10
+
11
+ from sglang.srt.managers.multimodal_processors.base_processor import (
12
+ BaseMultimodalProcessor,
13
+ MultimodalSpecialTokens,
14
+ )
15
+ from sglang.srt.managers.schedule_batch import (
16
+ Modality,
17
+ MultimodalDataItem,
18
+ MultimodalInputs,
19
+ )
20
+ from sglang.srt.models.pixtral import PixtralVisionModel
21
+
22
+
23
+ class PixtralProcessor(BaseMultimodalProcessor):
24
+ models = [PixtralVisionModel]
25
+
26
+ PAD_TOKEN = "<pad>"
27
+ IMG_BREAK_TOKEN_ID = 12
28
+ IMG_END_TOKEN_ID = 13
29
+
30
+ def get_patch_grid_size(
31
+ self,
32
+ *,
33
+ image_width: int,
34
+ image_height: int,
35
+ ) -> tuple[int, int]:
36
+ max_width = max_height = self.image_size
37
+ patch_width = patch_height = self.patch_size
38
+
39
+ ratio = max(image_width / max_width, image_height / max_height)
40
+
41
+ if ratio > 1:
42
+ image_width = int(math.floor(image_width / ratio))
43
+ image_height = int(math.floor(image_height / ratio))
44
+
45
+ nrows, ncols = _get_pixtral_hf_num_image_tokens(
46
+ (image_height, image_width),
47
+ (patch_height, patch_width),
48
+ )
49
+
50
+ return ncols, nrows
51
+
52
+ def __init__(self, hf_config, server_args, _processor):
53
+ super().__init__(hf_config, server_args, _processor)
54
+ self.image_token_id = getattr(
55
+ hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
56
+ )
57
+ # Instantiate the patcher logic helper using the class defined above
58
+
59
+ self.vision_config = hf_config.vision_config
60
+ self.image_size = self.vision_config.image_size
61
+ self.patch_size = self.vision_config.patch_size
62
+ self.multimodal_tokens = MultimodalSpecialTokens(
63
+ image_token=_processor.image_token
64
+ )
65
+ _processor.tokenizer.add_special_tokens(
66
+ {
67
+ "pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN),
68
+ }
69
+ )
70
+
71
+ async def _resize(self, image):
72
+ num_w_tokens, num_h_tokens = self.get_patch_grid_size(
73
+ image_width=image.size[0],
74
+ image_height=image.size[1],
75
+ )
76
+ new_size = (num_w_tokens * self.patch_size, num_h_tokens * self.patch_size)
77
+ return image.resize(new_size)
78
+
79
+ async def process_mm_data_async(
80
+ self,
81
+ image_data: List[Union[str, bytes]],
82
+ input_text,
83
+ request_obj,
84
+ *args,
85
+ **kwargs,
86
+ ):
87
+ if not image_data:
88
+ return None
89
+
90
+ if isinstance(image_data, str):
91
+ image_data = [image_data]
92
+
93
+ mm_data = self.load_mm_data(
94
+ prompt=input_text,
95
+ multimodal_tokens=self.multimodal_tokens,
96
+ max_req_input_len=kwargs.get("max_req_input_len", 4096),
97
+ image_data=image_data,
98
+ return_text=True,
99
+ )
100
+
101
+ if mm_data.images:
102
+ resize_tasks = [self._resize(image) for image in mm_data.images]
103
+ mm_data.images = await asyncio.gather(*resize_tasks)
104
+
105
+ processor_output = self.process_mm_data(
106
+ input_text=mm_data.input_text,
107
+ images=mm_data.images,
108
+ )
109
+
110
+ if "pixel_values" in processor_output:
111
+ mm_items = [
112
+ MultimodalDataItem(
113
+ pixel_values=processor_output["pixel_values"],
114
+ image_sizes=processor_output["image_sizes"],
115
+ modality=Modality.IMAGE,
116
+ )
117
+ ]
118
+
119
+ input_ids = processor_output["input_ids"].view(-1).tolist()
120
+ processor_output.update(
121
+ input_ids=input_ids,
122
+ mm_items=mm_items,
123
+ # there's no im_start_id for pixtral, only im_token and im_end_token
124
+ im_end_id=self.IMG_END_TOKEN_ID,
125
+ im_token_id=self.image_token_id,
126
+ )
127
+ return processor_output
@@ -1,8 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import hashlib
4
- from enum import Enum, auto
5
-
6
3
  # Copyright 2023-2024 SGLang Team
7
4
  # Licensed under the Apache License, Version 2.0 (the "License");
8
5
  # you may not use this file except in compliance with the License.
@@ -30,12 +27,16 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
30
27
  It will be transformed from CPU scheduler to GPU model runner.
31
28
  - ForwardBatch is managed by `model_runner.py::ModelRunner`.
32
29
  It contains low-level tensor data. Most of the data consists of GPU tensors.
30
+
31
+ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
33
32
  """
34
33
 
35
34
  import copy
36
35
  import dataclasses
36
+ import hashlib
37
37
  import logging
38
38
  import threading
39
+ from enum import Enum, auto
39
40
  from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
40
41
 
41
42
  import numpy as np
@@ -51,6 +52,7 @@ from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMi
51
52
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
52
53
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
53
54
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
55
+ from sglang.srt.metrics.collector import TimeStats
54
56
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
55
57
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
56
58
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -73,6 +75,7 @@ global_server_args_dict = {
73
75
  "disable_radix_cache": ServerArgs.disable_radix_cache,
74
76
  "enable_deepep_moe": ServerArgs.enable_deepep_moe,
75
77
  "enable_dp_attention": ServerArgs.enable_dp_attention,
78
+ "enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
76
79
  "enable_ep_moe": ServerArgs.enable_ep_moe,
77
80
  "enable_nan_detection": ServerArgs.enable_nan_detection,
78
81
  "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
@@ -134,9 +137,9 @@ class FINISH_LENGTH(BaseFinishReason):
134
137
 
135
138
 
136
139
  class FINISH_ABORT(BaseFinishReason):
137
- def __init__(self, message="Unknown error", status_code=None, err_type=None):
140
+ def __init__(self, message=None, status_code=None, err_type=None):
138
141
  super().__init__(is_error=True)
139
- self.message = message
142
+ self.message = message or "Aborted"
140
143
  self.status_code = status_code
141
144
  self.err_type = err_type
142
145
 
@@ -434,6 +437,7 @@ class Req:
434
437
  self.sampling_params = sampling_params
435
438
  self.custom_logit_processor = custom_logit_processor
436
439
  self.return_hidden_states = return_hidden_states
440
+ self.lora_path = lora_path
437
441
 
438
442
  # Memory pool info
439
443
  self.req_pool_idx: Optional[int] = None
@@ -441,11 +445,13 @@ class Req:
441
445
  # Check finish
442
446
  self.tokenizer = None
443
447
  self.finished_reason = None
448
+ # Whether this request has finished output
449
+ self.finished_output = None
444
450
  # If we want to abort the request in the middle of the event loop, set this to true
445
451
  # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
446
452
  self.to_abort = False
447
453
  # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
448
- self.to_abort_message: str = "Unknown error"
454
+ self.to_abort_message: str = None
449
455
  self.stream = stream
450
456
  self.eos_token_ids = eos_token_ids
451
457
 
@@ -483,6 +489,13 @@ class Req:
483
489
  # For retraction
484
490
  self.is_retracted = False
485
491
 
492
+ # Incremental streamining
493
+ self.send_token_offset: int = 0
494
+ self.send_decode_id_offset: int = 0
495
+ # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
496
+ # because the decode server does not have the first output token logprobs
497
+ self.send_output_token_logprobs_offset: int = 0
498
+
486
499
  # Logprobs (arguments)
487
500
  self.return_logprob = return_logprob
488
501
  # Start index to compute logprob from.
@@ -492,11 +505,9 @@ class Req:
492
505
  self.temp_scaled_logprobs = False
493
506
  self.top_p_normalized_logprobs = False
494
507
 
495
- # Latency Breakdown
496
- self.queue_time_start = None
497
- self.queue_time_end = None
498
-
499
508
  # Logprobs (return values)
509
+ # True means the input logprob has been already sent to detokenizer.
510
+ self.input_logprob_sent: bool = False
500
511
  self.input_token_logprobs_val: Optional[List[float]] = None
501
512
  self.input_token_logprobs_idx: Optional[List[int]] = None
502
513
  self.input_top_logprobs_val: Optional[List[float]] = None
@@ -511,8 +522,10 @@ class Req:
511
522
  self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
512
523
 
513
524
  if return_logprob:
525
+ # shape: (bs, 1)
514
526
  self.output_token_logprobs_val = []
515
527
  self.output_token_logprobs_idx = []
528
+ # shape: (bs, k)
516
529
  self.output_top_logprobs_val = []
517
530
  self.output_top_logprobs_idx = []
518
531
  self.output_token_ids_logprobs_val = []
@@ -530,6 +543,7 @@ class Req:
530
543
 
531
544
  # Constrained decoding
532
545
  self.grammar: Optional[BaseGrammarObject] = None
546
+ self.grammar_wait_ct = 0
533
547
 
534
548
  # The number of cached tokens that were already cached in the KV cache
535
549
  self.cached_tokens = 0
@@ -538,7 +552,12 @@ class Req:
538
552
  # The number of verification forward passes in the speculative decoding.
539
553
  # This is used to compute the average acceptance length per request.
540
554
  self.spec_verify_ct = 0
541
- self.lora_path = lora_path
555
+
556
+ # For metrics
557
+ self.time_stats: TimeStats = TimeStats()
558
+ self.has_log_time_stats: bool = False
559
+ self.queue_time_start = None
560
+ self.queue_time_end = None
542
561
 
543
562
  # For disaggregation
544
563
  self.bootstrap_host: str = bootstrap_host
@@ -546,8 +565,6 @@ class Req:
546
565
  self.bootstrap_room: Optional[int] = bootstrap_room
547
566
  self.disagg_kv_sender: Optional[BaseKVSender] = None
548
567
 
549
- # used for warmup because we don't have a pair yet when init
550
- self.skip_kv_transfer: bool = False
551
568
  # the start index of the sent kv cache
552
569
  # We want to send it chunk by chunk for chunked prefill.
553
570
  # After every chunk forward, we do the following:
@@ -555,14 +572,14 @@ class Req:
555
572
  # start_send_idx = len(req.fill_ids)
556
573
  self.start_send_idx: int = 0
557
574
 
558
- self.metadata_buffer_index: int = -1
559
- # The first output_id transferred from prefill instance.
560
- self.transferred_output_id: Optional[int] = None
561
-
562
575
  # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
563
576
  # This is because kv is not ready in `process_prefill_chunk`.
564
577
  # We use `tmp_end_idx` to store the end index of the kv cache to send.
565
578
  self.tmp_end_idx: int = -1
579
+ self.metadata_buffer_index: int = -1
580
+
581
+ # The first output_id transferred from prefill instance.
582
+ self.transferred_output_id: Optional[int] = None
566
583
 
567
584
  @property
568
585
  def seqlen(self):
@@ -653,6 +670,11 @@ class Req:
653
670
  )
654
671
  return
655
672
 
673
+ if self.grammar is not None:
674
+ if self.grammar.is_terminated():
675
+ self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
676
+ return
677
+
656
678
  last_token_id = self.output_ids[-1]
657
679
 
658
680
  if not self.sampling_params.ignore_eos:
@@ -697,13 +719,41 @@ class Req:
697
719
  self.req_pool_idx = None
698
720
  self.already_computed = 0
699
721
 
722
+ def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
723
+ token_indices = req_to_token_pool.req_to_token[
724
+ self.req_pool_idx, : self.seqlen - 1
725
+ ]
726
+ self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)
727
+
728
+ def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
729
+ token_indices = req_to_token_pool.req_to_token[
730
+ self.req_pool_idx, : self.seqlen - 1
731
+ ]
732
+ token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
733
+ del self.kv_cache_cpu
734
+
735
+ def log_time_stats(self):
736
+ # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
737
+ if self.has_log_time_stats is True:
738
+ return
739
+
740
+ if self.bootstrap_room is not None:
741
+ prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
742
+ else:
743
+ prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
744
+ logger.info(f"{prefix}: {self.time_stats}")
745
+ self.has_log_time_stats = True
746
+
700
747
  def __repr__(self):
701
748
  return (
702
749
  f"Req(rid={self.rid}, "
703
- f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
750
+ f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
751
+ f"{self.grammar=}, "
752
+ f"{self.sampling_params=})"
704
753
  )
705
754
 
706
755
 
756
+ # Batch id
707
757
  bid = 0
708
758
 
709
759
 
@@ -862,7 +912,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
862
912
  error_msg = (
863
913
  f"{phase_str} out of memory. Try to lower your batch size.\n"
864
914
  f"Try to allocate {num_tokens} tokens.\n"
865
- f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
915
+ f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
866
916
  )
867
917
  logger.error(error_msg)
868
918
  if self.tree_cache is not None:
@@ -903,7 +953,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
903
953
  error_msg = (
904
954
  f"Prefill out of memory. Try to lower your batch size.\n"
905
955
  f"Try to allocate {extend_num_tokens} tokens.\n"
906
- f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
956
+ f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
907
957
  f"{self.token_to_kv_pool_allocator.available_size()=}\n"
908
958
  f"{self.tree_cache.evictable_size()=}\n"
909
959
  )
@@ -938,7 +988,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
938
988
  error_msg = (
939
989
  f"Decode out of memory. Try to lower your batch size.\n"
940
990
  f"Try to allocate {len(seq_lens)} tokens.\n"
941
- f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
991
+ f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
942
992
  f"{self.token_to_kv_pool_allocator.available_size()=}\n"
943
993
  f"{self.tree_cache.evictable_size()=}\n"
944
994
  )
@@ -1447,7 +1497,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1447
1497
  i
1448
1498
  for i in range(len(self.reqs))
1449
1499
  if not self.reqs[i].finished()
1450
- and not self.reqs[i] in chunked_req_to_exclude
1500
+ and self.reqs[i] not in chunked_req_to_exclude
1451
1501
  ]
1452
1502
 
1453
1503
  if keep_indices is None or len(keep_indices) == 0:
@@ -1468,7 +1518,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1468
1518
  self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
1469
1519
 
1470
1520
  self.reqs = [self.reqs[i] for i in keep_indices]
1471
- self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1521
+ if self.multimodal_inputs is not None:
1522
+ self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1472
1523
  self.req_pool_indices = self.req_pool_indices[keep_indices_device]
1473
1524
  self.seq_lens = self.seq_lens[keep_indices_device]
1474
1525
  self.out_cache_loc = None
@@ -1517,7 +1568,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1517
1568
  self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1518
1569
  self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1519
1570
  self.reqs.extend(other.reqs)
1520
- self.multimodal_inputs.extend(other.multimodal_inputs)
1571
+ if self.multimodal_inputs is not None:
1572
+ self.multimodal_inputs.extend(other.multimodal_inputs)
1521
1573
 
1522
1574
  self.return_logprob |= other.return_logprob
1523
1575
  self.has_stream |= other.has_stream
@@ -468,9 +468,6 @@ class PrefillAdder:
468
468
  return AddReqResult.OTHER
469
469
 
470
470
  with self._lock_node(req.last_node):
471
- if total_tokens > self.rem_total_tokens:
472
- return AddReqResult.NO_TOKEN
473
-
474
471
  if (
475
472
  enable_hierarchical_cache
476
473
  and req.last_node_global is not None