sglang 0.3.5.post1__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +337 -0
- sglang/bench_one_batch.py +474 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +115 -31
- sglang/check_env.py +3 -6
- sglang/srt/constrained/base_grammar_backend.py +4 -3
- sglang/srt/constrained/outlines_backend.py +39 -26
- sglang/srt/constrained/xgrammar_backend.py +58 -14
- sglang/srt/layers/activation.py +3 -0
- sglang/srt/layers/attention/flashinfer_backend.py +93 -48
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/custom_op_util.py +26 -0
- sglang/srt/layers/fused_moe/fused_moe.py +11 -4
- sglang/srt/layers/fused_moe/patch.py +4 -2
- sglang/srt/layers/layernorm.py +4 -0
- sglang/srt/layers/logits_processor.py +10 -10
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +74 -9
- sglang/srt/managers/detokenizer_manager.py +1 -14
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/schedule_batch.py +104 -38
- sglang/srt/managers/schedule_policy.py +5 -1
- sglang/srt/managers/scheduler.py +210 -56
- sglang/srt/managers/session_controller.py +62 -0
- sglang/srt/managers/tokenizer_manager.py +38 -0
- sglang/srt/managers/tp_worker.py +12 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
- sglang/srt/model_executor/cuda_graph_runner.py +43 -6
- sglang/srt/model_executor/forward_batch_info.py +109 -15
- sglang/srt/model_executor/model_runner.py +102 -43
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/deepseek_v2.py +147 -44
- sglang/srt/models/gemma2.py +9 -8
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/torch_native_llama.py +94 -78
- sglang/srt/openai_api/adapter.py +11 -4
- sglang/srt/openai_api/protocol.py +30 -27
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +58 -57
- sglang/srt/sampling/sampling_params.py +3 -3
- sglang/srt/server.py +29 -2
- sglang/srt/server_args.py +97 -60
- sglang/srt/utils.py +103 -51
- sglang/test/runners.py +25 -6
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +33 -22
- sglang/version.py +1 -1
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/RECORD +62 -56
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ import os
|
|
23
23
|
import signal
|
24
24
|
import sys
|
25
25
|
import time
|
26
|
+
import uuid
|
26
27
|
from typing import Dict, List, Optional, Tuple, Union
|
27
28
|
|
28
29
|
import fastapi
|
@@ -42,11 +43,14 @@ from sglang.srt.managers.io_struct import (
|
|
42
43
|
BatchEmbeddingOut,
|
43
44
|
BatchStrOut,
|
44
45
|
BatchTokenIDOut,
|
46
|
+
CloseSessionReqInput,
|
45
47
|
EmbeddingReqInput,
|
46
48
|
FlushCacheReq,
|
47
49
|
GenerateReqInput,
|
48
50
|
GetMemPoolSizeReq,
|
49
51
|
GetMemPoolSizeReqOutput,
|
52
|
+
OpenSessionReqInput,
|
53
|
+
OpenSessionReqOutput,
|
50
54
|
ProfileReq,
|
51
55
|
TokenizedEmbeddingReqInput,
|
52
56
|
TokenizedGenerateReqInput,
|
@@ -146,6 +150,9 @@ class TokenizerManager:
|
|
146
150
|
self.model_update_lock = asyncio.Lock()
|
147
151
|
self.model_update_result = None
|
148
152
|
|
153
|
+
# For session info
|
154
|
+
self.session_futures = {} # session_id -> asyncio event
|
155
|
+
|
149
156
|
# Others
|
150
157
|
self.gracefully_exit = False
|
151
158
|
|
@@ -211,6 +218,8 @@ class TokenizerManager:
|
|
211
218
|
return_logprob = obj.return_logprob
|
212
219
|
logprob_start_len = obj.logprob_start_len
|
213
220
|
top_logprobs_num = obj.top_logprobs_num
|
221
|
+
session_id = obj.session_id
|
222
|
+
session_rid = obj.session_rid
|
214
223
|
|
215
224
|
if len(input_ids) >= self.context_len:
|
216
225
|
raise ValueError(
|
@@ -236,6 +245,8 @@ class TokenizerManager:
|
|
236
245
|
top_logprobs_num,
|
237
246
|
obj.stream,
|
238
247
|
obj.lora_path,
|
248
|
+
session_id=session_id,
|
249
|
+
session_rid=session_rid,
|
239
250
|
)
|
240
251
|
elif isinstance(obj, EmbeddingReqInput):
|
241
252
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
@@ -451,6 +462,26 @@ class TokenizerManager:
|
|
451
462
|
else:
|
452
463
|
return False, "Another update is in progress. Please try again later."
|
453
464
|
|
465
|
+
async def open_session(
|
466
|
+
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
467
|
+
):
|
468
|
+
if self.to_create_loop:
|
469
|
+
self.create_handle_loop()
|
470
|
+
|
471
|
+
session_id = uuid.uuid4().hex
|
472
|
+
obj.session_id = session_id
|
473
|
+
self.send_to_scheduler.send_pyobj(obj)
|
474
|
+
self.session_futures[session_id] = asyncio.Future()
|
475
|
+
session_id = await self.session_futures[session_id]
|
476
|
+
del self.session_futures[session_id]
|
477
|
+
return session_id
|
478
|
+
|
479
|
+
async def close_session(
|
480
|
+
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
481
|
+
):
|
482
|
+
assert not self.to_create_loop, "close session should not be the first request"
|
483
|
+
await self.send_to_scheduler.send_pyobj(obj)
|
484
|
+
|
454
485
|
def create_abort_task(self, obj: GenerateReqInput):
|
455
486
|
# Abort the request if the client is disconnected.
|
456
487
|
async def abort_request():
|
@@ -521,6 +552,11 @@ class TokenizerManager:
|
|
521
552
|
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
522
553
|
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
|
523
554
|
continue
|
555
|
+
elif isinstance(recv_obj, OpenSessionReqOutput):
|
556
|
+
self.session_futures[recv_obj.session_id].set_result(
|
557
|
+
recv_obj.session_id
|
558
|
+
)
|
559
|
+
continue
|
524
560
|
|
525
561
|
assert isinstance(
|
526
562
|
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
@@ -536,11 +572,13 @@ class TokenizerManager:
|
|
536
572
|
out_dict = {
|
537
573
|
"text": recv_obj.output_strs[i],
|
538
574
|
"meta_info": recv_obj.meta_info[i],
|
575
|
+
"session_id": recv_obj.session_ids[i],
|
539
576
|
}
|
540
577
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
541
578
|
out_dict = {
|
542
579
|
"token_ids": recv_obj.output_ids[i],
|
543
580
|
"meta_info": recv_obj.meta_info[i],
|
581
|
+
"session_id": recv_obj.session_ids[i],
|
544
582
|
}
|
545
583
|
else:
|
546
584
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -16,6 +16,7 @@ limitations under the License.
|
|
16
16
|
"""A tensor parallel worker."""
|
17
17
|
|
18
18
|
import logging
|
19
|
+
import threading
|
19
20
|
from typing import Optional
|
20
21
|
|
21
22
|
from sglang.srt.configs.model_config import ModelConfig
|
@@ -134,9 +135,19 @@ class TpModelWorker:
|
|
134
135
|
self.model_runner.token_to_kv_pool,
|
135
136
|
)
|
136
137
|
|
137
|
-
def
|
138
|
+
def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
|
139
|
+
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
140
|
+
self.model_runner.forward(forward_batch)
|
141
|
+
|
142
|
+
def forward_batch_generation(
|
143
|
+
self,
|
144
|
+
model_worker_batch: ModelWorkerBatch,
|
145
|
+
launch_done: Optional[threading.Event] = None,
|
146
|
+
):
|
138
147
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
139
148
|
logits_output = self.model_runner.forward(forward_batch)
|
149
|
+
if launch_done:
|
150
|
+
launch_done.set()
|
140
151
|
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
141
152
|
return logits_output, next_token_ids
|
142
153
|
|
@@ -15,9 +15,9 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""A tensor parallel worker."""
|
17
17
|
|
18
|
+
import dataclasses
|
18
19
|
import logging
|
19
20
|
import threading
|
20
|
-
import time
|
21
21
|
from queue import Queue
|
22
22
|
from typing import Optional
|
23
23
|
|
@@ -26,7 +26,6 @@ import torch
|
|
26
26
|
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
27
27
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
28
28
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
29
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
30
29
|
from sglang.srt.server_args import ServerArgs
|
31
30
|
|
32
31
|
logger = logging.getLogger(__name__)
|
@@ -56,6 +55,7 @@ class TpModelWorkerClient:
|
|
56
55
|
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
|
57
56
|
self.max_running_requests = self.worker.max_running_requests
|
58
57
|
self.device = self.worker.device
|
58
|
+
self.gpu_id = gpu_id
|
59
59
|
|
60
60
|
# Init future mappings
|
61
61
|
self.future_token_ids_ct = 0
|
@@ -73,12 +73,6 @@ class TpModelWorkerClient:
|
|
73
73
|
)
|
74
74
|
self.forward_thread.start()
|
75
75
|
|
76
|
-
self.copy_queue = Queue()
|
77
|
-
self.copy_thread = threading.Thread(
|
78
|
-
target=self.copy_thread_func,
|
79
|
-
)
|
80
|
-
self.copy_thread.start()
|
81
|
-
|
82
76
|
def get_worker_info(self):
|
83
77
|
return self.worker.get_worker_info()
|
84
78
|
|
@@ -98,15 +92,25 @@ class TpModelWorkerClient:
|
|
98
92
|
with torch.cuda.stream(self.forward_stream):
|
99
93
|
self.forward_thread_func_()
|
100
94
|
|
101
|
-
@torch.
|
95
|
+
@torch.no_grad()
|
102
96
|
def forward_thread_func_(self):
|
97
|
+
batch_pt = 0
|
98
|
+
batch_lists = [None] * 2
|
99
|
+
|
103
100
|
while True:
|
104
|
-
self.has_inflight_batch = False
|
105
101
|
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
106
102
|
if not model_worker_batch:
|
107
103
|
break
|
108
|
-
|
109
|
-
|
104
|
+
|
105
|
+
# Keep a reference of model_worker_batch by storing it into a list.
|
106
|
+
# Otherwise, the tensor members of model_worker_batch will be released
|
107
|
+
# by pytorch and cause CUDA illegal memory access errors.
|
108
|
+
batch_lists[batch_pt % 2] = model_worker_batch
|
109
|
+
batch_pt += 1
|
110
|
+
|
111
|
+
# Create event
|
112
|
+
self.launch_done = threading.Event()
|
113
|
+
copy_done = torch.cuda.Event()
|
110
114
|
|
111
115
|
# Resolve future tokens in the input
|
112
116
|
input_ids = model_worker_batch.input_ids
|
@@ -114,7 +118,7 @@ class TpModelWorkerClient:
|
|
114
118
|
|
115
119
|
# Run forward
|
116
120
|
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
117
|
-
model_worker_batch
|
121
|
+
model_worker_batch, self.launch_done
|
118
122
|
)
|
119
123
|
|
120
124
|
# Update the future token ids map
|
@@ -139,44 +143,45 @@ class TpModelWorkerClient:
|
|
139
143
|
)
|
140
144
|
)
|
141
145
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
142
|
-
|
143
|
-
copy_event.record()
|
146
|
+
copy_done.record()
|
144
147
|
|
145
|
-
self.
|
146
|
-
self.copy_queue.put((copy_event, logits_output, next_token_ids))
|
148
|
+
self.output_queue.put((copy_done, logits_output, next_token_ids))
|
147
149
|
|
148
|
-
def
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
break
|
153
|
-
while not copy_event.query():
|
154
|
-
time.sleep(1e-5)
|
150
|
+
def resolve_batch_result(self, bid: int):
|
151
|
+
copy_done, logits_output, next_token_ids = self.output_queue.get()
|
152
|
+
copy_done.synchronize()
|
153
|
+
self.launch_done.wait()
|
155
154
|
|
156
|
-
|
157
|
-
|
158
|
-
|
155
|
+
if logits_output.next_token_logprobs is not None:
|
156
|
+
logits_output.next_token_logprobs = (
|
157
|
+
logits_output.next_token_logprobs.tolist()
|
158
|
+
)
|
159
|
+
if logits_output.input_token_logprobs is not None:
|
160
|
+
logits_output.input_token_logprobs = (
|
161
|
+
logits_output.input_token_logprobs.tolist()
|
159
162
|
)
|
160
|
-
|
161
|
-
logits_output.
|
162
|
-
|
163
|
-
|
164
|
-
logits_output.normalized_prompt_logprobs = (
|
165
|
-
logits_output.normalized_prompt_logprobs.tolist()
|
166
|
-
)
|
167
|
-
|
168
|
-
self.output_queue.put((logits_output, next_token_ids.tolist()))
|
169
|
-
|
170
|
-
def resulve_batch_result(self, bid: int):
|
171
|
-
logits_output, next_token_ids = self.output_queue.get()
|
172
|
-
if self.has_inflight_batch:
|
173
|
-
# Wait until the batch is launched
|
174
|
-
self.launch_event.wait()
|
163
|
+
logits_output.normalized_prompt_logprobs = (
|
164
|
+
logits_output.normalized_prompt_logprobs.tolist()
|
165
|
+
)
|
166
|
+
next_token_ids = next_token_ids.tolist()
|
175
167
|
return logits_output, next_token_ids
|
176
168
|
|
177
169
|
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
170
|
+
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
171
|
+
sampling_info = model_worker_batch.sampling_info
|
172
|
+
sampling_info.update_penalties()
|
173
|
+
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
|
174
|
+
sampling_info,
|
175
|
+
sampling_info_done=threading.Event(),
|
176
|
+
scaling_penalties=sampling_info.scaling_penalties,
|
177
|
+
linear_penalties=sampling_info.linear_penalties,
|
178
|
+
)
|
179
|
+
|
180
|
+
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
181
|
+
torch.cuda.current_stream().synchronize()
|
182
|
+
|
178
183
|
# Push a new batch to the queue
|
179
|
-
self.input_queue.put((model_worker_batch
|
184
|
+
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
180
185
|
|
181
186
|
# Allocate output future objects
|
182
187
|
bs = len(model_worker_batch.seq_lens)
|
@@ -192,16 +197,8 @@ class TpModelWorkerClient:
|
|
192
197
|
) % self.future_token_ids_limit
|
193
198
|
return None, future_next_token_ids
|
194
199
|
|
195
|
-
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
196
|
-
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
197
|
-
logits_output = self.model_runner.forward(forward_batch)
|
198
|
-
embeddings = logits_output.embeddings
|
199
|
-
return embeddings
|
200
|
-
|
201
200
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
202
|
-
success, message = self.
|
203
|
-
recv_req.model_path, recv_req.load_format
|
204
|
-
)
|
201
|
+
success, message = self.worker.update_weights(recv_req)
|
205
202
|
return success, message
|
206
203
|
|
207
204
|
def __delete__(self):
|
@@ -90,6 +90,8 @@ def set_torch_compile_config():
|
|
90
90
|
|
91
91
|
# FIXME: tmp workaround
|
92
92
|
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
93
|
+
if hasattr(torch._dynamo.config, "cache_size_limit"):
|
94
|
+
torch._dynamo.config.cache_size_limit = 1024
|
93
95
|
|
94
96
|
|
95
97
|
@maybe_torch_compile(dynamic=True)
|
@@ -111,6 +113,8 @@ class CudaGraphRunner:
|
|
111
113
|
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
112
114
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
113
115
|
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
116
|
+
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
|
117
|
+
self.tp_size = self.model_runner.tp_size
|
114
118
|
|
115
119
|
# Batch sizes to capture
|
116
120
|
if model_runner.server_args.disable_cuda_graph_padding:
|
@@ -165,6 +169,15 @@ class CudaGraphRunner:
|
|
165
169
|
else:
|
166
170
|
self.encoder_lens = None
|
167
171
|
|
172
|
+
if self.enable_dp_attention:
|
173
|
+
self.gathered_buffer = torch.zeros(
|
174
|
+
(
|
175
|
+
self.max_bs * self.tp_size,
|
176
|
+
self.model_runner.model_config.hidden_size,
|
177
|
+
),
|
178
|
+
dtype=self.model_runner.dtype,
|
179
|
+
)
|
180
|
+
|
168
181
|
# Capture
|
169
182
|
try:
|
170
183
|
with self.model_capture_mode():
|
@@ -190,11 +203,21 @@ class CudaGraphRunner:
|
|
190
203
|
self.model_runner.model.capture_mode = False
|
191
204
|
|
192
205
|
def can_run(self, forward_batch: ForwardBatch):
|
193
|
-
|
194
|
-
forward_batch.
|
195
|
-
|
196
|
-
|
197
|
-
|
206
|
+
if self.enable_dp_attention:
|
207
|
+
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
|
208
|
+
forward_batch.global_num_tokens
|
209
|
+
)
|
210
|
+
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
211
|
+
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
|
212
|
+
if self.disable_padding
|
213
|
+
else max_num_tokens <= self.max_bs
|
214
|
+
)
|
215
|
+
else:
|
216
|
+
is_bs_supported = (
|
217
|
+
forward_batch.batch_size in self.graphs
|
218
|
+
if self.disable_padding
|
219
|
+
else forward_batch.batch_size <= self.max_bs
|
220
|
+
)
|
198
221
|
|
199
222
|
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
200
223
|
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
@@ -239,6 +262,13 @@ class CudaGraphRunner:
|
|
239
262
|
seq_lens_sum = seq_lens.sum().item()
|
240
263
|
mrope_positions = self.mrope_positions[:, :bs]
|
241
264
|
|
265
|
+
if self.enable_dp_attention:
|
266
|
+
global_num_tokens = [bs] * self.tp_size
|
267
|
+
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
|
268
|
+
else:
|
269
|
+
global_num_tokens = None
|
270
|
+
gathered_buffer = None
|
271
|
+
|
242
272
|
# Attention backend
|
243
273
|
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
244
274
|
bs,
|
@@ -265,6 +295,8 @@ class CudaGraphRunner:
|
|
265
295
|
top_logprobs_nums=[0] * bs,
|
266
296
|
positions=clamp_position(seq_lens),
|
267
297
|
mrope_positions=mrope_positions,
|
298
|
+
global_num_tokens=global_num_tokens,
|
299
|
+
gathered_buffer=gathered_buffer,
|
268
300
|
)
|
269
301
|
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
270
302
|
return logits_output.next_token_logits
|
@@ -295,7 +327,12 @@ class CudaGraphRunner:
|
|
295
327
|
raw_bs = forward_batch.batch_size
|
296
328
|
|
297
329
|
# Pad
|
298
|
-
|
330
|
+
if self.enable_dp_attention:
|
331
|
+
index = bisect.bisect_left(
|
332
|
+
self.capture_bs, max(forward_batch.global_num_tokens)
|
333
|
+
)
|
334
|
+
else:
|
335
|
+
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
299
336
|
bs = self.capture_bs[index]
|
300
337
|
if bs != raw_bs:
|
301
338
|
self.seq_lens.fill_(1)
|
@@ -36,6 +36,8 @@ from enum import IntEnum, auto
|
|
36
36
|
from typing import TYPE_CHECKING, List, Optional
|
37
37
|
|
38
38
|
import torch
|
39
|
+
import triton
|
40
|
+
import triton.language as tl
|
39
41
|
|
40
42
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
41
43
|
|
@@ -50,12 +52,18 @@ if TYPE_CHECKING:
|
|
50
52
|
class ForwardMode(IntEnum):
|
51
53
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
52
54
|
PREFILL = auto()
|
53
|
-
# Extend a sequence. The KV cache of the
|
55
|
+
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
|
54
56
|
EXTEND = auto()
|
55
57
|
# Decode one token.
|
56
58
|
DECODE = auto()
|
57
|
-
# Contains both EXTEND and DECODE.
|
59
|
+
# Contains both EXTEND and DECODE when doing chunked prefill.
|
58
60
|
MIXED = auto()
|
61
|
+
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
|
62
|
+
IDLE = auto()
|
63
|
+
|
64
|
+
# A dummy first batch to start the pipeline for overlap scheduler.
|
65
|
+
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
66
|
+
DUMMY_FIRST = auto()
|
59
67
|
|
60
68
|
def is_prefill(self):
|
61
69
|
return self == ForwardMode.PREFILL
|
@@ -69,6 +77,12 @@ class ForwardMode(IntEnum):
|
|
69
77
|
def is_mixed(self):
|
70
78
|
return self == ForwardMode.MIXED
|
71
79
|
|
80
|
+
def is_idle(self):
|
81
|
+
return self == ForwardMode.IDLE
|
82
|
+
|
83
|
+
def is_dummy_first(self):
|
84
|
+
return self == ForwardMode.DUMMY_FIRST
|
85
|
+
|
72
86
|
|
73
87
|
@dataclass
|
74
88
|
class ForwardBatch:
|
@@ -102,6 +116,7 @@ class ForwardBatch:
|
|
102
116
|
extend_seq_lens: Optional[torch.Tensor] = None
|
103
117
|
extend_prefix_lens: Optional[torch.Tensor] = None
|
104
118
|
extend_start_loc: Optional[torch.Tensor] = None
|
119
|
+
extend_prefix_lens_cpu: Optional[List[int]] = None
|
105
120
|
extend_seq_lens_cpu: Optional[List[int]] = None
|
106
121
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
107
122
|
|
@@ -128,6 +143,11 @@ class ForwardBatch:
|
|
128
143
|
# For Qwen2-VL
|
129
144
|
mrope_positions: torch.Tensor = None
|
130
145
|
|
146
|
+
# For DP attention
|
147
|
+
global_num_tokens: Optional[List[int]] = None
|
148
|
+
gathered_buffer: Optional[torch.Tensor] = None
|
149
|
+
can_run_dp_cuda_graph: bool = False
|
150
|
+
|
131
151
|
def compute_mrope_positions(
|
132
152
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
133
153
|
):
|
@@ -209,31 +229,36 @@ class ForwardBatch:
|
|
209
229
|
seq_lens_sum=batch.seq_lens_sum,
|
210
230
|
return_logprob=batch.return_logprob,
|
211
231
|
top_logprobs_nums=batch.top_logprobs_nums,
|
232
|
+
global_num_tokens=batch.global_num_tokens,
|
233
|
+
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
212
234
|
lora_paths=batch.lora_paths,
|
213
235
|
sampling_info=batch.sampling_info,
|
214
236
|
)
|
215
237
|
|
238
|
+
if ret.global_num_tokens is not None:
|
239
|
+
max_len = max(ret.global_num_tokens)
|
240
|
+
ret.gathered_buffer = torch.zeros(
|
241
|
+
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
|
242
|
+
dtype=model_runner.dtype,
|
243
|
+
device=device,
|
244
|
+
)
|
245
|
+
|
246
|
+
if ret.forward_mode.is_idle():
|
247
|
+
return ret
|
248
|
+
|
216
249
|
# Init position information
|
217
250
|
if not ret.forward_mode.is_decode():
|
218
|
-
ret.positions = torch.concat(
|
219
|
-
[
|
220
|
-
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
221
|
-
for prefix_len, extend_len in zip(
|
222
|
-
batch.extend_prefix_lens, batch.extend_seq_lens
|
223
|
-
)
|
224
|
-
],
|
225
|
-
axis=0,
|
226
|
-
)
|
227
|
-
ret.extend_num_tokens = batch.extend_num_tokens
|
228
251
|
ret.extend_seq_lens = torch.tensor(
|
229
252
|
batch.extend_seq_lens, dtype=torch.int32
|
230
253
|
).to(device, non_blocking=True)
|
231
|
-
|
232
254
|
ret.extend_prefix_lens = torch.tensor(
|
233
255
|
batch.extend_prefix_lens, dtype=torch.int32
|
234
256
|
).to(device, non_blocking=True)
|
235
|
-
ret.
|
236
|
-
ret.extend_start_loc
|
257
|
+
ret.extend_num_tokens = batch.extend_num_tokens
|
258
|
+
ret.positions, ret.extend_start_loc = compute_position_triton(
|
259
|
+
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
|
260
|
+
)
|
261
|
+
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
237
262
|
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
238
263
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
239
264
|
|
@@ -250,3 +275,72 @@ class ForwardBatch:
|
|
250
275
|
model_runner.lora_manager.prepare_lora_batch(ret)
|
251
276
|
|
252
277
|
return ret
|
278
|
+
|
279
|
+
|
280
|
+
def compute_position_triton(
|
281
|
+
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
282
|
+
):
|
283
|
+
"""Compute positions. It is a fused version of `compute_position_torch`."""
|
284
|
+
batch_size = extend_seq_lens.shape[0]
|
285
|
+
positions = torch.empty(
|
286
|
+
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
|
287
|
+
)
|
288
|
+
extend_start_loc = torch.empty(
|
289
|
+
batch_size, dtype=torch.int32, device=extend_seq_lens.device
|
290
|
+
)
|
291
|
+
|
292
|
+
# Launch kernel
|
293
|
+
compute_position_kernel[(batch_size,)](
|
294
|
+
positions,
|
295
|
+
extend_start_loc,
|
296
|
+
extend_prefix_lens,
|
297
|
+
extend_seq_lens,
|
298
|
+
)
|
299
|
+
|
300
|
+
return positions, extend_start_loc
|
301
|
+
|
302
|
+
|
303
|
+
@triton.jit
|
304
|
+
def compute_position_kernel(
|
305
|
+
positions,
|
306
|
+
extend_start_loc,
|
307
|
+
extend_prefix_lens,
|
308
|
+
extend_seq_lens,
|
309
|
+
):
|
310
|
+
BLOCK_SIZE: tl.constexpr = 512
|
311
|
+
pid = tl.program_id(0)
|
312
|
+
|
313
|
+
prefix_len = tl.load(extend_prefix_lens + pid)
|
314
|
+
seq_len = tl.load(extend_seq_lens + pid)
|
315
|
+
|
316
|
+
# TODO: optimize this?
|
317
|
+
cumsum_start = 0
|
318
|
+
for i in range(pid):
|
319
|
+
cumsum_start += tl.load(extend_seq_lens + i)
|
320
|
+
|
321
|
+
num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
|
322
|
+
for i in range(num_loop):
|
323
|
+
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
324
|
+
tl.store(
|
325
|
+
positions + cumsum_start + offset,
|
326
|
+
prefix_len + offset,
|
327
|
+
mask=offset < seq_len,
|
328
|
+
)
|
329
|
+
tl.store(extend_start_loc + pid, cumsum_start)
|
330
|
+
|
331
|
+
|
332
|
+
def compute_position_torch(
|
333
|
+
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
|
334
|
+
):
|
335
|
+
positions = torch.concat(
|
336
|
+
[
|
337
|
+
torch.arange(
|
338
|
+
prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
|
339
|
+
)
|
340
|
+
for prefix_len, extend_len in zip(extend_prefix_lens, extend_seq_lens)
|
341
|
+
],
|
342
|
+
axis=0,
|
343
|
+
)
|
344
|
+
extend_start_loc = torch.zeros_like(extend_seq_lens)
|
345
|
+
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
346
|
+
return positions.to(torch.int64), extend_start_loc
|