sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,20 @@
|
|
1
1
|
import asyncio
|
2
|
+
import importlib
|
2
3
|
from typing import List, Optional, Union
|
3
4
|
|
4
5
|
import numpy as np
|
6
|
+
from transformers.models.auto.processing_auto import (
|
7
|
+
PROCESSOR_MAPPING_NAMES as HF_MAPPING_NAMES,
|
8
|
+
)
|
5
9
|
|
10
|
+
import sglang.srt.managers.multimodal_processor as sgl_mm_processor_utils
|
6
11
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
7
12
|
BaseMultimodalProcessor,
|
8
13
|
)
|
9
14
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
10
15
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
11
16
|
from sglang.srt.models.llava import (
|
17
|
+
LlavaForConditionalGeneration,
|
12
18
|
LlavaLlamaForCausalLM,
|
13
19
|
LlavaMistralForCausalLM,
|
14
20
|
LlavaQwenForCausalLM,
|
@@ -133,6 +139,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|
133
139
|
img_data, aspect_ratio, grid_pinpoints
|
134
140
|
)
|
135
141
|
)
|
142
|
+
|
136
143
|
res = await asyncio.gather(*res)
|
137
144
|
for pixel_v, image_h, image_s in res:
|
138
145
|
pixel_values.append(pixel_v)
|
@@ -165,3 +172,42 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|
165
172
|
)
|
166
173
|
],
|
167
174
|
}
|
175
|
+
|
176
|
+
|
177
|
+
class LlavaMultimodalProcessor(BaseMultimodalProcessor):
|
178
|
+
"""
|
179
|
+
This is a wrapper class used to identify the multimodal processor for Llava architecture models.
|
180
|
+
"""
|
181
|
+
|
182
|
+
models = [LlavaForConditionalGeneration]
|
183
|
+
|
184
|
+
def _get_sgl_processor_cls(self, model_type: str):
|
185
|
+
if hf_name := HF_MAPPING_NAMES.get(model_type):
|
186
|
+
sgl_mm_processor_set = sgl_mm_processor_utils.PROCESSOR_MAPPING.values()
|
187
|
+
sgl_processor_cls = list(
|
188
|
+
filter(lambda p: p.__name__ == hf_name, sgl_mm_processor_set)
|
189
|
+
)
|
190
|
+
if sgl_processor_cls:
|
191
|
+
return sgl_processor_cls[0]
|
192
|
+
raise ValueError(
|
193
|
+
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
|
194
|
+
)
|
195
|
+
|
196
|
+
def __init__(self, hf_config, server_args, _processor):
|
197
|
+
assert hasattr(hf_config, "vision_config")
|
198
|
+
assert hasattr(hf_config, "text_config")
|
199
|
+
self.vision_config = hf_config.vision_config
|
200
|
+
self.text_config = hf_config.text_config
|
201
|
+
self.hf_config = hf_config
|
202
|
+
|
203
|
+
if vision_type := getattr(self.vision_config, "model_type"):
|
204
|
+
self.inner = self._get_sgl_processor_cls(vision_type)(
|
205
|
+
hf_config, server_args, _processor
|
206
|
+
)
|
207
|
+
else:
|
208
|
+
raise ValueError(
|
209
|
+
f"Required `vision_config.model_type` is not found in hf_config: `{hf_config}`"
|
210
|
+
)
|
211
|
+
|
212
|
+
async def process_mm_data_async(self, *args, **kwargs):
|
213
|
+
return await self.inner.process_mm_data_async(*args, **kwargs)
|
@@ -0,0 +1,127 @@
|
|
1
|
+
import asyncio
|
2
|
+
import math
|
3
|
+
from typing import List, Optional, Union
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from transformers import PretrainedConfig
|
7
|
+
from transformers.models.pixtral.image_processing_pixtral import (
|
8
|
+
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
9
|
+
)
|
10
|
+
|
11
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
12
|
+
BaseMultimodalProcessor,
|
13
|
+
MultimodalSpecialTokens,
|
14
|
+
)
|
15
|
+
from sglang.srt.managers.schedule_batch import (
|
16
|
+
Modality,
|
17
|
+
MultimodalDataItem,
|
18
|
+
MultimodalInputs,
|
19
|
+
)
|
20
|
+
from sglang.srt.models.pixtral import PixtralVisionModel
|
21
|
+
|
22
|
+
|
23
|
+
class PixtralProcessor(BaseMultimodalProcessor):
|
24
|
+
models = [PixtralVisionModel]
|
25
|
+
|
26
|
+
PAD_TOKEN = "<pad>"
|
27
|
+
IMG_BREAK_TOKEN_ID = 12
|
28
|
+
IMG_END_TOKEN_ID = 13
|
29
|
+
|
30
|
+
def get_patch_grid_size(
|
31
|
+
self,
|
32
|
+
*,
|
33
|
+
image_width: int,
|
34
|
+
image_height: int,
|
35
|
+
) -> tuple[int, int]:
|
36
|
+
max_width = max_height = self.image_size
|
37
|
+
patch_width = patch_height = self.patch_size
|
38
|
+
|
39
|
+
ratio = max(image_width / max_width, image_height / max_height)
|
40
|
+
|
41
|
+
if ratio > 1:
|
42
|
+
image_width = int(math.floor(image_width / ratio))
|
43
|
+
image_height = int(math.floor(image_height / ratio))
|
44
|
+
|
45
|
+
nrows, ncols = _get_pixtral_hf_num_image_tokens(
|
46
|
+
(image_height, image_width),
|
47
|
+
(patch_height, patch_width),
|
48
|
+
)
|
49
|
+
|
50
|
+
return ncols, nrows
|
51
|
+
|
52
|
+
def __init__(self, hf_config, server_args, _processor):
|
53
|
+
super().__init__(hf_config, server_args, _processor)
|
54
|
+
self.image_token_id = getattr(
|
55
|
+
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
|
56
|
+
)
|
57
|
+
# Instantiate the patcher logic helper using the class defined above
|
58
|
+
|
59
|
+
self.vision_config = hf_config.vision_config
|
60
|
+
self.image_size = self.vision_config.image_size
|
61
|
+
self.patch_size = self.vision_config.patch_size
|
62
|
+
self.multimodal_tokens = MultimodalSpecialTokens(
|
63
|
+
image_token=_processor.image_token
|
64
|
+
)
|
65
|
+
_processor.tokenizer.add_special_tokens(
|
66
|
+
{
|
67
|
+
"pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN),
|
68
|
+
}
|
69
|
+
)
|
70
|
+
|
71
|
+
async def _resize(self, image):
|
72
|
+
num_w_tokens, num_h_tokens = self.get_patch_grid_size(
|
73
|
+
image_width=image.size[0],
|
74
|
+
image_height=image.size[1],
|
75
|
+
)
|
76
|
+
new_size = (num_w_tokens * self.patch_size, num_h_tokens * self.patch_size)
|
77
|
+
return image.resize(new_size)
|
78
|
+
|
79
|
+
async def process_mm_data_async(
|
80
|
+
self,
|
81
|
+
image_data: List[Union[str, bytes]],
|
82
|
+
input_text,
|
83
|
+
request_obj,
|
84
|
+
*args,
|
85
|
+
**kwargs,
|
86
|
+
):
|
87
|
+
if not image_data:
|
88
|
+
return None
|
89
|
+
|
90
|
+
if isinstance(image_data, str):
|
91
|
+
image_data = [image_data]
|
92
|
+
|
93
|
+
mm_data = self.load_mm_data(
|
94
|
+
prompt=input_text,
|
95
|
+
multimodal_tokens=self.multimodal_tokens,
|
96
|
+
max_req_input_len=kwargs.get("max_req_input_len", 4096),
|
97
|
+
image_data=image_data,
|
98
|
+
return_text=True,
|
99
|
+
)
|
100
|
+
|
101
|
+
if mm_data.images:
|
102
|
+
resize_tasks = [self._resize(image) for image in mm_data.images]
|
103
|
+
mm_data.images = await asyncio.gather(*resize_tasks)
|
104
|
+
|
105
|
+
processor_output = self.process_mm_data(
|
106
|
+
input_text=mm_data.input_text,
|
107
|
+
images=mm_data.images,
|
108
|
+
)
|
109
|
+
|
110
|
+
if "pixel_values" in processor_output:
|
111
|
+
mm_items = [
|
112
|
+
MultimodalDataItem(
|
113
|
+
pixel_values=processor_output["pixel_values"],
|
114
|
+
image_sizes=processor_output["image_sizes"],
|
115
|
+
modality=Modality.IMAGE,
|
116
|
+
)
|
117
|
+
]
|
118
|
+
|
119
|
+
input_ids = processor_output["input_ids"].view(-1).tolist()
|
120
|
+
processor_output.update(
|
121
|
+
input_ids=input_ids,
|
122
|
+
mm_items=mm_items,
|
123
|
+
# there's no im_start_id for pixtral, only im_token and im_end_token
|
124
|
+
im_end_id=self.IMG_END_TOKEN_ID,
|
125
|
+
im_token_id=self.image_token_id,
|
126
|
+
)
|
127
|
+
return processor_output
|
@@ -1,8 +1,5 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import hashlib
|
4
|
-
from enum import Enum, auto
|
5
|
-
|
6
3
|
# Copyright 2023-2024 SGLang Team
|
7
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
8
5
|
# you may not use this file except in compliance with the License.
|
@@ -30,12 +27,16 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|
30
27
|
It will be transformed from CPU scheduler to GPU model runner.
|
31
28
|
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
32
29
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
30
|
+
|
31
|
+
TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
|
33
32
|
"""
|
34
33
|
|
35
34
|
import copy
|
36
35
|
import dataclasses
|
36
|
+
import hashlib
|
37
37
|
import logging
|
38
38
|
import threading
|
39
|
+
from enum import Enum, auto
|
39
40
|
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
40
41
|
|
41
42
|
import numpy as np
|
@@ -51,6 +52,7 @@ from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMi
|
|
51
52
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
52
53
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
53
54
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
55
|
+
from sglang.srt.metrics.collector import TimeStats
|
54
56
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
55
57
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
56
58
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -73,6 +75,7 @@ global_server_args_dict = {
|
|
73
75
|
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
74
76
|
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
75
77
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
78
|
+
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
|
76
79
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
77
80
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
78
81
|
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
@@ -134,9 +137,9 @@ class FINISH_LENGTH(BaseFinishReason):
|
|
134
137
|
|
135
138
|
|
136
139
|
class FINISH_ABORT(BaseFinishReason):
|
137
|
-
def __init__(self, message=
|
140
|
+
def __init__(self, message=None, status_code=None, err_type=None):
|
138
141
|
super().__init__(is_error=True)
|
139
|
-
self.message = message
|
142
|
+
self.message = message or "Aborted"
|
140
143
|
self.status_code = status_code
|
141
144
|
self.err_type = err_type
|
142
145
|
|
@@ -434,6 +437,7 @@ class Req:
|
|
434
437
|
self.sampling_params = sampling_params
|
435
438
|
self.custom_logit_processor = custom_logit_processor
|
436
439
|
self.return_hidden_states = return_hidden_states
|
440
|
+
self.lora_path = lora_path
|
437
441
|
|
438
442
|
# Memory pool info
|
439
443
|
self.req_pool_idx: Optional[int] = None
|
@@ -441,11 +445,13 @@ class Req:
|
|
441
445
|
# Check finish
|
442
446
|
self.tokenizer = None
|
443
447
|
self.finished_reason = None
|
448
|
+
# Whether this request has finished output
|
449
|
+
self.finished_output = None
|
444
450
|
# If we want to abort the request in the middle of the event loop, set this to true
|
445
451
|
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
|
446
452
|
self.to_abort = False
|
447
453
|
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
|
448
|
-
self.to_abort_message: str =
|
454
|
+
self.to_abort_message: str = None
|
449
455
|
self.stream = stream
|
450
456
|
self.eos_token_ids = eos_token_ids
|
451
457
|
|
@@ -483,6 +489,13 @@ class Req:
|
|
483
489
|
# For retraction
|
484
490
|
self.is_retracted = False
|
485
491
|
|
492
|
+
# Incremental streamining
|
493
|
+
self.send_token_offset: int = 0
|
494
|
+
self.send_decode_id_offset: int = 0
|
495
|
+
# TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
|
496
|
+
# because the decode server does not have the first output token logprobs
|
497
|
+
self.send_output_token_logprobs_offset: int = 0
|
498
|
+
|
486
499
|
# Logprobs (arguments)
|
487
500
|
self.return_logprob = return_logprob
|
488
501
|
# Start index to compute logprob from.
|
@@ -492,11 +505,9 @@ class Req:
|
|
492
505
|
self.temp_scaled_logprobs = False
|
493
506
|
self.top_p_normalized_logprobs = False
|
494
507
|
|
495
|
-
# Latency Breakdown
|
496
|
-
self.queue_time_start = None
|
497
|
-
self.queue_time_end = None
|
498
|
-
|
499
508
|
# Logprobs (return values)
|
509
|
+
# True means the input logprob has been already sent to detokenizer.
|
510
|
+
self.input_logprob_sent: bool = False
|
500
511
|
self.input_token_logprobs_val: Optional[List[float]] = None
|
501
512
|
self.input_token_logprobs_idx: Optional[List[int]] = None
|
502
513
|
self.input_top_logprobs_val: Optional[List[float]] = None
|
@@ -511,8 +522,10 @@ class Req:
|
|
511
522
|
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
|
512
523
|
|
513
524
|
if return_logprob:
|
525
|
+
# shape: (bs, 1)
|
514
526
|
self.output_token_logprobs_val = []
|
515
527
|
self.output_token_logprobs_idx = []
|
528
|
+
# shape: (bs, k)
|
516
529
|
self.output_top_logprobs_val = []
|
517
530
|
self.output_top_logprobs_idx = []
|
518
531
|
self.output_token_ids_logprobs_val = []
|
@@ -530,6 +543,7 @@ class Req:
|
|
530
543
|
|
531
544
|
# Constrained decoding
|
532
545
|
self.grammar: Optional[BaseGrammarObject] = None
|
546
|
+
self.grammar_wait_ct = 0
|
533
547
|
|
534
548
|
# The number of cached tokens that were already cached in the KV cache
|
535
549
|
self.cached_tokens = 0
|
@@ -538,7 +552,12 @@ class Req:
|
|
538
552
|
# The number of verification forward passes in the speculative decoding.
|
539
553
|
# This is used to compute the average acceptance length per request.
|
540
554
|
self.spec_verify_ct = 0
|
541
|
-
|
555
|
+
|
556
|
+
# For metrics
|
557
|
+
self.time_stats: TimeStats = TimeStats()
|
558
|
+
self.has_log_time_stats: bool = False
|
559
|
+
self.queue_time_start = None
|
560
|
+
self.queue_time_end = None
|
542
561
|
|
543
562
|
# For disaggregation
|
544
563
|
self.bootstrap_host: str = bootstrap_host
|
@@ -546,8 +565,6 @@ class Req:
|
|
546
565
|
self.bootstrap_room: Optional[int] = bootstrap_room
|
547
566
|
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
548
567
|
|
549
|
-
# used for warmup because we don't have a pair yet when init
|
550
|
-
self.skip_kv_transfer: bool = False
|
551
568
|
# the start index of the sent kv cache
|
552
569
|
# We want to send it chunk by chunk for chunked prefill.
|
553
570
|
# After every chunk forward, we do the following:
|
@@ -555,14 +572,14 @@ class Req:
|
|
555
572
|
# start_send_idx = len(req.fill_ids)
|
556
573
|
self.start_send_idx: int = 0
|
557
574
|
|
558
|
-
self.metadata_buffer_index: int = -1
|
559
|
-
# The first output_id transferred from prefill instance.
|
560
|
-
self.transferred_output_id: Optional[int] = None
|
561
|
-
|
562
575
|
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
|
563
576
|
# This is because kv is not ready in `process_prefill_chunk`.
|
564
577
|
# We use `tmp_end_idx` to store the end index of the kv cache to send.
|
565
578
|
self.tmp_end_idx: int = -1
|
579
|
+
self.metadata_buffer_index: int = -1
|
580
|
+
|
581
|
+
# The first output_id transferred from prefill instance.
|
582
|
+
self.transferred_output_id: Optional[int] = None
|
566
583
|
|
567
584
|
@property
|
568
585
|
def seqlen(self):
|
@@ -653,6 +670,11 @@ class Req:
|
|
653
670
|
)
|
654
671
|
return
|
655
672
|
|
673
|
+
if self.grammar is not None:
|
674
|
+
if self.grammar.is_terminated():
|
675
|
+
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
|
676
|
+
return
|
677
|
+
|
656
678
|
last_token_id = self.output_ids[-1]
|
657
679
|
|
658
680
|
if not self.sampling_params.ignore_eos:
|
@@ -697,13 +719,41 @@ class Req:
|
|
697
719
|
self.req_pool_idx = None
|
698
720
|
self.already_computed = 0
|
699
721
|
|
722
|
+
def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
|
723
|
+
token_indices = req_to_token_pool.req_to_token[
|
724
|
+
self.req_pool_idx, : self.seqlen - 1
|
725
|
+
]
|
726
|
+
self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)
|
727
|
+
|
728
|
+
def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
|
729
|
+
token_indices = req_to_token_pool.req_to_token[
|
730
|
+
self.req_pool_idx, : self.seqlen - 1
|
731
|
+
]
|
732
|
+
token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
|
733
|
+
del self.kv_cache_cpu
|
734
|
+
|
735
|
+
def log_time_stats(self):
|
736
|
+
# If overlap schedule, we schedule one decode batch ahead so this gets called twice.
|
737
|
+
if self.has_log_time_stats is True:
|
738
|
+
return
|
739
|
+
|
740
|
+
if self.bootstrap_room is not None:
|
741
|
+
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
|
742
|
+
else:
|
743
|
+
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
|
744
|
+
logger.info(f"{prefix}: {self.time_stats}")
|
745
|
+
self.has_log_time_stats = True
|
746
|
+
|
700
747
|
def __repr__(self):
|
701
748
|
return (
|
702
749
|
f"Req(rid={self.rid}, "
|
703
|
-
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}
|
750
|
+
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
|
751
|
+
f"{self.grammar=}, "
|
752
|
+
f"{self.sampling_params=})"
|
704
753
|
)
|
705
754
|
|
706
755
|
|
756
|
+
# Batch id
|
707
757
|
bid = 0
|
708
758
|
|
709
759
|
|
@@ -745,6 +795,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
745
795
|
out_cache_loc: torch.Tensor = None # shape: [b], int64
|
746
796
|
output_ids: torch.Tensor = None # shape: [b], int64
|
747
797
|
|
798
|
+
# For multimodal inputs
|
799
|
+
multimodal_inputs: Optional[List] = None
|
800
|
+
|
748
801
|
# The sum of all sequence lengths
|
749
802
|
seq_lens_sum: int = None
|
750
803
|
|
@@ -859,7 +912,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
859
912
|
error_msg = (
|
860
913
|
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
861
914
|
f"Try to allocate {num_tokens} tokens.\n"
|
862
|
-
f"
|
915
|
+
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
863
916
|
)
|
864
917
|
logger.error(error_msg)
|
865
918
|
if self.tree_cache is not None:
|
@@ -900,7 +953,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
900
953
|
error_msg = (
|
901
954
|
f"Prefill out of memory. Try to lower your batch size.\n"
|
902
955
|
f"Try to allocate {extend_num_tokens} tokens.\n"
|
903
|
-
f"
|
956
|
+
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
904
957
|
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
905
958
|
f"{self.tree_cache.evictable_size()=}\n"
|
906
959
|
)
|
@@ -935,7 +988,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
935
988
|
error_msg = (
|
936
989
|
f"Decode out of memory. Try to lower your batch size.\n"
|
937
990
|
f"Try to allocate {len(seq_lens)} tokens.\n"
|
938
|
-
f"
|
991
|
+
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
939
992
|
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
940
993
|
f"{self.tree_cache.evictable_size()=}\n"
|
941
994
|
)
|
@@ -1050,6 +1103,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1050
1103
|
# Copy prefix and do some basic check
|
1051
1104
|
input_embeds = []
|
1052
1105
|
extend_input_logprob_token_ids = []
|
1106
|
+
multimodal_inputs = []
|
1053
1107
|
|
1054
1108
|
for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
|
1055
1109
|
req.req_pool_idx = req_pool_indices[i]
|
@@ -1065,6 +1119,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1065
1119
|
# If req.input_embeds is already a list, append its content directly
|
1066
1120
|
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
1067
1121
|
|
1122
|
+
multimodal_inputs.append(req.multimodal_inputs)
|
1123
|
+
|
1068
1124
|
req.cached_tokens += pre_len - req.already_computed
|
1069
1125
|
req.already_computed = seq_len
|
1070
1126
|
req.is_retracted = False
|
@@ -1147,6 +1203,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1147
1203
|
if input_embeds
|
1148
1204
|
else None
|
1149
1205
|
)
|
1206
|
+
for mm_input in multimodal_inputs:
|
1207
|
+
if mm_input is None:
|
1208
|
+
continue
|
1209
|
+
for mm_item in mm_input.mm_items:
|
1210
|
+
pixel_values = getattr(mm_item, "pixel_values", None)
|
1211
|
+
if isinstance(pixel_values, torch.Tensor):
|
1212
|
+
mm_item.pixel_values = pixel_values.to(
|
1213
|
+
self.device, non_blocking=True
|
1214
|
+
)
|
1215
|
+
self.multimodal_inputs = multimodal_inputs
|
1150
1216
|
self.seq_lens_sum = sum(seq_lens)
|
1151
1217
|
|
1152
1218
|
if self.return_logprob:
|
@@ -1431,7 +1497,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1431
1497
|
i
|
1432
1498
|
for i in range(len(self.reqs))
|
1433
1499
|
if not self.reqs[i].finished()
|
1434
|
-
and
|
1500
|
+
and self.reqs[i] not in chunked_req_to_exclude
|
1435
1501
|
]
|
1436
1502
|
|
1437
1503
|
if keep_indices is None or len(keep_indices) == 0:
|
@@ -1452,6 +1518,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1452
1518
|
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
1453
1519
|
|
1454
1520
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
1521
|
+
if self.multimodal_inputs is not None:
|
1522
|
+
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
1455
1523
|
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
1456
1524
|
self.seq_lens = self.seq_lens[keep_indices_device]
|
1457
1525
|
self.out_cache_loc = None
|
@@ -1500,6 +1568,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1500
1568
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
1501
1569
|
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
|
1502
1570
|
self.reqs.extend(other.reqs)
|
1571
|
+
if self.multimodal_inputs is not None:
|
1572
|
+
self.multimodal_inputs.extend(other.multimodal_inputs)
|
1503
1573
|
|
1504
1574
|
self.return_logprob |= other.return_logprob
|
1505
1575
|
self.has_stream |= other.has_stream
|
@@ -1558,7 +1628,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1558
1628
|
extend_seq_lens=extend_seq_lens,
|
1559
1629
|
extend_prefix_lens=extend_prefix_lens,
|
1560
1630
|
extend_logprob_start_lens=extend_logprob_start_lens,
|
1561
|
-
multimodal_inputs=
|
1631
|
+
multimodal_inputs=self.multimodal_inputs,
|
1562
1632
|
encoder_cached=self.encoder_cached,
|
1563
1633
|
encoder_lens=self.encoder_lens,
|
1564
1634
|
encoder_lens_cpu=self.encoder_lens_cpu,
|
@@ -455,7 +455,10 @@ class PrefillAdder:
|
|
455
455
|
total_tokens = req.extend_input_len + min(
|
456
456
|
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
457
457
|
)
|
458
|
-
input_tokens =
|
458
|
+
input_tokens = (
|
459
|
+
-(-req.extend_input_len // self.tree_cache.page_size)
|
460
|
+
* self.tree_cache.page_size
|
461
|
+
)
|
459
462
|
prefix_len = len(req.prefix_indices)
|
460
463
|
|
461
464
|
if total_tokens >= self.rem_total_tokens:
|
@@ -465,9 +468,6 @@ class PrefillAdder:
|
|
465
468
|
return AddReqResult.OTHER
|
466
469
|
|
467
470
|
with self._lock_node(req.last_node):
|
468
|
-
if total_tokens > self.rem_total_tokens:
|
469
|
-
return AddReqResult.NO_TOKEN
|
470
|
-
|
471
471
|
if (
|
472
472
|
enable_hierarchical_cache
|
473
473
|
and req.last_node_global is not None
|
@@ -477,7 +477,10 @@ class PrefillAdder:
|
|
477
477
|
req.last_node_global, req.prefix_indices
|
478
478
|
)
|
479
479
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
480
|
-
input_tokens =
|
480
|
+
input_tokens = (
|
481
|
+
-(-req.extend_input_len // self.tree_cache.page_size)
|
482
|
+
* self.tree_cache.page_size
|
483
|
+
)
|
481
484
|
prefix_len = len(req.prefix_indices)
|
482
485
|
|
483
486
|
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
@@ -493,12 +496,12 @@ class PrefillAdder:
|
|
493
496
|
),
|
494
497
|
)
|
495
498
|
else:
|
496
|
-
|
499
|
+
# Make sure at least one page is available
|
500
|
+
trunc_len = self.rem_chunk_tokens - self.tree_cache.page_size + 1
|
501
|
+
if trunc_len <= 0:
|
497
502
|
return AddReqResult.OTHER
|
498
503
|
|
499
504
|
# Chunked prefill
|
500
|
-
trunc_len = self.rem_chunk_tokens
|
501
|
-
|
502
505
|
req.extend_input_len = trunc_len
|
503
506
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
504
507
|
|