sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,20 @@
1
1
  import asyncio
2
+ import importlib
2
3
  from typing import List, Optional, Union
3
4
 
4
5
  import numpy as np
6
+ from transformers.models.auto.processing_auto import (
7
+ PROCESSOR_MAPPING_NAMES as HF_MAPPING_NAMES,
8
+ )
5
9
 
10
+ import sglang.srt.managers.multimodal_processor as sgl_mm_processor_utils
6
11
  from sglang.srt.managers.multimodal_processors.base_processor import (
7
12
  BaseMultimodalProcessor,
8
13
  )
9
14
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
10
15
  from sglang.srt.mm_utils import expand2square, process_anyres_image
11
16
  from sglang.srt.models.llava import (
17
+ LlavaForConditionalGeneration,
12
18
  LlavaLlamaForCausalLM,
13
19
  LlavaMistralForCausalLM,
14
20
  LlavaQwenForCausalLM,
@@ -133,6 +139,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
133
139
  img_data, aspect_ratio, grid_pinpoints
134
140
  )
135
141
  )
142
+
136
143
  res = await asyncio.gather(*res)
137
144
  for pixel_v, image_h, image_s in res:
138
145
  pixel_values.append(pixel_v)
@@ -165,3 +172,42 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
165
172
  )
166
173
  ],
167
174
  }
175
+
176
+
177
+ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
178
+ """
179
+ This is a wrapper class used to identify the multimodal processor for Llava architecture models.
180
+ """
181
+
182
+ models = [LlavaForConditionalGeneration]
183
+
184
+ def _get_sgl_processor_cls(self, model_type: str):
185
+ if hf_name := HF_MAPPING_NAMES.get(model_type):
186
+ sgl_mm_processor_set = sgl_mm_processor_utils.PROCESSOR_MAPPING.values()
187
+ sgl_processor_cls = list(
188
+ filter(lambda p: p.__name__ == hf_name, sgl_mm_processor_set)
189
+ )
190
+ if sgl_processor_cls:
191
+ return sgl_processor_cls[0]
192
+ raise ValueError(
193
+ f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
194
+ )
195
+
196
+ def __init__(self, hf_config, server_args, _processor):
197
+ assert hasattr(hf_config, "vision_config")
198
+ assert hasattr(hf_config, "text_config")
199
+ self.vision_config = hf_config.vision_config
200
+ self.text_config = hf_config.text_config
201
+ self.hf_config = hf_config
202
+
203
+ if vision_type := getattr(self.vision_config, "model_type"):
204
+ self.inner = self._get_sgl_processor_cls(vision_type)(
205
+ hf_config, server_args, _processor
206
+ )
207
+ else:
208
+ raise ValueError(
209
+ f"Required `vision_config.model_type` is not found in hf_config: `{hf_config}`"
210
+ )
211
+
212
+ async def process_mm_data_async(self, *args, **kwargs):
213
+ return await self.inner.process_mm_data_async(*args, **kwargs)
@@ -0,0 +1,127 @@
1
+ import asyncio
2
+ import math
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ from transformers import PretrainedConfig
7
+ from transformers.models.pixtral.image_processing_pixtral import (
8
+ _num_image_tokens as _get_pixtral_hf_num_image_tokens,
9
+ )
10
+
11
+ from sglang.srt.managers.multimodal_processors.base_processor import (
12
+ BaseMultimodalProcessor,
13
+ MultimodalSpecialTokens,
14
+ )
15
+ from sglang.srt.managers.schedule_batch import (
16
+ Modality,
17
+ MultimodalDataItem,
18
+ MultimodalInputs,
19
+ )
20
+ from sglang.srt.models.pixtral import PixtralVisionModel
21
+
22
+
23
+ class PixtralProcessor(BaseMultimodalProcessor):
24
+ models = [PixtralVisionModel]
25
+
26
+ PAD_TOKEN = "<pad>"
27
+ IMG_BREAK_TOKEN_ID = 12
28
+ IMG_END_TOKEN_ID = 13
29
+
30
+ def get_patch_grid_size(
31
+ self,
32
+ *,
33
+ image_width: int,
34
+ image_height: int,
35
+ ) -> tuple[int, int]:
36
+ max_width = max_height = self.image_size
37
+ patch_width = patch_height = self.patch_size
38
+
39
+ ratio = max(image_width / max_width, image_height / max_height)
40
+
41
+ if ratio > 1:
42
+ image_width = int(math.floor(image_width / ratio))
43
+ image_height = int(math.floor(image_height / ratio))
44
+
45
+ nrows, ncols = _get_pixtral_hf_num_image_tokens(
46
+ (image_height, image_width),
47
+ (patch_height, patch_width),
48
+ )
49
+
50
+ return ncols, nrows
51
+
52
+ def __init__(self, hf_config, server_args, _processor):
53
+ super().__init__(hf_config, server_args, _processor)
54
+ self.image_token_id = getattr(
55
+ hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
56
+ )
57
+ # Instantiate the patcher logic helper using the class defined above
58
+
59
+ self.vision_config = hf_config.vision_config
60
+ self.image_size = self.vision_config.image_size
61
+ self.patch_size = self.vision_config.patch_size
62
+ self.multimodal_tokens = MultimodalSpecialTokens(
63
+ image_token=_processor.image_token
64
+ )
65
+ _processor.tokenizer.add_special_tokens(
66
+ {
67
+ "pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN),
68
+ }
69
+ )
70
+
71
+ async def _resize(self, image):
72
+ num_w_tokens, num_h_tokens = self.get_patch_grid_size(
73
+ image_width=image.size[0],
74
+ image_height=image.size[1],
75
+ )
76
+ new_size = (num_w_tokens * self.patch_size, num_h_tokens * self.patch_size)
77
+ return image.resize(new_size)
78
+
79
+ async def process_mm_data_async(
80
+ self,
81
+ image_data: List[Union[str, bytes]],
82
+ input_text,
83
+ request_obj,
84
+ *args,
85
+ **kwargs,
86
+ ):
87
+ if not image_data:
88
+ return None
89
+
90
+ if isinstance(image_data, str):
91
+ image_data = [image_data]
92
+
93
+ mm_data = self.load_mm_data(
94
+ prompt=input_text,
95
+ multimodal_tokens=self.multimodal_tokens,
96
+ max_req_input_len=kwargs.get("max_req_input_len", 4096),
97
+ image_data=image_data,
98
+ return_text=True,
99
+ )
100
+
101
+ if mm_data.images:
102
+ resize_tasks = [self._resize(image) for image in mm_data.images]
103
+ mm_data.images = await asyncio.gather(*resize_tasks)
104
+
105
+ processor_output = self.process_mm_data(
106
+ input_text=mm_data.input_text,
107
+ images=mm_data.images,
108
+ )
109
+
110
+ if "pixel_values" in processor_output:
111
+ mm_items = [
112
+ MultimodalDataItem(
113
+ pixel_values=processor_output["pixel_values"],
114
+ image_sizes=processor_output["image_sizes"],
115
+ modality=Modality.IMAGE,
116
+ )
117
+ ]
118
+
119
+ input_ids = processor_output["input_ids"].view(-1).tolist()
120
+ processor_output.update(
121
+ input_ids=input_ids,
122
+ mm_items=mm_items,
123
+ # there's no im_start_id for pixtral, only im_token and im_end_token
124
+ im_end_id=self.IMG_END_TOKEN_ID,
125
+ im_token_id=self.image_token_id,
126
+ )
127
+ return processor_output
@@ -1,8 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import hashlib
4
- from enum import Enum, auto
5
-
6
3
  # Copyright 2023-2024 SGLang Team
