sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +2 -2
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +9 -7
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +48 -43
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +7 -2
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +227 -120
- sglang/srt/disaggregation/nixl/conn.py +1 -0
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +7 -1
- sglang/srt/entrypoints/engine.py +17 -2
- sglang/srt/entrypoints/http_server.py +17 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +1 -1
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +72 -71
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +3 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +76 -24
- sglang/srt/managers/schedule_policy.py +0 -3
- sglang/srt/managers/scheduler.py +113 -88
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +133 -34
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/memory_pool.py +2 -0
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +19 -14
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +23 -20
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +5 -6
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +30 -4
- sglang/srt/openai_api/protocol.py +0 -8
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +34 -4
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +6 -5
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +89 -14
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,20 @@
|
|
1
1
|
import asyncio
|
2
|
+
import importlib
|
2
3
|
from typing import List, Optional, Union
|
3
4
|
|
4
5
|
import numpy as np
|
6
|
+
from transformers.models.auto.processing_auto import (
|
7
|
+
PROCESSOR_MAPPING_NAMES as HF_MAPPING_NAMES,
|
8
|
+
)
|
5
9
|
|
10
|
+
import sglang.srt.managers.multimodal_processor as sgl_mm_processor_utils
|
6
11
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
7
12
|
BaseMultimodalProcessor,
|
8
13
|
)
|
9
14
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
10
15
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
11
16
|
from sglang.srt.models.llava import (
|
17
|
+
LlavaForConditionalGeneration,
|
12
18
|
LlavaLlamaForCausalLM,
|
13
19
|
LlavaMistralForCausalLM,
|
14
20
|
LlavaQwenForCausalLM,
|
@@ -133,6 +139,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|
133
139
|
img_data, aspect_ratio, grid_pinpoints
|
134
140
|
)
|
135
141
|
)
|
142
|
+
|
136
143
|
res = await asyncio.gather(*res)
|
137
144
|
for pixel_v, image_h, image_s in res:
|
138
145
|
pixel_values.append(pixel_v)
|
@@ -165,3 +172,42 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|
165
172
|
)
|
166
173
|
],
|
167
174
|
}
|
175
|
+
|
176
|
+
|
177
|
+
class LlavaMultimodalProcessor(BaseMultimodalProcessor):
|
178
|
+
"""
|
179
|
+
This is a wrapper class used to identify the multimodal processor for Llava architecture models.
|
180
|
+
"""
|
181
|
+
|
182
|
+
models = [LlavaForConditionalGeneration]
|
183
|
+
|
184
|
+
def _get_sgl_processor_cls(self, model_type: str):
|
185
|
+
if hf_name := HF_MAPPING_NAMES.get(model_type):
|
186
|
+
sgl_mm_processor_set = sgl_mm_processor_utils.PROCESSOR_MAPPING.values()
|
187
|
+
sgl_processor_cls = list(
|
188
|
+
filter(lambda p: p.__name__ == hf_name, sgl_mm_processor_set)
|
189
|
+
)
|
190
|
+
if sgl_processor_cls:
|
191
|
+
return sgl_processor_cls[0]
|
192
|
+
raise ValueError(
|
193
|
+
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
|
194
|
+
)
|
195
|
+
|
196
|
+
def __init__(self, hf_config, server_args, _processor):
|
197
|
+
assert hasattr(hf_config, "vision_config")
|
198
|
+
assert hasattr(hf_config, "text_config")
|
199
|
+
self.vision_config = hf_config.vision_config
|
200
|
+
self.text_config = hf_config.text_config
|
201
|
+
self.hf_config = hf_config
|
202
|
+
|
203
|
+
if vision_type := getattr(self.vision_config, "model_type"):
|
204
|
+
self.inner = self._get_sgl_processor_cls(vision_type)(
|
205
|
+
hf_config, server_args, _processor
|
206
|
+
)
|
207
|
+
else:
|
208
|
+
raise ValueError(
|
209
|
+
f"Required `vision_config.model_type` is not found in hf_config: `{hf_config}`"
|
210
|
+
)
|
211
|
+
|
212
|
+
async def process_mm_data_async(self, *args, **kwargs):
|
213
|
+
return await self.inner.process_mm_data_async(*args, **kwargs)
|
@@ -0,0 +1,127 @@
|
|
1
|
+
import asyncio
|
2
|
+
import math
|
3
|
+
from typing import List, Optional, Union
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from transformers import PretrainedConfig
|
7
|
+
from transformers.models.pixtral.image_processing_pixtral import (
|
8
|
+
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
9
|
+
)
|
10
|
+
|
11
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
12
|
+
BaseMultimodalProcessor,
|
13
|
+
MultimodalSpecialTokens,
|
14
|
+
)
|
15
|
+
from sglang.srt.managers.schedule_batch import (
|
16
|
+
Modality,
|
17
|
+
MultimodalDataItem,
|
18
|
+
MultimodalInputs,
|
19
|
+
)
|
20
|
+
from sglang.srt.models.pixtral import PixtralVisionModel
|
21
|
+
|
22
|
+
|
23
|
+
class PixtralProcessor(BaseMultimodalProcessor):
|
24
|
+
models = [PixtralVisionModel]
|
25
|
+
|
26
|
+
PAD_TOKEN = "<pad>"
|
27
|
+
IMG_BREAK_TOKEN_ID = 12
|
28
|
+
IMG_END_TOKEN_ID = 13
|
29
|
+
|
30
|
+
def get_patch_grid_size(
|
31
|
+
self,
|
32
|
+
*,
|
33
|
+
image_width: int,
|
34
|
+
image_height: int,
|
35
|
+
) -> tuple[int, int]:
|
36
|
+
max_width = max_height = self.image_size
|
37
|
+
patch_width = patch_height = self.patch_size
|
38
|
+
|
39
|
+
ratio = max(image_width / max_width, image_height / max_height)
|
40
|
+
|
41
|
+
if ratio > 1:
|
42
|
+
image_width = int(math.floor(image_width / ratio))
|
43
|
+
image_height = int(math.floor(image_height / ratio))
|
44
|
+
|
45
|
+
nrows, ncols = _get_pixtral_hf_num_image_tokens(
|
46
|
+
(image_height, image_width),
|
47
|
+
(patch_height, patch_width),
|
48
|
+
)
|
49
|
+
|
50
|
+
return ncols, nrows
|
51
|
+
|
52
|
+
def __init__(self, hf_config, server_args, _processor):
|
53
|
+
super().__init__(hf_config, server_args, _processor)
|
54
|
+
self.image_token_id = getattr(
|
55
|
+
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
|
56
|
+
)
|
57
|
+
# Instantiate the patcher logic helper using the class defined above
|
58
|
+
|
59
|
+
self.vision_config = hf_config.vision_config
|
60
|
+
self.image_size = self.vision_config.image_size
|
61
|
+
self.patch_size = self.vision_config.patch_size
|
62
|
+
self.multimodal_tokens = MultimodalSpecialTokens(
|
63
|
+
image_token=_processor.image_token
|
64
|
+
)
|
65
|
+
_processor.tokenizer.add_special_tokens(
|
66
|
+
{
|
67
|
+
"pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN),
|
68
|
+
}
|
69
|
+
)
|
70
|
+
|
71
|
+
async def _resize(self, image):
|
72
|
+
num_w_tokens, num_h_tokens = self.get_patch_grid_size(
|
73
|
+
image_width=image.size[0],
|
74
|
+
image_height=image.size[1],
|
75
|
+
)
|
76
|
+
new_size = (num_w_tokens * self.patch_size, num_h_tokens * self.patch_size)
|
77
|
+
return image.resize(new_size)
|
78
|
+
|
79
|
+
async def process_mm_data_async(
|
80
|
+
self,
|
81
|
+
image_data: List[Union[str, bytes]],
|
82
|
+
input_text,
|
83
|
+
request_obj,
|
84
|
+
*args,
|
85
|
+
**kwargs,
|
86
|
+
):
|
87
|
+
if not image_data:
|
88
|
+
return None
|
89
|
+
|
90
|
+
if isinstance(image_data, str):
|
91
|
+
image_data = [image_data]
|
92
|
+
|
93
|
+
mm_data = self.load_mm_data(
|
94
|
+
prompt=input_text,
|
95
|
+
multimodal_tokens=self.multimodal_tokens,
|
96
|
+
max_req_input_len=kwargs.get("max_req_input_len", 4096),
|
97
|
+
image_data=image_data,
|
98
|
+
return_text=True,
|
99
|
+
)
|
100
|
+
|
101
|
+
if mm_data.images:
|
102
|
+
resize_tasks = [self._resize(image) for image in mm_data.images]
|
103
|
+
mm_data.images = await asyncio.gather(*resize_tasks)
|
104
|
+
|
105
|
+
processor_output = self.process_mm_data(
|
106
|
+
input_text=mm_data.input_text,
|
107
|
+
images=mm_data.images,
|
108
|
+
)
|
109
|
+
|
110
|
+
if "pixel_values" in processor_output:
|
111
|
+
mm_items = [
|
112
|
+
MultimodalDataItem(
|
113
|
+
pixel_values=processor_output["pixel_values"],
|
114
|
+
image_sizes=processor_output["image_sizes"],
|
115
|
+
modality=Modality.IMAGE,
|
116
|
+
)
|
117
|
+
]
|
118
|
+
|
119
|
+
input_ids = processor_output["input_ids"].view(-1).tolist()
|
120
|
+
processor_output.update(
|
121
|
+
input_ids=input_ids,
|
122
|
+
mm_items=mm_items,
|
123
|
+
# there's no im_start_id for pixtral, only im_token and im_end_token
|
124
|
+
im_end_id=self.IMG_END_TOKEN_ID,
|
125
|
+
im_token_id=self.image_token_id,
|
126
|
+
)
|
127
|
+
return processor_output
|
@@ -1,8 +1,5 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import hashlib
|
4
|
-
from enum import Enum, auto
|
5
|
-
|
6
3
|
# Copyright 2023-2024 SGLang Team
|
7
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
8
5
|
# you may not use this file except in compliance with the License.
|
@@ -30,12 +27,16 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|
30
27
|
It will be transformed from CPU scheduler to GPU model runner.
|
31
28
|
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
32
29
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
30
|
+
|
31
|
+
TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
|
33
32
|
"""
|
34
33
|
|
35
34
|
import copy
|
36
35
|
import dataclasses
|
36
|
+
import hashlib
|
37
37
|
import logging
|
38
38
|
import threading
|
39
|
+
from enum import Enum, auto
|
39
40
|
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
40
41
|
|
41
42
|
import numpy as np
|
@@ -51,6 +52,7 @@ from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMi
|
|
51
52
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
52
53
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
53
54
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
55
|
+
from sglang.srt.metrics.collector import TimeStats
|
54
56
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
55
57
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
56
58
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -73,6 +75,7 @@ global_server_args_dict = {
|
|
73
75
|
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
74
76
|
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
75
77
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
78
|
+
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
|
76
79
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
77
80
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
78
81
|
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
@@ -134,9 +137,9 @@ class FINISH_LENGTH(BaseFinishReason):
|
|
134
137
|
|
135
138
|
|
136
139
|
class FINISH_ABORT(BaseFinishReason):
|
137
|
-
def __init__(self, message=
|
140
|
+
def __init__(self, message=None, status_code=None, err_type=None):
|
138
141
|
super().__init__(is_error=True)
|
139
|
-
self.message = message
|
142
|
+
self.message = message or "Aborted"
|
140
143
|
self.status_code = status_code
|
141
144
|
self.err_type = err_type
|
142
145
|
|
@@ -434,6 +437,7 @@ class Req:
|
|
434
437
|
self.sampling_params = sampling_params
|
435
438
|
self.custom_logit_processor = custom_logit_processor
|
436
439
|
self.return_hidden_states = return_hidden_states
|
440
|
+
self.lora_path = lora_path
|
437
441
|
|
438
442
|
# Memory pool info
|
439
443
|
self.req_pool_idx: Optional[int] = None
|
@@ -441,11 +445,13 @@ class Req:
|
|
441
445
|
# Check finish
|
442
446
|
self.tokenizer = None
|
443
447
|
self.finished_reason = None
|
448
|
+
# Whether this request has finished output
|
449
|
+
self.finished_output = None
|
444
450
|
# If we want to abort the request in the middle of the event loop, set this to true
|
445
451
|
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
|
446
452
|
self.to_abort = False
|
447
453
|
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
|
448
|
-
self.to_abort_message: str =
|
454
|
+
self.to_abort_message: str = None
|
449
455
|
self.stream = stream
|
450
456
|
self.eos_token_ids = eos_token_ids
|
451
457
|
|
@@ -483,6 +489,13 @@ class Req:
|
|
483
489
|
# For retraction
|
484
490
|
self.is_retracted = False
|
485
491
|
|
492
|
+
# Incremental streamining
|
493
|
+
self.send_token_offset: int = 0
|
494
|
+
self.send_decode_id_offset: int = 0
|
495
|
+
# TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
|
496
|
+
# because the decode server does not have the first output token logprobs
|
497
|
+
self.send_output_token_logprobs_offset: int = 0
|
498
|
+
|
486
499
|
# Logprobs (arguments)
|
487
500
|
self.return_logprob = return_logprob
|
488
501
|
# Start index to compute logprob from.
|
@@ -492,11 +505,9 @@ class Req:
|
|
492
505
|
self.temp_scaled_logprobs = False
|
493
506
|
self.top_p_normalized_logprobs = False
|
494
507
|
|
495
|
-
# Latency Breakdown
|
496
|
-
self.queue_time_start = None
|
497
|
-
self.queue_time_end = None
|
498
|
-
|
499
508
|
# Logprobs (return values)
|
509
|
+
# True means the input logprob has been already sent to detokenizer.
|
510
|
+
self.input_logprob_sent: bool = False
|
500
511
|
self.input_token_logprobs_val: Optional[List[float]] = None
|
501
512
|
self.input_token_logprobs_idx: Optional[List[int]] = None
|
502
513
|
self.input_top_logprobs_val: Optional[List[float]] = None
|
@@ -511,8 +522,10 @@ class Req:
|
|
511
522
|
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
|
512
523
|
|
513
524
|
if return_logprob:
|
525
|
+
# shape: (bs, 1)
|
514
526
|
self.output_token_logprobs_val = []
|
515
527
|
self.output_token_logprobs_idx = []
|
528
|
+
# shape: (bs, k)
|
516
529
|
self.output_top_logprobs_val = []
|
517
530
|
self.output_top_logprobs_idx = []
|
518
531
|
self.output_token_ids_logprobs_val = []
|
@@ -530,6 +543,7 @@ class Req:
|
|
530
543
|
|
531
544
|
# Constrained decoding
|
532
545
|
self.grammar: Optional[BaseGrammarObject] = None
|
546
|
+
self.grammar_wait_ct = 0
|
533
547
|
|
534
548
|
# The number of cached tokens that were already cached in the KV cache
|
535
549
|
self.cached_tokens = 0
|
@@ -538,7 +552,12 @@ class Req:
|
|
538
552
|
# The number of verification forward passes in the speculative decoding.
|
539
553
|
# This is used to compute the average acceptance length per request.
|
540
554
|
self.spec_verify_ct = 0
|
541
|
-
|
555
|
+
|
556
|
+
# For metrics
|
557
|
+
self.time_stats: TimeStats = TimeStats()
|
558
|
+
self.has_log_time_stats: bool = False
|
559
|
+
self.queue_time_start = None
|
560
|
+
self.queue_time_end = None
|
542
561
|
|
543
562
|
# For disaggregation
|
544
563
|
self.bootstrap_host: str = bootstrap_host
|
@@ -546,8 +565,6 @@ class Req:
|
|
546
565
|
self.bootstrap_room: Optional[int] = bootstrap_room
|
547
566
|
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
548
567
|
|
549
|
-
# used for warmup because we don't have a pair yet when init
|
550
|
-
self.skip_kv_transfer: bool = False
|
551
568
|
# the start index of the sent kv cache
|
552
569
|
# We want to send it chunk by chunk for chunked prefill.
|
553
570
|
# After every chunk forward, we do the following:
|
@@ -555,14 +572,14 @@ class Req:
|
|
555
572
|
# start_send_idx = len(req.fill_ids)
|
556
573
|
self.start_send_idx: int = 0
|
557
574
|
|
558
|
-
self.metadata_buffer_index: int = -1
|
559
|
-
# The first output_id transferred from prefill instance.
|
560
|
-
self.transferred_output_id: Optional[int] = None
|
561
|
-
|
562
575
|
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
|
563
576
|
# This is because kv is not ready in `process_prefill_chunk`.
|
564
577
|
# We use `tmp_end_idx` to store the end index of the kv cache to send.
|
565
578
|
self.tmp_end_idx: int = -1
|
579
|
+
self.metadata_buffer_index: int = -1
|
580
|
+
|
581
|
+
# The first output_id transferred from prefill instance.
|
582
|
+
self.transferred_output_id: Optional[int] = None
|
566
583
|
|
567
584
|
@property
|
568
585
|
def seqlen(self):
|
@@ -653,6 +670,11 @@ class Req:
|
|
653
670
|
)
|
654
671
|
return
|
655
672
|
|
673
|
+
if self.grammar is not None:
|
674
|
+
if self.grammar.is_terminated():
|
675
|
+
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
|
676
|
+
return
|
677
|
+
|
656
678
|
last_token_id = self.output_ids[-1]
|
657
679
|
|
658
680
|
if not self.sampling_params.ignore_eos:
|
@@ -697,13 +719,41 @@ class Req:
|
|
697
719
|
self.req_pool_idx = None
|
698
720
|
self.already_computed = 0
|
699
721
|
|
722
|
+
def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
|
723
|
+
token_indices = req_to_token_pool.req_to_token[
|
724
|
+
self.req_pool_idx, : self.seqlen - 1
|
725
|
+
]
|
726
|
+
self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)
|
727
|
+
|
728
|
+
def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
|
729
|
+
token_indices = req_to_token_pool.req_to_token[
|
730
|
+
self.req_pool_idx, : self.seqlen - 1
|
731
|
+
]
|
732
|
+
token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
|
733
|
+
del self.kv_cache_cpu
|
734
|
+
|
735
|
+
def log_time_stats(self):
|
736
|
+
# If overlap schedule, we schedule one decode batch ahead so this gets called twice.
|
737
|
+
if self.has_log_time_stats is True:
|
738
|
+
return
|
739
|
+
|
740
|
+
if self.bootstrap_room is not None:
|
741
|
+
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
|
742
|
+
else:
|
743
|
+
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
|
744
|
+
logger.info(f"{prefix}: {self.time_stats}")
|
745
|
+
self.has_log_time_stats = True
|
746
|
+
|
700
747
|
def __repr__(self):
|
701
748
|
return (
|
702
749
|
f"Req(rid={self.rid}, "
|
703
|
-
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}
|
750
|
+
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
|
751
|
+
f"{self.grammar=}, "
|
752
|
+
f"{self.sampling_params=})"
|
704
753
|
)
|
705
754
|
|
706
755
|
|
756
|
+
# Batch id
|
707
757
|
bid = 0
|
708
758
|
|
709
759
|
|
@@ -862,7 +912,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
862
912
|
error_msg = (
|
863
913
|
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
864
914
|
f"Try to allocate {num_tokens} tokens.\n"
|
865
|
-
f"
|
915
|
+
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
866
916
|
)
|
867
917
|
logger.error(error_msg)
|
868
918
|
if self.tree_cache is not None:
|
@@ -903,7 +953,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
903
953
|
error_msg = (
|
904
954
|
f"Prefill out of memory. Try to lower your batch size.\n"
|
905
955
|
f"Try to allocate {extend_num_tokens} tokens.\n"
|
906
|
-
f"
|
956
|
+
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
907
957
|
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
908
958
|
f"{self.tree_cache.evictable_size()=}\n"
|
909
959
|
)
|
@@ -938,7 +988,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
938
988
|
error_msg = (
|
939
989
|
f"Decode out of memory. Try to lower your batch size.\n"
|
940
990
|
f"Try to allocate {len(seq_lens)} tokens.\n"
|
941
|
-
f"
|
991
|
+
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
942
992
|
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
943
993
|
f"{self.tree_cache.evictable_size()=}\n"
|
944
994
|
)
|
@@ -1447,7 +1497,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1447
1497
|
i
|
1448
1498
|
for i in range(len(self.reqs))
|
1449
1499
|
if not self.reqs[i].finished()
|
1450
|
-
and
|
1500
|
+
and self.reqs[i] not in chunked_req_to_exclude
|
1451
1501
|
]
|
1452
1502
|
|
1453
1503
|
if keep_indices is None or len(keep_indices) == 0:
|
@@ -1468,7 +1518,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1468
1518
|
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
1469
1519
|
|
1470
1520
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
1471
|
-
|
1521
|
+
if self.multimodal_inputs is not None:
|
1522
|
+
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
1472
1523
|
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
1473
1524
|
self.seq_lens = self.seq_lens[keep_indices_device]
|
1474
1525
|
self.out_cache_loc = None
|
@@ -1517,7 +1568,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1517
1568
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
1518
1569
|
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
|
1519
1570
|
self.reqs.extend(other.reqs)
|
1520
|
-
self.multimodal_inputs
|
1571
|
+
if self.multimodal_inputs is not None:
|
1572
|
+
self.multimodal_inputs.extend(other.multimodal_inputs)
|
1521
1573
|
|
1522
1574
|
self.return_logprob |= other.return_logprob
|
1523
1575
|
self.has_stream |= other.has_stream
|
@@ -468,9 +468,6 @@ class PrefillAdder:
|
|
468
468
|
return AddReqResult.OTHER
|
469
469
|
|
470
470
|
with self._lock_node(req.last_node):
|
471
|
-
if total_tokens > self.rem_total_tokens:
|
472
|
-
return AddReqResult.NO_TOKEN
|
473
|
-
|
474
471
|
if (
|
475
472
|
enable_hierarchical_cache
|
476
473
|
and req.last_node_global is not None
|