sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +55 -2
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +4 -3
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/launch_server.py +3 -2
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -11
- sglang/srt/managers/detokenizer_manager.py +7 -4
- sglang/srt/managers/image_processor.py +1 -1
- sglang/srt/managers/io_struct.py +48 -12
- sglang/srt/managers/schedule_batch.py +42 -36
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +111 -46
- sglang/srt/managers/session_controller.py +0 -3
- sglang/srt/managers/tokenizer_manager.py +169 -100
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +14 -51
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +10 -12
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +391 -0
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +12 -9
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +10 -6
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +303 -204
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +253 -48
- sglang/test/test_utils.py +27 -7
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
- sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
- sglang/srt/layers/fused_moe_grok/layer.py +0 -630
- sglang-0.3.6.post2.dist-info/RECORD +0 -164
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ from vllm.distributed import (
|
|
23
23
|
tensor_model_parallel_all_gather,
|
24
24
|
)
|
25
25
|
|
26
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
26
27
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
27
28
|
|
28
29
|
|
@@ -163,7 +164,7 @@ class LogitsProcessor(nn.Module):
|
|
163
164
|
self,
|
164
165
|
input_ids,
|
165
166
|
hidden_states,
|
166
|
-
|
167
|
+
lm_head: VocabParallelEmbedding,
|
167
168
|
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
168
169
|
):
|
169
170
|
if isinstance(logits_metadata, ForwardBatch):
|
@@ -178,7 +179,7 @@ class LogitsProcessor(nn.Module):
|
|
178
179
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
179
180
|
last_hidden = hidden_states[last_index]
|
180
181
|
|
181
|
-
last_logits =
|
182
|
+
last_logits = self._get_logits(last_hidden, lm_head)
|
182
183
|
if self.do_tensor_parallel_all_gather:
|
183
184
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
184
185
|
last_logits = last_logits[:, : self.config.vocab_size].float()
|
@@ -229,7 +230,7 @@ class LogitsProcessor(nn.Module):
|
|
229
230
|
|
230
231
|
# Compute the logits and logprobs for all required tokens
|
231
232
|
states = torch.cat(states, dim=0)
|
232
|
-
all_logits =
|
233
|
+
all_logits = self._get_logits(states, lm_head)
|
233
234
|
if self.do_tensor_parallel_all_gather:
|
234
235
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
235
236
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
@@ -276,6 +277,19 @@ class LogitsProcessor(nn.Module):
|
|
276
277
|
output_top_logprobs=output_top_logprobs,
|
277
278
|
)
|
278
279
|
|
280
|
+
def _get_logits(
|
281
|
+
self,
|
282
|
+
hidden_states: torch.Tensor,
|
283
|
+
lm_head: VocabParallelEmbedding,
|
284
|
+
embedding_bias: Optional[torch.Tensor] = None,
|
285
|
+
) -> torch.Tensor:
|
286
|
+
if hasattr(lm_head, "weight"):
|
287
|
+
logits = torch.matmul(hidden_states, lm_head.weight.T)
|
288
|
+
else:
|
289
|
+
# GGUF models
|
290
|
+
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
291
|
+
return logits
|
292
|
+
|
279
293
|
|
280
294
|
def test():
|
281
295
|
all_logprobs = torch.tensor(
|
@@ -117,10 +117,44 @@ def fp8_get_quant_method(self, layer, prefix):
|
|
117
117
|
return None
|
118
118
|
|
119
119
|
|
120
|
+
def gptq_get_quant_method(self, layer, prefix):
|
121
|
+
from vllm.model_executor.layers.linear import LinearBase
|
122
|
+
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
123
|
+
GPTQMarlinLinearMethod,
|
124
|
+
GPTQMarlinMoEMethod,
|
125
|
+
)
|
126
|
+
|
127
|
+
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
128
|
+
|
129
|
+
if isinstance(layer, LinearBase):
|
130
|
+
return GPTQMarlinLinearMethod(self)
|
131
|
+
elif isinstance(layer, FusedMoE):
|
132
|
+
return GPTQMarlinMoEMethod(self)
|
133
|
+
return None
|
134
|
+
|
135
|
+
|
136
|
+
def awq_get_quant_method(self, layer, prefix):
|
137
|
+
from vllm.model_executor.layers.linear import LinearBase
|
138
|
+
from vllm.model_executor.layers.quantization.awq_marlin import (
|
139
|
+
AWQMarlinLinearMethod,
|
140
|
+
AWQMoEMethod,
|
141
|
+
)
|
142
|
+
|
143
|
+
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
144
|
+
|
145
|
+
if isinstance(layer, LinearBase):
|
146
|
+
return AWQMarlinLinearMethod(self)
|
147
|
+
elif isinstance(layer, FusedMoE):
|
148
|
+
return AWQMoEMethod(self)
|
149
|
+
return None
|
150
|
+
|
151
|
+
|
120
152
|
def apply_monkey_patches():
|
121
153
|
"""Apply all monkey patches in one place."""
|
122
154
|
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
|
123
155
|
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
156
|
+
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
157
|
+
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
124
158
|
|
125
159
|
|
126
160
|
# Apply patches when module is imported
|
sglang/srt/lora/lora.py
CHANGED
@@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
31
31
|
ParallelLMHead,
|
32
32
|
VocabParallelEmbedding,
|
33
33
|
)
|
34
|
-
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
35
34
|
|
36
35
|
from sglang.srt.layers.linear import (
|
37
36
|
ColumnParallelLinear,
|
@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import (
|
|
40
39
|
RowParallelLinear,
|
41
40
|
)
|
42
41
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
42
|
+
from sglang.srt.model_loader.loader import DefaultModelLoader
|
43
43
|
|
44
44
|
|
45
45
|
class BaseLayerWithLoRA(nn.Module):
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
import logging
|
17
17
|
import multiprocessing as mp
|
18
|
+
import signal
|
18
19
|
import threading
|
19
20
|
from enum import Enum, auto
|
20
21
|
|
22
|
+
import psutil
|
21
23
|
import zmq
|
22
24
|
|
23
25
|
from sglang.srt.managers.io_struct import (
|
@@ -26,13 +28,7 @@ from sglang.srt.managers.io_struct import (
|
|
26
28
|
)
|
27
29
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
28
30
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
29
|
-
from sglang.srt.utils import
|
30
|
-
bind_port,
|
31
|
-
configure_logger,
|
32
|
-
get_zmq_socket,
|
33
|
-
kill_parent_process,
|
34
|
-
suppress_other_loggers,
|
35
|
-
)
|
31
|
+
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
|
36
32
|
from sglang.utils import get_exception_traceback
|
37
33
|
|
38
34
|
logger = logging.getLogger(__name__)
|
@@ -235,7 +231,7 @@ def run_data_parallel_controller_process(
|
|
235
231
|
pipe_writer,
|
236
232
|
):
|
237
233
|
configure_logger(server_args)
|
238
|
-
|
234
|
+
parent_process = psutil.Process().parent()
|
239
235
|
|
240
236
|
try:
|
241
237
|
controller = DataParallelController(server_args, port_args)
|
@@ -244,6 +240,6 @@ def run_data_parallel_controller_process(
|
|
244
240
|
)
|
245
241
|
controller.event_loop()
|
246
242
|
except Exception:
|
247
|
-
|
248
|
-
logger.error(
|
249
|
-
|
243
|
+
traceback = get_exception_traceback()
|
244
|
+
logger.error(f"DataParallelController hit an exception: {traceback}")
|
245
|
+
parent_process.send_signal(signal.SIGQUIT)
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
import dataclasses
|
17
17
|
import logging
|
18
|
+
import signal
|
18
19
|
from collections import OrderedDict
|
19
20
|
from typing import List, Union
|
20
21
|
|
22
|
+
import psutil
|
21
23
|
import zmq
|
22
24
|
|
23
25
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
@@ -28,7 +30,7 @@ from sglang.srt.managers.io_struct import (
|
|
28
30
|
)
|
29
31
|
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
30
32
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
31
|
-
from sglang.srt.utils import configure_logger, get_zmq_socket
|
33
|
+
from sglang.srt.utils import configure_logger, get_zmq_socket
|
32
34
|
from sglang.utils import find_printable_text, get_exception_traceback
|
33
35
|
|
34
36
|
logger = logging.getLogger(__name__)
|
@@ -193,11 +195,12 @@ def run_detokenizer_process(
|
|
193
195
|
port_args: PortArgs,
|
194
196
|
):
|
195
197
|
configure_logger(server_args)
|
198
|
+
parent_process = psutil.Process().parent()
|
196
199
|
|
197
200
|
try:
|
198
201
|
manager = DetokenizerManager(server_args, port_args)
|
199
202
|
manager.event_loop()
|
200
203
|
except Exception:
|
201
|
-
|
202
|
-
logger.error(
|
203
|
-
|
204
|
+
traceback = get_exception_traceback()
|
205
|
+
logger.error(f"DetokenizerManager hit an exception: {traceback}")
|
206
|
+
parent_process.send_signal(signal.SIGQUIT)
|
@@ -338,7 +338,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
|
338
338
|
"pixel_values": pixel_values,
|
339
339
|
"image_hashes": image_hashes,
|
340
340
|
"image_sizes": image_sizes,
|
341
|
-
"modalities": request_obj.modalities,
|
341
|
+
"modalities": request_obj.modalities or ["image"],
|
342
342
|
"image_grid_thws": image_grid_thws,
|
343
343
|
}
|
344
344
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -352,7 +352,7 @@ class FlushCacheReq:
|
|
352
352
|
|
353
353
|
|
354
354
|
@dataclass
|
355
|
-
class
|
355
|
+
class UpdateWeightFromDiskReqInput:
|
356
356
|
# The model path with the new weights
|
357
357
|
model_path: str
|
358
358
|
# The format to load the weights
|
@@ -360,30 +360,66 @@ class UpdateWeightReqInput:
|
|
360
360
|
|
361
361
|
|
362
362
|
@dataclass
|
363
|
-
class
|
363
|
+
class UpdateWeightFromDiskReqOutput:
|
364
364
|
success: bool
|
365
365
|
message: str
|
366
366
|
|
367
367
|
|
368
368
|
@dataclass
|
369
|
-
class
|
370
|
-
|
371
|
-
|
369
|
+
class UpdateWeightsFromDistributedReqInput:
|
370
|
+
name: str
|
371
|
+
dtype: str
|
372
|
+
shape: List[int]
|
372
373
|
|
373
374
|
|
374
|
-
|
375
|
-
|
376
|
-
|
375
|
+
@dataclass
|
376
|
+
class UpdateWeightsFromDistributedReqOutput:
|
377
|
+
success: bool
|
378
|
+
message: str
|
377
379
|
|
378
380
|
|
379
381
|
@dataclass
|
380
|
-
class
|
381
|
-
|
382
|
+
class InitWeightsUpdateGroupReqInput:
|
383
|
+
# The master address
|
384
|
+
master_address: str
|
385
|
+
# The master port
|
386
|
+
master_port: int
|
387
|
+
# The rank offset
|
388
|
+
rank_offset: int
|
389
|
+
# The world size
|
390
|
+
world_size: int
|
391
|
+
# The group name
|
392
|
+
group_name: str = "weight_update_group"
|
393
|
+
# The backend
|
394
|
+
backend: str = "nccl"
|
395
|
+
|
396
|
+
|
397
|
+
@dataclass
|
398
|
+
class InitWeightsUpdateGroupReqOutput:
|
399
|
+
success: bool
|
400
|
+
message: str
|
382
401
|
|
383
402
|
|
384
403
|
@dataclass
|
385
|
-
class
|
386
|
-
|
404
|
+
class GetWeightsByNameReqInput:
|
405
|
+
name: str
|
406
|
+
truncate_size: int = 100
|
407
|
+
|
408
|
+
|
409
|
+
@dataclass
|
410
|
+
class GetWeightsByNameReqOutput:
|
411
|
+
parameter: list
|
412
|
+
|
413
|
+
|
414
|
+
@dataclass
|
415
|
+
class AbortReq:
|
416
|
+
# The request id
|
417
|
+
rid: str
|
418
|
+
|
419
|
+
|
420
|
+
class ProfileReq(Enum):
|
421
|
+
START_PROFILE = 1
|
422
|
+
STOP_PROFILE = 2
|
387
423
|
|
388
424
|
|
389
425
|
@dataclass
|
@@ -124,7 +124,7 @@ class FINISH_ABORT(BaseFinishReason):
|
|
124
124
|
class ImageInputs:
|
125
125
|
"""The image related inputs."""
|
126
126
|
|
127
|
-
pixel_values: torch.Tensor
|
127
|
+
pixel_values: Union[torch.Tensor, np.array]
|
128
128
|
image_hashes: Optional[list] = None
|
129
129
|
image_sizes: Optional[list] = None
|
130
130
|
image_offsets: Optional[list] = None
|
@@ -132,7 +132,7 @@ class ImageInputs:
|
|
132
132
|
modalities: Optional[list] = None
|
133
133
|
num_image_tokens: Optional[int] = None
|
134
134
|
|
135
|
-
|
135
|
+
# Llava related
|
136
136
|
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
137
137
|
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
138
138
|
|
@@ -141,19 +141,17 @@ class ImageInputs:
|
|
141
141
|
mrope_position_delta: Optional[torch.Tensor] = None
|
142
142
|
|
143
143
|
@staticmethod
|
144
|
-
def from_dict(obj
|
145
|
-
# Use image hash as fake token_ids, which is then used for prefix matching
|
144
|
+
def from_dict(obj: dict):
|
146
145
|
ret = ImageInputs(
|
147
146
|
pixel_values=obj["pixel_values"],
|
148
|
-
image_hashes=
|
147
|
+
image_hashes=obj["image_hashes"],
|
149
148
|
)
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
]
|
149
|
+
|
150
|
+
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
151
|
+
# Please note that if the `input_ids` is later used in the model forward,
|
152
|
+
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
153
|
+
# errors in cuda kernels. See also llava.py for example.
|
154
|
+
ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
|
157
155
|
|
158
156
|
optional_args = [
|
159
157
|
"image_sizes",
|
@@ -168,17 +166,16 @@ class ImageInputs:
|
|
168
166
|
|
169
167
|
return ret
|
170
168
|
|
171
|
-
def merge(self, other
|
169
|
+
def merge(self, other):
|
172
170
|
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
173
171
|
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
174
|
-
self.image_hashes += other.image_hashes
|
175
172
|
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
]
|
173
|
+
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
174
|
+
# Please note that if the `input_ids` is later used in the model forward,
|
175
|
+
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
176
|
+
# errors in cuda kernels. See also llava.py for example.
|
177
|
+
self.image_hashes += other.image_hashes
|
178
|
+
self.pad_values = [x % (1 << 30) for x in self.image_hashes]
|
182
179
|
|
183
180
|
optional_args = [
|
184
181
|
"image_sizes",
|
@@ -231,6 +228,7 @@ class Req:
|
|
231
228
|
self.tokenizer = None
|
232
229
|
self.finished_reason = None
|
233
230
|
self.stream = False
|
231
|
+
self.to_abort = False
|
234
232
|
|
235
233
|
# For incremental decoding
|
236
234
|
# ----- | --------- read_ids -------|
|
@@ -290,11 +288,11 @@ class Req:
|
|
290
288
|
# The number of cached tokens, that were already cached in the KV cache
|
291
289
|
self.cached_tokens = 0
|
292
290
|
|
293
|
-
def extend_image_inputs(self, image_inputs
|
291
|
+
def extend_image_inputs(self, image_inputs):
|
294
292
|
if self.image_inputs is None:
|
295
293
|
self.image_inputs = image_inputs
|
296
294
|
else:
|
297
|
-
self.image_inputs.merge(image_inputs
|
295
|
+
self.image_inputs.merge(image_inputs)
|
298
296
|
|
299
297
|
# whether request reached finished condition
|
300
298
|
def finished(self) -> bool:
|
@@ -368,6 +366,10 @@ class Req:
|
|
368
366
|
if self.finished():
|
369
367
|
return
|
370
368
|
|
369
|
+
if self.to_abort:
|
370
|
+
self.finished_reason = FINISH_ABORT()
|
371
|
+
return
|
372
|
+
|
371
373
|
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
372
374
|
self.finished_reason = FINISH_LENGTH(
|
373
375
|
length=self.sampling_params.max_new_tokens
|
@@ -741,20 +743,24 @@ class ScheduleBatch:
|
|
741
743
|
extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
|
742
744
|
self.device, non_blocking=True
|
743
745
|
)
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
746
|
+
if global_server_args_dict["attention_backend"] != "torch_native":
|
747
|
+
write_req_to_token_pool_triton[(bs,)](
|
748
|
+
self.req_to_token_pool.req_to_token,
|
749
|
+
self.req_pool_indices,
|
750
|
+
pre_lens,
|
751
|
+
self.seq_lens,
|
752
|
+
extend_lens,
|
753
|
+
self.out_cache_loc,
|
754
|
+
self.req_to_token_pool.req_to_token.shape[1],
|
755
|
+
)
|
756
|
+
else:
|
757
|
+
pt = 0
|
758
|
+
for i in range(bs):
|
759
|
+
self.req_to_token_pool.write(
|
760
|
+
(self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
|
761
|
+
self.out_cache_loc[pt : pt + self.extend_lens[i]],
|
762
|
+
)
|
763
|
+
pt += self.extend_lens[i]
|
758
764
|
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
759
765
|
|
760
766
|
if self.model_config.is_encoder_decoder:
|
@@ -142,7 +142,7 @@ class PrefillAdder:
|
|
142
142
|
|
143
143
|
self.req_states = None
|
144
144
|
self.can_run_list = []
|
145
|
-
self.
|
145
|
+
self.new_being_chunked_req = None
|
146
146
|
self.log_hit_tokens = 0
|
147
147
|
self.log_input_tokens = 0
|
148
148
|
|
@@ -182,7 +182,7 @@ class PrefillAdder:
|
|
182
182
|
self.log_hit_tokens += prefix_len
|
183
183
|
self.log_input_tokens += extend_input_len
|
184
184
|
|
185
|
-
def
|
185
|
+
def add_being_chunked_req(self, req: Req):
|
186
186
|
truncated = req.extend_input_len > self.rem_chunk_tokens
|
187
187
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
188
188
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
@@ -269,10 +269,13 @@ class PrefillAdder:
|
|
269
269
|
else:
|
270
270
|
# Chunked prefill
|
271
271
|
trunc_len = self.rem_chunk_tokens
|
272
|
+
if trunc_len == 0:
|
273
|
+
return AddReqResult.OTHER
|
274
|
+
|
272
275
|
req.extend_input_len = trunc_len
|
273
276
|
req.fill_ids = req.fill_ids[:trunc_len]
|
274
277
|
self.can_run_list.append(req)
|
275
|
-
self.
|
278
|
+
self.new_being_chunked_req = req
|
276
279
|
self._prefill_one_req(0, trunc_len, 0)
|
277
280
|
|
278
281
|
return self.budget_state()
|
@@ -326,7 +329,7 @@ class PrefillAdder:
|
|
326
329
|
req.extend_input_len = trunc_len
|
327
330
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
328
331
|
self.can_run_list.append(req)
|
329
|
-
self.
|
332
|
+
self.new_being_chunked_req = req
|
330
333
|
self.tree_cache.inc_lock_ref(req.last_node)
|
331
334
|
self._prefill_one_req(prefix_len, trunc_len, 0)
|
332
335
|
|