7
4
  # Licensed under the Apache License, Version 2.0 (the "License");
8
5
  # you may not use this file except in compliance with the License.
@@ -30,12 +27,16 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
30
27
  It will be transformed from CPU scheduler to GPU model runner.
31
28
  - ForwardBatch is managed by `model_runner.py::ModelRunner`.
32
29
  It contains low-level tensor data. Most of the data consists of GPU tensors.
30
+
31
+ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
33
32
  """
34
33
 
35
34
  import copy
36
35
  import dataclasses
36
+ import hashlib
37
37
  import logging
38
38
  import threading
39
+ from enum import Enum, auto
39
40
  from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
40
41
 
41
42
  import numpy as np
@@ -51,6 +52,7 @@ from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMi
51
52
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
52
53
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
53
54
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
55
+ from sglang.srt.metrics.collector import TimeStats
54
56
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
55
57
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
56
58
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -73,6 +75,7 @@ global_server_args_dict = {
73
75
  "disable_radix_cache": ServerArgs.disable_radix_cache,
74
76
  "enable_deepep_moe": ServerArgs.enable_deepep_moe,
75
77
  "enable_dp_attention": ServerArgs.enable_dp_attention,
78
+ "enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
76
79
  "enable_ep_moe": ServerArgs.enable_ep_moe,
77
80
  "enable_nan_detection": ServerArgs.enable_nan_detection,
78
81
  "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
@@ -134,9 +137,9 @@ class FINISH_LENGTH(BaseFinishReason):
134
137
 
135
138
 
136
139
  class FINISH_ABORT(BaseFinishReason):
137
- def __init__(self, message="Unknown error", status_code=None, err_type=None):
140
+ def __init__(self, message=None, status_code=None, err_type=None):
138
141
  super().__init__(is_error=True)
139
- self.message = message
142
+ self.message = message or "Aborted"
140
143
  self.status_code = status_code
141
144
  self.err_type = err_type
142
145
 
@@ -434,6 +437,7 @@ class Req:
434
437
  self.sampling_params = sampling_params
435
438
  self.custom_logit_processor = custom_logit_processor
436
439
  self.return_hidden_states = return_hidden_states
440
+ self.lora_path = lora_path
437
441
 
438
442
  # Memory pool info
439
443
  self.req_pool_idx: Optional[int] = None
@@ -441,11 +445,13 @@ class Req:
441
445
  # Check finish
442
446
  self.tokenizer = None
443
447
  self.finished_reason = None
448
+ # Whether this request has finished output
449
+ self.finished_output = None
444
450
  # If we want to abort the request in the middle of the event loop, set this to true
445
451
  # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
446
452
  self.to_abort = False
447
453
  # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
448
- self.to_abort_message: str = "Unknown error"
454
+ self.to_abort_message: str = None
449
455
  self.stream = stream
450
456
  self.eos_token_ids = eos_token_ids
451
457
 
@@ -483,6 +489,13 @@ class Req:
483
489
  # For retraction
484
490
  self.is_retracted = False
485
491
 
492
+ # Incremental streamining
493
+ self.send_token_offset: int = 0
494
+ self.send_decode_id_offset: int = 0
495
+ # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
496
+ # because the decode server does not have the first output token logprobs
497
+ self.send_output_token_logprobs_offset: int = 0
498
+
486
499
  # Logprobs (arguments)
487
500
  self.return_logprob = return_logprob
488
501
  # Start index to compute logprob from.
@@ -492,11 +505,9 @@ class Req:
492
505
  self.temp_scaled_logprobs = False
493
506
  self.top_p_normalized_logprobs = False
494
507
 
495
- # Latency Breakdown
496
- self.queue_time_start = None
497
- self.queue_time_end = None
498
-
499
508
  # Logprobs (return values)
509
+ # True means the input logprob has been already sent to detokenizer.
510
+ self.input_logprob_sent: bool = False
500
511
  self.input_token_logprobs_val: Optional[List[float]] = None
501
512
  self.input_token_logprobs_idx: Optional[List[int]] = None
502
513
  self.input_top_logprobs_val: Optional[List[float]] = None
@@ -511,8 +522,10 @@ class Req:
511
522
  self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
512
523
 
513
524
  if return_logprob:
525
+ # shape: (bs, 1)
514
526
  self.output_token_logprobs_val = []
515
527
  self.output_token_logprobs_idx = []
528
+ # shape: (bs, k)
516
529
  self.output_top_logprobs_val = []
517
530
  self.output_top_logprobs_idx = []
518
531
  self.output_token_ids_logprobs_val = []
@@ -530,6 +543,7 @@ class Req:
530
543
 
531
544
  # Constrained decoding
532
545
  self.grammar: Optional[BaseGrammarObject] = None
546
+ self.grammar_wait_ct = 0
533
547
 
534
548
  # The number of cached tokens that were already cached in the KV cache
535
549
  self.cached_tokens = 0
@@ -538,7 +552,12 @@ class Req:
538
552
  # The number of verification forward passes in the speculative decoding.
539
553
  # This is used to compute the average acceptance length per request.
540
554
  self.spec_verify_ct = 0
541
- self.lora_path = lora_path
555
+
556
+ # For metrics
557
+ self.time_stats: TimeStats = TimeStats()
558
+ self.has_log_time_stats: bool = False
559
+ self.queue_time_start = None
560
+ self.queue_time_end = None
542
561
 
543
562
  # For disaggregation
544
563
  self.bootstrap_host: str = bootstrap_host
@@ -546,8 +565,6 @@ class Req:
546
565
  self.bootstrap_room: Optional[int] = bootstrap_room
547
566
  self.disagg_kv_sender: Optional[BaseKVSender] = None
548
567
 
549
- # used for warmup because we don't have a pair yet when init
550
- self.skip_kv_transfer: bool = False
551
568
  # the start index of the sent kv cache
552
569
  # We want to send it chunk by chunk for chunked prefill.
553
570
  # After every chunk forward, we do the following:
@@ -555,14 +572,14 @@ class Req:
555
572
  # start_send_idx = len(req.fill_ids)
556
573
  self.start_send_idx: int = 0
557
574
 
558
- self.metadata_buffer_index: int = -1
559
- # The first output_id transferred from prefill instance.
560
- self.transferred_output_id: Optional[int] = None
561
-
562
575
  # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
563
576
  # This is because kv is not ready in `process_prefill_chunk`.
564
577
  # We use `tmp_end_idx` to store the end index of the kv cache to send.
565
578
  self.tmp_end_idx: int = -1
579
+ self.metadata_buffer_index: int = -1
580
+
581
+ # The first output_id transferred from prefill instance.
582
+ self.transferred_output_id: Optional[int] = None
566
583
 
567
584
  @property
568
585
  def seqlen(self):
@@ -653,6 +670,11 @@ class Req:
653
670
  )
654
671
  return
655
672
 
673
+ if self.grammar is not None:
674
+ if self.grammar.is_terminated():
675
+ self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
676
+ return
677
+
656
678
  last_token_id = self.output_ids[-1]
657
679
 
658
680
  if not self.sampling_params.ignore_eos:
@@ -697,13 +719,41 @@ class Req:
697
719
  self.req_pool_idx = None
698
720
  self.already_computed = 0
699
721
 
722
+ def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
723
+ token_indices = req_to_token_pool.req_to_token[
724
+ self.req_pool_idx, : self.seqlen - 1
725
+ ]
726
+ self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)
727
+
728
+ def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
729
+ token_indices = req_to_token_pool.req_to_token[
730
+ self.req_pool_idx, : self.seqlen - 1
731
+ ]
732
+ token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
733
+ del self.kv_cache_cpu
734
+
735
+ def log_time_stats(self):
736
+ # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
737
+ if self.has_log_time_stats is True:
738
+ return
739
+
740
+ if self.bootstrap_room is not None:
741
+ prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
742
+ else:
743
+ prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
744
+ logger.info(f"{prefix}: {self.time_stats}")
745
+ self.has_log_time_stats = True
746
+
700
747
  def __repr__(self):
701
748
  return (
702
749
  f"Req(rid={self.rid}, "
703
- f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
750
+ f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
751
+ f"{self.grammar=}, "
752
+ f"{self.sampling_params=})"
704
753
  )
705
754
 
706
755
 
756
+ # Batch id
707
757
  bid = 0
708
758
 
709
759
 
@@ -745,6 +795,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
745
795
  out_cache_loc: torch.Tensor = None # shape: [b], int64
746
796
  output_ids: torch.Tensor = None # shape: [b], int64
747
797
 
798
+ # For multimodal inputs
799
+ multimodal_inputs: Optional[List] = None
800
+
748
801
  # The sum of all sequence lengths
749
802
  seq_lens_sum: int = None
750
803
 
@@ -859,7 +912,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
859
912
  error_msg = (
860
913
  f"{phase_str} out of memory. Try to lower your batch size.\n"
861
914
  f"Try to allocate {num_tokens} tokens.\n"
862
- f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
915
+ f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
863
916
  )
864
917
  logger.error(error_msg)
865
918
  if self.tree_cache is not None:
@@ -900,7 +953,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
900
953
  error_msg = (
901
954
  f"Prefill out of memory. Try to lower your batch size.\n"
902
955
  f"Try to allocate {extend_num_tokens} tokens.\n"
903
- f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
956
+ f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
904
957
  f"{self.token_to_kv_pool_allocator.available_size()=}\n"
905
958
  f"{self.tree_cache.evictable_size()=}\n"
906
959
  )
@@ -935,7 +988,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
935
988
  error_msg = (
936
989
  f"Decode out of memory. Try to lower your batch size.\n"
937
990
  f"Try to allocate {len(seq_lens)} tokens.\n"
938
- f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
991
+ f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
939
992
  f"{self.token_to_kv_pool_allocator.available_size()=}\n"
940
993
  f"{self.tree_cache.evictable_size()=}\n"
941
994
  )
@@ -1050,6 +1103,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1050
1103
  # Copy prefix and do some basic check
1051
1104
  input_embeds = []
1052
1105
  extend_input_logprob_token_ids = []
1106
+ multimodal_inputs = []
1053
1107
 
1054
1108
  for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1055
1109
  req.req_pool_idx = req_pool_indices[i]
@@ -1065,6 +1119,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1065
1119
  # If req.input_embeds is already a list, append its content directly
1066
1120
  input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
1067
1121
 
1122
+ multimodal_inputs.append(req.multimodal_inputs)
1123
+
1068
1124
  req.cached_tokens += pre_len - req.already_computed
1069
1125
  req.already_computed = seq_len
1070
1126
  req.is_retracted = False
@@ -1147,6 +1203,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1147
1203
  if input_embeds
1148
1204
  else None
1149
1205
  )
1206
+ for mm_input in multimodal_inputs:
1207
+ if mm_input is None:
1208
+ continue
1209
+ for mm_item in mm_input.mm_items:
1210
+ pixel_values = getattr(mm_item, "pixel_values", None)
1211
+ if isinstance(pixel_values, torch.Tensor):
1212
+ mm_item.pixel_values = pixel_values.to(
1213
+ self.device, non_blocking=True
1214
+ )
1215
+ self.multimodal_inputs = multimodal_inputs
1150
1216
  self.seq_lens_sum = sum(seq_lens)
1151
1217
 
1152
1218
  if self.return_logprob:
@@ -1431,7 +1497,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1431
1497
  i
1432
1498
  for i in range(len(self.reqs))
1433
1499
  if not self.reqs[i].finished()
1434
- and not self.reqs[i] in chunked_req_to_exclude
1500
+ and self.reqs[i] not in chunked_req_to_exclude
1435
1501
  ]
1436
1502
 
1437
1503
  if keep_indices is None or len(keep_indices) == 0:
@@ -1452,6 +1518,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1452
1518
  self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
1453
1519
 
1454
1520
  self.reqs = [self.reqs[i] for i in keep_indices]
1521
+ if self.multimodal_inputs is not None:
1522
+ self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1455
1523
  self.req_pool_indices = self.req_pool_indices[keep_indices_device]
1456
1524
  self.seq_lens = self.seq_lens[keep_indices_device]
1457
1525
  self.out_cache_loc = None
@@ -1500,6 +1568,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1500
1568
  self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1501
1569
  self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1502
1570
  self.reqs.extend(other.reqs)
1571
+ if self.multimodal_inputs is not None:
1572
+ self.multimodal_inputs.extend(other.multimodal_inputs)
1503
1573
 
1504
1574
  self.return_logprob |= other.return_logprob
1505
1575
  self.has_stream |= other.has_stream
@@ -1558,7 +1628,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1558
1628
  extend_seq_lens=extend_seq_lens,
1559
1629
  extend_prefix_lens=extend_prefix_lens,
1560
1630
  extend_logprob_start_lens=extend_logprob_start_lens,
1561
- multimodal_inputs=[r.multimodal_inputs for r in self.reqs],
1631
+ multimodal_inputs=self.multimodal_inputs,
1562
1632
  encoder_cached=self.encoder_cached,
1563
1633
  encoder_lens=self.encoder_lens,
1564
1634
  encoder_lens_cpu=self.encoder_lens_cpu,
@@ -455,7 +455,10 @@ class PrefillAdder:
455
455
  total_tokens = req.extend_input_len + min(
456
456
  req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
457
457
  )
458
- input_tokens = req.extend_input_len
458
+ input_tokens = (
459
+ -(-req.extend_input_len // self.tree_cache.page_size)
460
+ * self.tree_cache.page_size
461
+ )
459
462
  prefix_len = len(req.prefix_indices)
460
463
 
461
464
  if total_tokens >= self.rem_total_tokens:
@@ -465,9 +468,6 @@ class PrefillAdder:
465
468
  return AddReqResult.OTHER
466
469
 
467
470
  with self._lock_node(req.last_node):
468
- if total_tokens > self.rem_total_tokens:
469
- return AddReqResult.NO_TOKEN
470
-
471
471
  if (
472
472
  enable_hierarchical_cache
473
473
  and req.last_node_global is not None
@@ -477,7 +477,10 @@ class PrefillAdder:
477
477
  req.last_node_global, req.prefix_indices
478
478
  )
479
479
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
480
- input_tokens = req.extend_input_len
480
+ input_tokens = (
481
+ -(-req.extend_input_len // self.tree_cache.page_size)
482
+ * self.tree_cache.page_size
483
+ )
481
484
  prefix_len = len(req.prefix_indices)
482
485
 
483
486
  if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
@@ -493,12 +496,12 @@ class PrefillAdder:
493
496
  ),
494
497
  )
495
498
  else:
496
- if self.rem_chunk_tokens == 0:
499
+ # Make sure at least one page is available
500
+ trunc_len = self.rem_chunk_tokens - self.tree_cache.page_size + 1
501
+ if trunc_len <= 0:
497
502
  return AddReqResult.OTHER
498
503
 
499
504
  # Chunked prefill
500
- trunc_len = self.rem_chunk_tokens
501
-
502
505
  req.extend_input_len = trunc_len
503
506
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
504
507