sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +472 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +125 -6
- sglang/check_env.py +3 -6
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +28 -17
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +47 -58
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +16 -13
- sglang/srt/layers/attention/flashinfer_backend.py +106 -54
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +25 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +17 -15
- sglang/srt/layers/logits_processor.py +23 -25
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +98 -27
- sglang/srt/managers/detokenizer_manager.py +13 -15
- sglang/srt/managers/io_struct.py +63 -21
- sglang/srt/managers/schedule_batch.py +154 -59
- sglang/srt/managers/schedule_policy.py +18 -16
- sglang/srt/managers/scheduler.py +278 -109
- sglang/srt/managers/session_controller.py +61 -0
- sglang/srt/managers/tokenizer_manager.py +63 -18
- sglang/srt/managers/tp_worker.py +25 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +63 -25
- sglang/srt/model_executor/forward_batch_info.py +128 -32
- sglang/srt/model_executor/model_runner.py +132 -64
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +162 -59
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +31 -25
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +14 -16
- sglang/srt/models/llavavid.py +14 -16
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +22 -20
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +107 -93
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +19 -17
- sglang/srt/openai_api/protocol.py +14 -16
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +61 -57
- sglang/srt/sampling/sampling_params.py +14 -16
- sglang/srt/server.py +86 -35
- sglang/srt/server_args.py +96 -80
- sglang/srt/utils.py +266 -68
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +38 -20
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +31 -20
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.5.post2.dist-info/RECORD +0 -156
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
17
15
|
|
18
16
|
import dataclasses
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""
|
17
15
|
The definition of objects transfered between different
|
18
16
|
processes (TokenizerManager, DetokenizerManager, Controller).
|
@@ -21,7 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
|
21
19
|
import uuid
|
22
20
|
from dataclasses import dataclass
|
23
21
|
from enum import Enum
|
24
|
-
from typing import Dict, List, Optional, Union
|
22
|
+
from typing import Dict, List, Optional, Tuple, Union
|
25
23
|
|
26
24
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
27
25
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -31,8 +29,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
|
31
29
|
class GenerateReqInput:
|
32
30
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
33
31
|
text: Optional[Union[List[str], str]] = None
|
34
|
-
# The token ids for text; one can either
|
32
|
+
# The token ids for text; one can specify either text or input_ids
|
35
33
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
34
|
+
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
|
35
|
+
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
36
36
|
# The image input. It can be a file name, a url, or base64 encoded string.
|
37
37
|
# See also python/sglang/srt/utils.py:load_image.
|
38
38
|
image_data: Optional[Union[List[str], str]] = None
|
@@ -56,11 +56,22 @@ class GenerateReqInput:
|
|
56
56
|
# LoRA related
|
57
57
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
58
58
|
|
59
|
+
# Session id info for continual prompting
|
60
|
+
session: Optional[
|
61
|
+
Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
|
62
|
+
] = None
|
63
|
+
|
59
64
|
def normalize_batch_and_arguments(self):
|
60
|
-
if (
|
61
|
-
self.text is
|
65
|
+
if (
|
66
|
+
self.text is None and self.input_ids is None and self.input_embeds is None
|
67
|
+
) or (
|
68
|
+
self.text is not None
|
69
|
+
and self.input_ids is not None
|
70
|
+
and self.input_embeds is not None
|
62
71
|
):
|
63
|
-
raise ValueError(
|
72
|
+
raise ValueError(
|
73
|
+
"Either text, input_ids or input_embeds should be provided."
|
74
|
+
)
|
64
75
|
|
65
76
|
# Derive the batch size
|
66
77
|
if self.text is not None:
|
@@ -70,13 +81,21 @@ class GenerateReqInput:
|
|
70
81
|
else:
|
71
82
|
self.is_single = False
|
72
83
|
self.batch_size = len(self.text)
|
73
|
-
|
84
|
+
self.input_embeds = None
|
85
|
+
elif self.input_ids is not None:
|
74
86
|
if isinstance(self.input_ids[0], int):
|
75
87
|
self.is_single = True
|
76
88
|
self.batch_size = 1
|
77
89
|
else:
|
78
90
|
self.is_single = False
|
79
91
|
self.batch_size = len(self.input_ids)
|
92
|
+
self.input_embeds = None
|
93
|
+
else:
|
94
|
+
if isinstance(self.input_embeds[0][0], float):
|
95
|
+
self.is_single = True
|
96
|
+
self.batch_size = 1
|
97
|
+
else:
|
98
|
+
self.batch_size = len(self.input_embeds)
|
80
99
|
|
81
100
|
# Handle parallel sampling
|
82
101
|
# When parallel sampling is used, we always treat the input as a batch.
|
@@ -199,6 +218,12 @@ class TokenizedGenerateReqInput:
|
|
199
218
|
|
200
219
|
# LoRA related
|
201
220
|
lora_path: Optional[str] = None # None means just use the base model
|
221
|
+
# The input embeds
|
222
|
+
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
223
|
+
|
224
|
+
# Session id info for continual prompting
|
225
|
+
session_id: Optional[str] = None
|
226
|
+
session_rid: Optional[str] = None
|
202
227
|
|
203
228
|
|
204
229
|
@dataclass
|
@@ -211,6 +236,8 @@ class EmbeddingReqInput:
|
|
211
236
|
rid: Optional[Union[List[str], str]] = None
|
212
237
|
# Dummy sampling params for compatibility
|
213
238
|
sampling_params: Union[List[Dict], Dict] = None
|
239
|
+
# Dummy input embeds for compatibility
|
240
|
+
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
214
241
|
|
215
242
|
def normalize_batch_and_arguments(self):
|
216
243
|
if (self.text is None and self.input_ids is None) or (
|
@@ -357,3 +384,18 @@ class GetMemPoolSizeReq:
|
|
357
384
|
@dataclass
|
358
385
|
class GetMemPoolSizeReqOutput:
|
359
386
|
size: int
|
387
|
+
|
388
|
+
|
389
|
+
@dataclass
|
390
|
+
class OpenSessionReqInput:
|
391
|
+
capacity_of_str_len: int
|
392
|
+
|
393
|
+
|
394
|
+
@dataclass
|
395
|
+
class CloseSessionReqInput:
|
396
|
+
session_id: str
|
397
|
+
|
398
|
+
|
399
|
+
@dataclass
|
400
|
+
class OpenSessionReqOutput:
|
401
|
+
session_id: str
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""
|
17
15
|
Store information about requests and batches.
|
18
16
|
|
@@ -34,6 +32,8 @@ import logging
|
|
34
32
|
from typing import List, Optional, Tuple, Union
|
35
33
|
|
36
34
|
import torch
|
35
|
+
import triton
|
36
|
+
import triton.language as tl
|
37
37
|
|
38
38
|
from sglang.global_config import global_config
|
39
39
|
from sglang.srt.configs.model_config import ModelConfig
|
@@ -55,7 +55,8 @@ global_server_args_dict = {
|
|
55
55
|
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
56
56
|
"disable_mla": ServerArgs.disable_mla,
|
57
57
|
"torchao_config": ServerArgs.torchao_config,
|
58
|
-
"
|
58
|
+
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
59
|
+
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
59
60
|
}
|
60
61
|
|
61
62
|
|
@@ -133,6 +134,7 @@ class ImageInputs:
|
|
133
134
|
image_embeds: Optional[List[torch.Tensor]] = None
|
134
135
|
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
135
136
|
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
137
|
+
|
136
138
|
# QWen2-VL related
|
137
139
|
image_grid_thws: List[Tuple[int, int, int]] = None
|
138
140
|
mrope_position_delta: Optional[torch.Tensor] = None
|
@@ -176,6 +178,8 @@ class Req:
|
|
176
178
|
origin_input_ids: Tuple[int],
|
177
179
|
sampling_params: SamplingParams,
|
178
180
|
lora_path: Optional[str] = None,
|
181
|
+
input_embeds: Optional[List[List[float]]] = None,
|
182
|
+
session_id: Optional[str] = None,
|
179
183
|
):
|
180
184
|
# Input and output info
|
181
185
|
self.rid = rid
|
@@ -184,11 +188,13 @@ class Req:
|
|
184
188
|
self.origin_input_ids = origin_input_ids
|
185
189
|
self.output_ids = [] # Each decode stage's output ids
|
186
190
|
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
191
|
+
self.session_id = session_id
|
187
192
|
|
188
193
|
self.sampling_params = sampling_params
|
189
194
|
self.lora_path = lora_path
|
195
|
+
self.input_embeds = input_embeds
|
190
196
|
|
191
|
-
# Memory info
|
197
|
+
# Memory pool info
|
192
198
|
self.req_pool_idx = None
|
193
199
|
|
194
200
|
# Check finish
|
@@ -425,7 +431,7 @@ bid = 0
|
|
425
431
|
|
426
432
|
@dataclasses.dataclass
|
427
433
|
class ScheduleBatch:
|
428
|
-
"""Store all inforamtion of a batch."""
|
434
|
+
"""Store all inforamtion of a batch on the scheduler."""
|
429
435
|
|
430
436
|
# Request, memory pool, and cache
|
431
437
|
reqs: List[Req]
|
@@ -433,14 +439,18 @@ class ScheduleBatch:
|
|
433
439
|
token_to_kv_pool: BaseTokenToKVPool = None
|
434
440
|
tree_cache: BasePrefixCache = None
|
435
441
|
|
436
|
-
#
|
442
|
+
# Batch configs
|
437
443
|
model_config: ModelConfig = None
|
438
|
-
|
439
444
|
forward_mode: ForwardMode = None
|
445
|
+
enable_overlap: bool = False
|
446
|
+
|
447
|
+
# Sampling info
|
440
448
|
sampling_info: SamplingBatchInfo = None
|
449
|
+
next_batch_sampling_info: SamplingBatchInfo = None
|
441
450
|
|
442
451
|
# Batched arguments to model runner
|
443
452
|
input_ids: torch.Tensor = None
|
453
|
+
input_embeds: torch.Tensor = None
|
444
454
|
req_pool_indices: torch.Tensor = None
|
445
455
|
seq_lens: torch.Tensor = None
|
446
456
|
# The output locations of the KV cache
|
@@ -450,6 +460,10 @@ class ScheduleBatch:
|
|
450
460
|
# The sum of all sequence lengths
|
451
461
|
seq_lens_sum: int = None
|
452
462
|
|
463
|
+
# For DP attention
|
464
|
+
global_num_tokens: Optional[List[int]] = None
|
465
|
+
can_run_dp_cuda_graph: bool = False
|
466
|
+
|
453
467
|
# For processing logprobs
|
454
468
|
return_logprob: bool = False
|
455
469
|
top_logprobs_nums: Optional[List[int]] = None
|
@@ -459,6 +473,7 @@ class ScheduleBatch:
|
|
459
473
|
extend_lens: List[int] = None
|
460
474
|
extend_num_tokens: int = None
|
461
475
|
decoding_reqs: List[Req] = None
|
476
|
+
extend_logprob_start_lens: List[int] = None
|
462
477
|
|
463
478
|
# For encoder-decoder
|
464
479
|
encoder_cached: Optional[List[bool]] = None
|
@@ -479,10 +494,11 @@ class ScheduleBatch:
|
|
479
494
|
def init_new(
|
480
495
|
cls,
|
481
496
|
reqs: List[Req],
|
482
|
-
req_to_token_pool,
|
483
|
-
token_to_kv_pool,
|
484
|
-
tree_cache,
|
485
|
-
model_config,
|
497
|
+
req_to_token_pool: ReqToTokenPool,
|
498
|
+
token_to_kv_pool: ReqToTokenPool,
|
499
|
+
tree_cache: BasePrefixCache,
|
500
|
+
model_config: ModelConfig,
|
501
|
+
enable_overlap: bool,
|
486
502
|
):
|
487
503
|
return cls(
|
488
504
|
reqs=reqs,
|
@@ -490,6 +506,7 @@ class ScheduleBatch:
|
|
490
506
|
token_to_kv_pool=token_to_kv_pool,
|
491
507
|
tree_cache=tree_cache,
|
492
508
|
model_config=model_config,
|
509
|
+
enable_overlap=enable_overlap,
|
493
510
|
return_logprob=any(req.return_logprob for req in reqs),
|
494
511
|
has_stream=any(req.stream for req in reqs),
|
495
512
|
has_grammar=any(req.grammar for req in reqs),
|
@@ -502,7 +519,7 @@ class ScheduleBatch:
|
|
502
519
|
def is_empty(self):
|
503
520
|
return len(self.reqs) == 0
|
504
521
|
|
505
|
-
def alloc_req_slots(self, num_reqs):
|
522
|
+
def alloc_req_slots(self, num_reqs: int):
|
506
523
|
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
507
524
|
if req_pool_indices is None:
|
508
525
|
raise RuntimeError(
|
@@ -588,14 +605,14 @@ class ScheduleBatch:
|
|
588
605
|
)
|
589
606
|
|
590
607
|
if not decoder_out_cache_loc:
|
591
|
-
self.out_cache_loc = torch.
|
608
|
+
self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
|
592
609
|
self.device, non_blocking=True
|
593
610
|
)
|
594
611
|
else:
|
595
612
|
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
|
596
613
|
|
597
614
|
if not encoder_out_cache_loc:
|
598
|
-
self.encoder_out_cache_loc = torch.
|
615
|
+
self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
|
599
616
|
self.device, non_blocking=True
|
600
617
|
)
|
601
618
|
else:
|
@@ -611,11 +628,14 @@ class ScheduleBatch:
|
|
611
628
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
612
629
|
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
613
630
|
seq_lens = []
|
631
|
+
pre_lens = []
|
614
632
|
|
615
633
|
# Allocate memory
|
616
634
|
req_pool_indices = self.alloc_req_slots(bs)
|
617
635
|
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
618
636
|
|
637
|
+
input_embeds = []
|
638
|
+
|
619
639
|
pt = 0
|
620
640
|
for i, req in enumerate(reqs):
|
621
641
|
already_computed = (
|
@@ -634,10 +654,11 @@ class ScheduleBatch:
|
|
634
654
|
self.req_to_token_pool.write(
|
635
655
|
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
636
656
|
)
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
657
|
+
|
658
|
+
# If input_embeds are available, store them
|
659
|
+
if req.input_embeds is not None:
|
660
|
+
# If req.input_embeds is already a list, append its content directly
|
661
|
+
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
641
662
|
|
642
663
|
# Compute the relative logprob_start_len in an extend batch
|
643
664
|
if req.logprob_start_len >= pre_len:
|
@@ -648,8 +669,8 @@ class ScheduleBatch:
|
|
648
669
|
extend_logprob_start_len = req.extend_input_len - 1
|
649
670
|
|
650
671
|
req.extend_logprob_start_len = extend_logprob_start_len
|
651
|
-
pt += req.extend_input_len
|
652
672
|
req.is_retracted = False
|
673
|
+
pre_lens.append(pre_len)
|
653
674
|
|
654
675
|
# Set fields
|
655
676
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
@@ -661,6 +682,11 @@ class ScheduleBatch:
|
|
661
682
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
662
683
|
self.device, non_blocking=True
|
663
684
|
)
|
685
|
+
self.input_embeds = (
|
686
|
+
torch.tensor(input_embeds).to(self.device, non_blocking=True)
|
687
|
+
if input_embeds
|
688
|
+
else None
|
689
|
+
)
|
664
690
|
|
665
691
|
self.out_cache_loc = out_cache_loc
|
666
692
|
|
@@ -672,13 +698,37 @@ class ScheduleBatch:
|
|
672
698
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
673
699
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
674
700
|
|
701
|
+
# Write to req_to_token_pool
|
702
|
+
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
|
703
|
+
self.device, non_blocking=True
|
704
|
+
)
|
705
|
+
extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
|
706
|
+
self.device, non_blocking=True
|
707
|
+
)
|
708
|
+
write_req_to_token_pool_triton[(bs,)](
|
709
|
+
self.req_to_token_pool.req_to_token,
|
710
|
+
self.req_pool_indices,
|
711
|
+
pre_lens,
|
712
|
+
self.seq_lens,
|
713
|
+
extend_lens,
|
714
|
+
self.out_cache_loc,
|
715
|
+
self.req_to_token_pool.req_to_token.shape[1],
|
716
|
+
)
|
717
|
+
# The triton kernel is equivalent to the following python code.
|
718
|
+
# self.req_to_token_pool.write(
|
719
|
+
# (req.req_pool_idx, slice(pre_len, seq_len)),
|
720
|
+
# out_cache_loc[pt : pt + req.extend_input_len],
|
721
|
+
# )
|
722
|
+
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
723
|
+
|
675
724
|
if self.model_config.is_encoder_decoder:
|
676
725
|
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
677
726
|
|
727
|
+
# Build sampling info
|
678
728
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
679
729
|
self,
|
680
730
|
self.model_config.vocab_size,
|
681
|
-
|
731
|
+
enable_overlap_schedule=self.enable_overlap,
|
682
732
|
)
|
683
733
|
|
684
734
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
@@ -695,16 +745,20 @@ class ScheduleBatch:
|
|
695
745
|
self.merge_batch(running_batch)
|
696
746
|
self.input_ids = input_ids
|
697
747
|
self.out_cache_loc = out_cache_loc
|
698
|
-
|
748
|
+
|
749
|
+
# For overlap scheduler, the output_ids has one step delay
|
750
|
+
delta = 0 if self.enable_overlap else -1
|
699
751
|
|
700
752
|
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
701
753
|
self.prefix_lens.extend(
|
702
754
|
[
|
703
|
-
len(r.origin_input_ids) + len(r.output_ids)
|
755
|
+
len(r.origin_input_ids) + len(r.output_ids) + delta
|
704
756
|
for r in running_batch.reqs
|
705
757
|
]
|
706
758
|
)
|
707
759
|
self.extend_lens.extend([1] * running_bs)
|
760
|
+
self.extend_num_tokens += running_bs
|
761
|
+
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
708
762
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
709
763
|
|
710
764
|
def check_decode_mem(self):
|
@@ -720,6 +774,7 @@ class ScheduleBatch:
|
|
720
774
|
return False
|
721
775
|
|
722
776
|
def retract_decode(self):
|
777
|
+
"""Retract the decoding requests when there is not enough memory."""
|
723
778
|
sorted_indices = [i for i in range(len(self.reqs))]
|
724
779
|
|
725
780
|
# TODO(lsyin): improve retraction policy for radix cache
|
@@ -858,15 +913,21 @@ class ScheduleBatch:
|
|
858
913
|
# Reset the encoder cached status
|
859
914
|
self.encoder_cached = [True] * len(self.reqs)
|
860
915
|
|
861
|
-
def
|
916
|
+
def prepare_for_idle(self):
|
917
|
+
self.forward_mode = ForwardMode.IDLE
|
918
|
+
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
919
|
+
self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
920
|
+
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
921
|
+
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
922
|
+
self.seq_lens_sum = 0
|
923
|
+
self.extend_num_tokens = 0
|
924
|
+
|
925
|
+
def prepare_for_decode(self):
|
862
926
|
self.forward_mode = ForwardMode.DECODE
|
863
927
|
|
864
928
|
self.input_ids = self.output_ids
|
865
929
|
self.output_ids = None
|
866
|
-
|
867
|
-
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
868
|
-
self.input_ids
|
869
|
-
)
|
930
|
+
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
|
870
931
|
|
871
932
|
# Alloc mem
|
872
933
|
bs = len(self.reqs)
|
@@ -878,7 +939,7 @@ class ScheduleBatch:
|
|
878
939
|
else:
|
879
940
|
locs = self.seq_lens
|
880
941
|
|
881
|
-
if enable_overlap:
|
942
|
+
if self.enable_overlap:
|
882
943
|
# Do not use in-place operations in the overlap mode
|
883
944
|
self.req_to_token_pool.write(
|
884
945
|
(self.req_pool_indices, locs), self.out_cache_loc
|
@@ -969,17 +1030,18 @@ class ScheduleBatch:
|
|
969
1030
|
self.has_grammar = self.has_grammar or other.has_grammar
|
970
1031
|
|
971
1032
|
def get_model_worker_batch(self):
|
972
|
-
if self.forward_mode.is_decode():
|
1033
|
+
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
973
1034
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
974
1035
|
else:
|
975
1036
|
extend_seq_lens = self.extend_lens
|
976
1037
|
extend_prefix_lens = self.prefix_lens
|
977
1038
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
978
1039
|
|
979
|
-
if self.
|
980
|
-
|
981
|
-
|
982
|
-
|
1040
|
+
if self.sampling_info:
|
1041
|
+
if self.has_grammar:
|
1042
|
+
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
1043
|
+
else:
|
1044
|
+
self.sampling_info.grammars = None
|
983
1045
|
|
984
1046
|
global bid
|
985
1047
|
bid += 1
|
@@ -995,6 +1057,8 @@ class ScheduleBatch:
|
|
995
1057
|
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
|
996
1058
|
return_logprob=self.return_logprob,
|
997
1059
|
top_logprobs_nums=self.top_logprobs_nums,
|
1060
|
+
global_num_tokens=self.global_num_tokens,
|
1061
|
+
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
998
1062
|
extend_num_tokens=self.extend_num_tokens,
|
999
1063
|
extend_seq_lens=extend_seq_lens,
|
1000
1064
|
extend_prefix_lens=extend_prefix_lens,
|
@@ -1006,6 +1070,7 @@ class ScheduleBatch:
|
|
1006
1070
|
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
1007
1071
|
lora_paths=[req.lora_path for req in self.reqs],
|
1008
1072
|
sampling_info=self.sampling_info,
|
1073
|
+
input_embeds=self.input_embeds,
|
1009
1074
|
)
|
1010
1075
|
|
1011
1076
|
def copy(self):
|
@@ -1051,6 +1116,10 @@ class ModelWorkerBatch:
|
|
1051
1116
|
return_logprob: bool
|
1052
1117
|
top_logprobs_nums: Optional[List[int]]
|
1053
1118
|
|
1119
|
+
# For DP attention
|
1120
|
+
global_num_tokens: Optional[List[int]]
|
1121
|
+
can_run_dp_cuda_graph: bool
|
1122
|
+
|
1054
1123
|
# For extend
|
1055
1124
|
extend_num_tokens: Optional[int]
|
1056
1125
|
extend_seq_lens: Optional[List[int]]
|
@@ -1072,16 +1141,42 @@ class ModelWorkerBatch:
|
|
1072
1141
|
# Sampling info
|
1073
1142
|
sampling_info: SamplingBatchInfo
|
1074
1143
|
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1144
|
+
# The input Embeds
|
1145
|
+
input_embeds: Optional[torch.tensor] = None
|
1146
|
+
|
1147
|
+
|
1148
|
+
@triton.jit
|
1149
|
+
def write_req_to_token_pool_triton(
|
1150
|
+
req_to_token_ptr, # [max_batch, max_context_len]
|
1151
|
+
req_pool_indices,
|
1152
|
+
pre_lens,
|
1153
|
+
seq_lens,
|
1154
|
+
extend_lens,
|
1155
|
+
out_cache_loc,
|
1156
|
+
req_to_token_ptr_stride: tl.constexpr,
|
1157
|
+
):
|
1158
|
+
BLOCK_SIZE: tl.constexpr = 512
|
1159
|
+
pid = tl.program_id(0)
|
1160
|
+
|
1161
|
+
req_pool_index = tl.load(req_pool_indices + pid)
|
1162
|
+
pre_len = tl.load(pre_lens + pid)
|
1163
|
+
seq_len = tl.load(seq_lens + pid)
|
1164
|
+
|
1165
|
+
# TODO: optimize this?
|
1166
|
+
cumsum_start = 0
|
1167
|
+
for i in range(pid):
|
1168
|
+
cumsum_start += tl.load(extend_lens + i)
|
1169
|
+
|
1170
|
+
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
|
1171
|
+
for i in range(num_loop):
|
1172
|
+
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
1173
|
+
mask = offset < (seq_len - pre_len)
|
1174
|
+
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
|
1175
|
+
tl.store(
|
1176
|
+
req_to_token_ptr
|
1177
|
+
+ req_pool_index * req_to_token_ptr_stride
|
1178
|
+
+ offset
|
1179
|
+
+ pre_len,
|
1180
|
+
value,
|
1181
|
+
mask=mask,
|
1182
|
+
)
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Request scheduler policy"""
|
17
15
|
|
18
16
|
import os
|
@@ -302,7 +300,11 @@ class PrefillAdder:
|
|
302
300
|
if (
|
303
301
|
self.rem_chunk_tokens is None
|
304
302
|
or input_tokens <= self.rem_chunk_tokens
|
305
|
-
or (
|
303
|
+
or (
|
304
|
+
req.return_logprob
|
305
|
+
and req.normalized_prompt_logprob is None
|
306
|
+
and req.logprob_start_len != len(req.origin_input_ids) - 1
|
307
|
+
)
|
306
308
|
):
|
307
309
|
# Non-chunked prefill
|
308
310
|
self.can_run_list.append(req)
|