sglang 0.3.5.post2__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +474 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +71 -1
- sglang/check_env.py +3 -6
- sglang/srt/constrained/outlines_backend.py +15 -2
- sglang/srt/constrained/xgrammar_backend.py +22 -14
- sglang/srt/layers/activation.py +3 -0
- sglang/srt/layers/attention/flashinfer_backend.py +93 -48
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/custom_op_util.py +26 -0
- sglang/srt/layers/fused_moe/fused_moe.py +11 -4
- sglang/srt/layers/layernorm.py +4 -0
- sglang/srt/layers/logits_processor.py +10 -10
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +74 -9
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/schedule_batch.py +104 -38
- sglang/srt/managers/schedule_policy.py +5 -1
- sglang/srt/managers/scheduler.py +204 -54
- sglang/srt/managers/session_controller.py +62 -0
- sglang/srt/managers/tokenizer_manager.py +38 -0
- sglang/srt/managers/tp_worker.py +12 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
- sglang/srt/model_executor/cuda_graph_runner.py +43 -6
- sglang/srt/model_executor/forward_batch_info.py +109 -15
- sglang/srt/model_executor/model_runner.py +99 -43
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/deepseek_v2.py +147 -44
- sglang/srt/models/gemma2.py +9 -8
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/torch_native_llama.py +94 -78
- sglang/srt/openai_api/adapter.py +6 -2
- sglang/srt/openai_api/protocol.py +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +58 -57
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +27 -1
- sglang/srt/server_args.py +78 -62
- sglang/srt/utils.py +71 -52
- sglang/test/runners.py +25 -6
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +30 -19
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/RECORD +60 -55
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import logging
|
19
19
|
import multiprocessing as mp
|
20
|
+
import threading
|
20
21
|
from enum import Enum, auto
|
21
22
|
|
22
23
|
import zmq
|
@@ -28,6 +29,7 @@ from sglang.srt.managers.io_struct import (
|
|
28
29
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
29
30
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
30
31
|
from sglang.srt.utils import (
|
32
|
+
bind_port,
|
31
33
|
configure_logger,
|
32
34
|
get_zmq_socket,
|
33
35
|
kill_parent_process,
|
@@ -80,20 +82,62 @@ class DataParallelController:
|
|
80
82
|
|
81
83
|
# Start data parallel workers
|
82
84
|
base_gpu_id = 0
|
83
|
-
self.workers = []
|
85
|
+
self.workers = [None] * server_args.dp_size
|
86
|
+
|
87
|
+
threads = []
|
88
|
+
sockets = []
|
84
89
|
for dp_rank in range(server_args.dp_size):
|
85
90
|
tmp_port_args = PortArgs.init_new(server_args)
|
91
|
+
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
86
92
|
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
87
93
|
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
94
|
+
if server_args.enable_dp_attention:
|
95
|
+
# Data parallelism resues the tensor parallelism group,
|
96
|
+
# so all dp ranks should use the same nccl port.
|
97
|
+
tmp_port_args.nccl_port = port_args.nccl_port
|
98
|
+
else:
|
99
|
+
# This port is checked free in PortArgs.init_new.
|
100
|
+
# We hold it first so that the next dp worker gets a different port
|
101
|
+
sockets.append(bind_port(tmp_port_args.nccl_port))
|
102
|
+
|
103
|
+
# Create a thread for each worker
|
104
|
+
thread = threading.Thread(
|
105
|
+
target=self.launch_worker_func,
|
106
|
+
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
|
93
107
|
)
|
108
|
+
threads.append(thread)
|
109
|
+
base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size
|
110
|
+
|
111
|
+
# Free all sockets before starting the threads to launch TP workers
|
112
|
+
for sock in sockets:
|
113
|
+
sock.close()
|
114
|
+
|
115
|
+
# Start all threads
|
116
|
+
for thread in threads:
|
117
|
+
thread.start()
|
118
|
+
for thread in threads:
|
119
|
+
thread.join()
|
120
|
+
|
121
|
+
def launch_worker_func(
|
122
|
+
self,
|
123
|
+
server_args: ServerArgs,
|
124
|
+
port_args: PortArgs,
|
125
|
+
base_gpu_id: int,
|
126
|
+
dp_rank: int,
|
127
|
+
):
|
128
|
+
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
94
129
|
|
95
|
-
|
96
|
-
|
130
|
+
launch_func_ = (
|
131
|
+
self.launch_tensor_parallel_process
|
132
|
+
if server_args.enable_dp_attention
|
133
|
+
else self.launch_tensor_parallel_group
|
134
|
+
)
|
135
|
+
self.workers[dp_rank] = launch_func_(
|
136
|
+
server_args,
|
137
|
+
port_args,
|
138
|
+
base_gpu_id,
|
139
|
+
dp_rank,
|
140
|
+
)
|
97
141
|
|
98
142
|
def launch_tensor_parallel_group(
|
99
143
|
self,
|
@@ -112,7 +156,7 @@ class DataParallelController:
|
|
112
156
|
)
|
113
157
|
for tp_rank in tp_rank_range:
|
114
158
|
reader, writer = mp.Pipe(duplex=False)
|
115
|
-
gpu_id = base_gpu_id + tp_rank % tp_size_per_node
|
159
|
+
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
|
116
160
|
proc = mp.Process(
|
117
161
|
target=run_scheduler_process,
|
118
162
|
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
@@ -131,6 +175,27 @@ class DataParallelController:
|
|
131
175
|
|
132
176
|
return send_to
|
133
177
|
|
178
|
+
def launch_tensor_parallel_process(
|
179
|
+
self,
|
180
|
+
server_args: ServerArgs,
|
181
|
+
port_args: PortArgs,
|
182
|
+
base_gpu_id: int,
|
183
|
+
dp_rank: int,
|
184
|
+
):
|
185
|
+
reader, writer = mp.Pipe(duplex=False)
|
186
|
+
gpu_id = base_gpu_id
|
187
|
+
tp_rank = dp_rank
|
188
|
+
proc = mp.Process(
|
189
|
+
target=run_scheduler_process,
|
190
|
+
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
191
|
+
)
|
192
|
+
proc.start()
|
193
|
+
send_to = get_zmq_socket(
|
194
|
+
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
195
|
+
)
|
196
|
+
reader.recv()
|
197
|
+
return send_to
|
198
|
+
|
134
199
|
def round_robin_scheduler(self, req):
|
135
200
|
self.workers[self.round_robin_counter].send_pyobj(req)
|
136
201
|
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -56,6 +56,10 @@ class GenerateReqInput:
|
|
56
56
|
# LoRA related
|
57
57
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
58
58
|
|
59
|
+
# Session id info for continual prompting
|
60
|
+
session_id: Optional[Union[List[str], str]] = None
|
61
|
+
session_rid: Optional[Union[List[str], str]] = None
|
62
|
+
|
59
63
|
def normalize_batch_and_arguments(self):
|
60
64
|
if (self.text is None and self.input_ids is None) or (
|
61
65
|
self.text is not None and self.input_ids is not None
|
@@ -200,6 +204,10 @@ class TokenizedGenerateReqInput:
|
|
200
204
|
# LoRA related
|
201
205
|
lora_path: Optional[str] = None # None means just use the base model
|
202
206
|
|
207
|
+
# Session id info for continual prompting
|
208
|
+
session_id: Optional[int] = None
|
209
|
+
session_rid: Optional[str] = None
|
210
|
+
|
203
211
|
|
204
212
|
@dataclass
|
205
213
|
class EmbeddingReqInput:
|
@@ -293,6 +301,8 @@ class BatchTokenIDOut:
|
|
293
301
|
meta_info: List[Dict]
|
294
302
|
finished_reason: List[BaseFinishReason]
|
295
303
|
no_stop_trim: List[bool]
|
304
|
+
# The updated session unique id
|
305
|
+
session_ids: List[str]
|
296
306
|
|
297
307
|
|
298
308
|
@dataclass
|
@@ -305,6 +315,8 @@ class BatchStrOut:
|
|
305
315
|
meta_info: List[Dict]
|
306
316
|
# The finish reason
|
307
317
|
finished_reason: List[BaseFinishReason]
|
318
|
+
# The update session unique id
|
319
|
+
session_ids: List[str]
|
308
320
|
|
309
321
|
|
310
322
|
@dataclass
|
@@ -357,3 +369,18 @@ class GetMemPoolSizeReq:
|
|
357
369
|
@dataclass
|
358
370
|
class GetMemPoolSizeReqOutput:
|
359
371
|
size: int
|
372
|
+
|
373
|
+
|
374
|
+
@dataclass
|
375
|
+
class OpenSessionReqInput:
|
376
|
+
capacity_of_str_len: int
|
377
|
+
|
378
|
+
|
379
|
+
@dataclass
|
380
|
+
class CloseSessionReqInput:
|
381
|
+
session_id: str
|
382
|
+
|
383
|
+
|
384
|
+
@dataclass
|
385
|
+
class OpenSessionReqOutput:
|
386
|
+
session_id: str
|
@@ -34,6 +34,8 @@ import logging
|
|
34
34
|
from typing import List, Optional, Tuple, Union
|
35
35
|
|
36
36
|
import torch
|
37
|
+
import triton
|
38
|
+
import triton.language as tl
|
37
39
|
|
38
40
|
from sglang.global_config import global_config
|
39
41
|
from sglang.srt.configs.model_config import ModelConfig
|
@@ -55,7 +57,8 @@ global_server_args_dict = {
|
|
55
57
|
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
56
58
|
"disable_mla": ServerArgs.disable_mla,
|
57
59
|
"torchao_config": ServerArgs.torchao_config,
|
58
|
-
"
|
60
|
+
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
61
|
+
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
59
62
|
}
|
60
63
|
|
61
64
|
|
@@ -133,6 +136,7 @@ class ImageInputs:
|
|
133
136
|
image_embeds: Optional[List[torch.Tensor]] = None
|
134
137
|
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
135
138
|
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
139
|
+
|
136
140
|
# QWen2-VL related
|
137
141
|
image_grid_thws: List[Tuple[int, int, int]] = None
|
138
142
|
mrope_position_delta: Optional[torch.Tensor] = None
|
@@ -176,6 +180,7 @@ class Req:
|
|
176
180
|
origin_input_ids: Tuple[int],
|
177
181
|
sampling_params: SamplingParams,
|
178
182
|
lora_path: Optional[str] = None,
|
183
|
+
session_id: Optional[str] = None,
|
179
184
|
):
|
180
185
|
# Input and output info
|
181
186
|
self.rid = rid
|
@@ -184,11 +189,12 @@ class Req:
|
|
184
189
|
self.origin_input_ids = origin_input_ids
|
185
190
|
self.output_ids = [] # Each decode stage's output ids
|
186
191
|
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
192
|
+
self.session_id = session_id
|
187
193
|
|
188
194
|
self.sampling_params = sampling_params
|
189
195
|
self.lora_path = lora_path
|
190
196
|
|
191
|
-
# Memory info
|
197
|
+
# Memory pool info
|
192
198
|
self.req_pool_idx = None
|
193
199
|
|
194
200
|
# Check finish
|
@@ -425,7 +431,7 @@ bid = 0
|
|
425
431
|
|
426
432
|
@dataclasses.dataclass
|
427
433
|
class ScheduleBatch:
|
428
|
-
"""Store all inforamtion of a batch."""
|
434
|
+
"""Store all inforamtion of a batch on the scheduler."""
|
429
435
|
|
430
436
|
# Request, memory pool, and cache
|
431
437
|
reqs: List[Req]
|
@@ -435,9 +441,9 @@ class ScheduleBatch:
|
|
435
441
|
|
436
442
|
# For utility
|
437
443
|
model_config: ModelConfig = None
|
438
|
-
|
439
444
|
forward_mode: ForwardMode = None
|
440
445
|
sampling_info: SamplingBatchInfo = None
|
446
|
+
next_batch_sampling_info: SamplingBatchInfo = None
|
441
447
|
|
442
448
|
# Batched arguments to model runner
|
443
449
|
input_ids: torch.Tensor = None
|
@@ -450,6 +456,10 @@ class ScheduleBatch:
|
|
450
456
|
# The sum of all sequence lengths
|
451
457
|
seq_lens_sum: int = None
|
452
458
|
|
459
|
+
# For DP attention
|
460
|
+
global_num_tokens: Optional[List[int]] = None
|
461
|
+
can_run_dp_cuda_graph: bool = False
|
462
|
+
|
453
463
|
# For processing logprobs
|
454
464
|
return_logprob: bool = False
|
455
465
|
top_logprobs_nums: Optional[List[int]] = None
|
@@ -502,7 +512,7 @@ class ScheduleBatch:
|
|
502
512
|
def is_empty(self):
|
503
513
|
return len(self.reqs) == 0
|
504
514
|
|
505
|
-
def alloc_req_slots(self, num_reqs):
|
515
|
+
def alloc_req_slots(self, num_reqs: int):
|
506
516
|
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
507
517
|
if req_pool_indices is None:
|
508
518
|
raise RuntimeError(
|
@@ -588,14 +598,14 @@ class ScheduleBatch:
|
|
588
598
|
)
|
589
599
|
|
590
600
|
if not decoder_out_cache_loc:
|
591
|
-
self.out_cache_loc = torch.
|
601
|
+
self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
|
592
602
|
self.device, non_blocking=True
|
593
603
|
)
|
594
604
|
else:
|
595
605
|
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
|
596
606
|
|
597
607
|
if not encoder_out_cache_loc:
|
598
|
-
self.encoder_out_cache_loc = torch.
|
608
|
+
self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
|
599
609
|
self.device, non_blocking=True
|
600
610
|
)
|
601
611
|
else:
|
@@ -603,7 +613,7 @@ class ScheduleBatch:
|
|
603
613
|
|
604
614
|
assert len(self.out_cache_loc) == self.extend_num_tokens
|
605
615
|
|
606
|
-
def prepare_for_extend(self):
|
616
|
+
def prepare_for_extend(self, enable_overlap_schedule: bool = False):
|
607
617
|
self.forward_mode = ForwardMode.EXTEND
|
608
618
|
|
609
619
|
bs = len(self.reqs)
|
@@ -611,12 +621,12 @@ class ScheduleBatch:
|
|
611
621
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
612
622
|
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
613
623
|
seq_lens = []
|
624
|
+
pre_lens = []
|
614
625
|
|
615
626
|
# Allocate memory
|
616
627
|
req_pool_indices = self.alloc_req_slots(bs)
|
617
628
|
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
618
629
|
|
619
|
-
pt = 0
|
620
630
|
for i, req in enumerate(reqs):
|
621
631
|
already_computed = (
|
622
632
|
req.extend_logprob_start_len + 1 + req.cached_tokens
|
@@ -634,10 +644,6 @@ class ScheduleBatch:
|
|
634
644
|
self.req_to_token_pool.write(
|
635
645
|
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
636
646
|
)
|
637
|
-
self.req_to_token_pool.write(
|
638
|
-
(req.req_pool_idx, slice(pre_len, seq_len)),
|
639
|
-
out_cache_loc[pt : pt + req.extend_input_len],
|
640
|
-
)
|
641
647
|
|
642
648
|
# Compute the relative logprob_start_len in an extend batch
|
643
649
|
if req.logprob_start_len >= pre_len:
|
@@ -648,8 +654,8 @@ class ScheduleBatch:
|
|
648
654
|
extend_logprob_start_len = req.extend_input_len - 1
|
649
655
|
|
650
656
|
req.extend_logprob_start_len = extend_logprob_start_len
|
651
|
-
pt += req.extend_input_len
|
652
657
|
req.is_retracted = False
|
658
|
+
pre_lens.append(pre_len)
|
653
659
|
|
654
660
|
# Set fields
|
655
661
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
@@ -661,7 +667,6 @@ class ScheduleBatch:
|
|
661
667
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
662
668
|
self.device, non_blocking=True
|
663
669
|
)
|
664
|
-
|
665
670
|
self.out_cache_loc = out_cache_loc
|
666
671
|
|
667
672
|
self.seq_lens_sum = sum(seq_lens)
|
@@ -672,13 +677,37 @@ class ScheduleBatch:
|
|
672
677
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
673
678
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
674
679
|
|
680
|
+
# Write to req_to_token_pool
|
681
|
+
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
|
682
|
+
self.device, non_blocking=True
|
683
|
+
)
|
684
|
+
extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
|
685
|
+
self.device, non_blocking=True
|
686
|
+
)
|
687
|
+
write_req_to_token_pool_triton[(bs,)](
|
688
|
+
self.req_to_token_pool.req_to_token,
|
689
|
+
self.req_pool_indices,
|
690
|
+
pre_lens,
|
691
|
+
self.seq_lens,
|
692
|
+
extend_lens,
|
693
|
+
self.out_cache_loc,
|
694
|
+
self.req_to_token_pool.req_to_token.shape[1],
|
695
|
+
)
|
696
|
+
# The triton kernel is equivalent to the following python code.
|
697
|
+
# self.req_to_token_pool.write(
|
698
|
+
# (req.req_pool_idx, slice(pre_len, seq_len)),
|
699
|
+
# out_cache_loc[pt : pt + req.extend_input_len],
|
700
|
+
# )
|
701
|
+
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
702
|
+
|
675
703
|
if self.model_config.is_encoder_decoder:
|
676
704
|
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
677
705
|
|
706
|
+
# Build sampling info
|
678
707
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
679
708
|
self,
|
680
709
|
self.model_config.vocab_size,
|
681
|
-
|
710
|
+
enable_overlap_schedule=enable_overlap_schedule,
|
682
711
|
)
|
683
712
|
|
684
713
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
@@ -720,6 +749,7 @@ class ScheduleBatch:
|
|
720
749
|
return False
|
721
750
|
|
722
751
|
def retract_decode(self):
|
752
|
+
"""Retract the decoding requests when there is not enough memory."""
|
723
753
|
sorted_indices = [i for i in range(len(self.reqs))]
|
724
754
|
|
725
755
|
# TODO(lsyin): improve retraction policy for radix cache
|
@@ -858,15 +888,21 @@ class ScheduleBatch:
|
|
858
888
|
# Reset the encoder cached status
|
859
889
|
self.encoder_cached = [True] * len(self.reqs)
|
860
890
|
|
891
|
+
def prepare_for_idle(self):
|
892
|
+
self.forward_mode = ForwardMode.IDLE
|
893
|
+
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
894
|
+
self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
895
|
+
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
896
|
+
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
897
|
+
self.seq_lens_sum = 0
|
898
|
+
self.extend_num_tokens = 0
|
899
|
+
|
861
900
|
def prepare_for_decode(self, enable_overlap: bool = False):
|
862
901
|
self.forward_mode = ForwardMode.DECODE
|
863
902
|
|
864
903
|
self.input_ids = self.output_ids
|
865
904
|
self.output_ids = None
|
866
|
-
|
867
|
-
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
868
|
-
self.input_ids
|
869
|
-
)
|
905
|
+
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
|
870
906
|
|
871
907
|
# Alloc mem
|
872
908
|
bs = len(self.reqs)
|
@@ -969,17 +1005,18 @@ class ScheduleBatch:
|
|
969
1005
|
self.has_grammar = self.has_grammar or other.has_grammar
|
970
1006
|
|
971
1007
|
def get_model_worker_batch(self):
|
972
|
-
if self.forward_mode.is_decode():
|
1008
|
+
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
973
1009
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
974
1010
|
else:
|
975
1011
|
extend_seq_lens = self.extend_lens
|
976
1012
|
extend_prefix_lens = self.prefix_lens
|
977
1013
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
978
1014
|
|
979
|
-
if self.
|
980
|
-
|
981
|
-
|
982
|
-
|
1015
|
+
if self.sampling_info:
|
1016
|
+
if self.has_grammar:
|
1017
|
+
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
1018
|
+
else:
|
1019
|
+
self.sampling_info.grammars = None
|
983
1020
|
|
984
1021
|
global bid
|
985
1022
|
bid += 1
|
@@ -995,6 +1032,8 @@ class ScheduleBatch:
|
|
995
1032
|
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
|
996
1033
|
return_logprob=self.return_logprob,
|
997
1034
|
top_logprobs_nums=self.top_logprobs_nums,
|
1035
|
+
global_num_tokens=self.global_num_tokens,
|
1036
|
+
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
998
1037
|
extend_num_tokens=self.extend_num_tokens,
|
999
1038
|
extend_seq_lens=extend_seq_lens,
|
1000
1039
|
extend_prefix_lens=extend_prefix_lens,
|
@@ -1051,6 +1090,10 @@ class ModelWorkerBatch:
|
|
1051
1090
|
return_logprob: bool
|
1052
1091
|
top_logprobs_nums: Optional[List[int]]
|
1053
1092
|
|
1093
|
+
# For DP attention
|
1094
|
+
global_num_tokens: Optional[List[int]]
|
1095
|
+
can_run_dp_cuda_graph: bool
|
1096
|
+
|
1054
1097
|
# For extend
|
1055
1098
|
extend_num_tokens: Optional[int]
|
1056
1099
|
extend_seq_lens: Optional[List[int]]
|
@@ -1072,16 +1115,39 @@ class ModelWorkerBatch:
|
|
1072
1115
|
# Sampling info
|
1073
1116
|
sampling_info: SamplingBatchInfo
|
1074
1117
|
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1118
|
+
|
1119
|
+
@triton.jit
|
1120
|
+
def write_req_to_token_pool_triton(
|
1121
|
+
req_to_token_ptr, # [max_batch, max_context_len]
|
1122
|
+
req_pool_indices,
|
1123
|
+
pre_lens,
|
1124
|
+
seq_lens,
|
1125
|
+
extend_lens,
|
1126
|
+
out_cache_loc,
|
1127
|
+
req_to_token_ptr_stride: tl.constexpr,
|
1128
|
+
):
|
1129
|
+
BLOCK_SIZE: tl.constexpr = 512
|
1130
|
+
pid = tl.program_id(0)
|
1131
|
+
|
1132
|
+
req_pool_index = tl.load(req_pool_indices + pid)
|
1133
|
+
pre_len = tl.load(pre_lens + pid)
|
1134
|
+
seq_len = tl.load(seq_lens + pid)
|
1135
|
+
|
1136
|
+
# TODO: optimize this?
|
1137
|
+
cumsum_start = 0
|
1138
|
+
for i in range(pid):
|
1139
|
+
cumsum_start += tl.load(extend_lens + i)
|
1140
|
+
|
1141
|
+
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
|
1142
|
+
for i in range(num_loop):
|
1143
|
+
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
1144
|
+
mask = offset < (seq_len - pre_len)
|
1145
|
+
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
|
1146
|
+
tl.store(
|
1147
|
+
req_to_token_ptr
|
1148
|
+
+ req_pool_index * req_to_token_ptr_stride
|
1149
|
+
+ offset
|
1150
|
+
+ pre_len,
|
1151
|
+
value,
|
1152
|
+
mask=mask,
|
1153
|
+
)
|
@@ -302,7 +302,11 @@ class PrefillAdder:
|
|
302
302
|
if (
|
303
303
|
self.rem_chunk_tokens is None
|
304
304
|
or input_tokens <= self.rem_chunk_tokens
|
305
|
-
or (
|
305
|
+
or (
|
306
|
+
req.return_logprob
|
307
|
+
and req.normalized_prompt_logprob is None
|
308
|
+
and req.logprob_start_len != len(req.origin_input_ids) - 1
|
309
|
+
)
|
306
310
|
):
|
307
311
|
# Non-chunked prefill
|
308
312
|
self.can_run_list.append(req)
|