sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +472 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +125 -6
- sglang/check_env.py +3 -6
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +28 -17
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +47 -58
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +16 -13
- sglang/srt/layers/attention/flashinfer_backend.py +106 -54
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +25 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +17 -15
- sglang/srt/layers/logits_processor.py +23 -25
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +98 -27
- sglang/srt/managers/detokenizer_manager.py +13 -15
- sglang/srt/managers/io_struct.py +63 -21
- sglang/srt/managers/schedule_batch.py +154 -59
- sglang/srt/managers/schedule_policy.py +18 -16
- sglang/srt/managers/scheduler.py +278 -109
- sglang/srt/managers/session_controller.py +61 -0
- sglang/srt/managers/tokenizer_manager.py +63 -18
- sglang/srt/managers/tp_worker.py +25 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +63 -25
- sglang/srt/model_executor/forward_batch_info.py +128 -32
- sglang/srt/model_executor/model_runner.py +132 -64
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +162 -59
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +31 -25
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +14 -16
- sglang/srt/models/llavavid.py +14 -16
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +22 -20
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +107 -93
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +19 -17
- sglang/srt/openai_api/protocol.py +14 -16
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +61 -57
- sglang/srt/sampling/sampling_params.py +14 -16
- sglang/srt/server.py +86 -35
- sglang/srt/server_args.py +96 -80
- sglang/srt/utils.py +266 -68
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +38 -20
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +31 -20
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.5.post2.dist-info/RECORD +0 -156
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,20 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
See the License for the specific language governing permissions and
|
15
|
-
limitations under the License.
|
16
|
-
"""
|
17
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
18
14
|
"""Run the model with cuda graph and torch.compile."""
|
19
15
|
|
16
|
+
from __future__ import annotations
|
17
|
+
|
20
18
|
import bisect
|
21
19
|
from contextlib import contextmanager
|
22
20
|
from typing import TYPE_CHECKING, Callable
|
@@ -25,7 +23,7 @@ import torch
|
|
25
23
|
from vllm.distributed.parallel_state import graph_capture
|
26
24
|
from vllm.model_executor.custom_op import CustomOp
|
27
25
|
|
28
|
-
from sglang.srt.layers.
|
26
|
+
from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native
|
29
27
|
from sglang.srt.layers.logits_processor import (
|
30
28
|
LogitsMetadata,
|
31
29
|
LogitsProcessor,
|
@@ -67,7 +65,10 @@ def patch_model(
|
|
67
65
|
_to_torch(model)
|
68
66
|
monkey_patch_vllm_all_gather()
|
69
67
|
backup_ca_comm = tp_group.ca_comm
|
70
|
-
|
68
|
+
# Use custom-allreduce here.
|
69
|
+
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
70
|
+
# even with ENABLE_INTRA_NODE_COMM=1.
|
71
|
+
# tp_group.ca_comm = None
|
71
72
|
yield torch.compile(
|
72
73
|
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
|
73
74
|
)
|
@@ -90,6 +91,8 @@ def set_torch_compile_config():
|
|
90
91
|
|
91
92
|
# FIXME: tmp workaround
|
92
93
|
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
94
|
+
if hasattr(torch._dynamo.config, "cache_size_limit"):
|
95
|
+
torch._dynamo.config.cache_size_limit = 1024
|
93
96
|
|
94
97
|
|
95
98
|
@maybe_torch_compile(dynamic=True)
|
@@ -111,6 +114,8 @@ class CudaGraphRunner:
|
|
111
114
|
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
112
115
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
113
116
|
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
117
|
+
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
|
118
|
+
self.tp_size = self.model_runner.tp_size
|
114
119
|
|
115
120
|
# Batch sizes to capture
|
116
121
|
if model_runner.server_args.disable_cuda_graph_padding:
|
@@ -165,6 +170,15 @@ class CudaGraphRunner:
|
|
165
170
|
else:
|
166
171
|
self.encoder_lens = None
|
167
172
|
|
173
|
+
if self.enable_dp_attention:
|
174
|
+
self.gathered_buffer = torch.zeros(
|
175
|
+
(
|
176
|
+
self.max_bs * self.tp_size,
|
177
|
+
self.model_runner.model_config.hidden_size,
|
178
|
+
),
|
179
|
+
dtype=self.model_runner.dtype,
|
180
|
+
)
|
181
|
+
|
168
182
|
# Capture
|
169
183
|
try:
|
170
184
|
with self.model_capture_mode():
|
@@ -190,11 +204,21 @@ class CudaGraphRunner:
|
|
190
204
|
self.model_runner.model.capture_mode = False
|
191
205
|
|
192
206
|
def can_run(self, forward_batch: ForwardBatch):
|
193
|
-
|
194
|
-
forward_batch.
|
195
|
-
|
196
|
-
|
197
|
-
|
207
|
+
if self.enable_dp_attention:
|
208
|
+
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
|
209
|
+
forward_batch.global_num_tokens
|
210
|
+
)
|
211
|
+
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
212
|
+
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
|
213
|
+
if self.disable_padding
|
214
|
+
else max_num_tokens <= self.max_bs
|
215
|
+
)
|
216
|
+
else:
|
217
|
+
is_bs_supported = (
|
218
|
+
forward_batch.batch_size in self.graphs
|
219
|
+
if self.disable_padding
|
220
|
+
else forward_batch.batch_size <= self.max_bs
|
221
|
+
)
|
198
222
|
|
199
223
|
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
200
224
|
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
@@ -239,6 +263,13 @@ class CudaGraphRunner:
|
|
239
263
|
seq_lens_sum = seq_lens.sum().item()
|
240
264
|
mrope_positions = self.mrope_positions[:, :bs]
|
241
265
|
|
266
|
+
if self.enable_dp_attention:
|
267
|
+
global_num_tokens = [bs] * self.tp_size
|
268
|
+
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
|
269
|
+
else:
|
270
|
+
global_num_tokens = None
|
271
|
+
gathered_buffer = None
|
272
|
+
|
242
273
|
# Attention backend
|
243
274
|
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
244
275
|
bs,
|
@@ -265,6 +296,8 @@ class CudaGraphRunner:
|
|
265
296
|
top_logprobs_nums=[0] * bs,
|
266
297
|
positions=clamp_position(seq_lens),
|
267
298
|
mrope_positions=mrope_positions,
|
299
|
+
global_num_tokens=global_num_tokens,
|
300
|
+
gathered_buffer=gathered_buffer,
|
268
301
|
)
|
269
302
|
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
270
303
|
return logits_output.next_token_logits
|
@@ -295,7 +328,12 @@ class CudaGraphRunner:
|
|
295
328
|
raw_bs = forward_batch.batch_size
|
296
329
|
|
297
330
|
# Pad
|
298
|
-
|
331
|
+
if self.enable_dp_attention:
|
332
|
+
index = bisect.bisect_left(
|
333
|
+
self.capture_bs, max(forward_batch.global_num_tokens)
|
334
|
+
)
|
335
|
+
else:
|
336
|
+
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
299
337
|
bs = self.capture_bs[index]
|
300
338
|
if bs != raw_bs:
|
301
339
|
self.seq_lens.fill_(1)
|
@@ -1,20 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
See the License for the specific language governing permissions and
|
15
|
-
limitations under the License.
|
16
|
-
"""
|
17
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
18
14
|
"""
|
19
15
|
Store information about a forward batch.
|
20
16
|
|
@@ -31,11 +27,15 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|
31
27
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
32
28
|
"""
|
33
29
|
|
30
|
+
from __future__ import annotations
|
31
|
+
|
34
32
|
from dataclasses import dataclass
|
35
33
|
from enum import IntEnum, auto
|
36
34
|
from typing import TYPE_CHECKING, List, Optional
|
37
35
|
|
38
36
|
import torch
|
37
|
+
import triton
|
38
|
+
import triton.language as tl
|
39
39
|
|
40
40
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
41
41
|
|
@@ -50,12 +50,18 @@ if TYPE_CHECKING:
|
|
50
50
|
class ForwardMode(IntEnum):
|
51
51
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
52
52
|
PREFILL = auto()
|
53
|
-
# Extend a sequence. The KV cache of the
|
53
|
+
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
|
54
54
|
EXTEND = auto()
|
55
55
|
# Decode one token.
|
56
56
|
DECODE = auto()
|
57
|
-
# Contains both EXTEND and DECODE.
|
57
|
+
# Contains both EXTEND and DECODE when doing chunked prefill.
|
58
58
|
MIXED = auto()
|
59
|
+
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
|
60
|
+
IDLE = auto()
|
61
|
+
|
62
|
+
# A dummy first batch to start the pipeline for overlap scheduler.
|
63
|
+
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
64
|
+
DUMMY_FIRST = auto()
|
59
65
|
|
60
66
|
def is_prefill(self):
|
61
67
|
return self == ForwardMode.PREFILL
|
@@ -69,6 +75,12 @@ class ForwardMode(IntEnum):
|
|
69
75
|
def is_mixed(self):
|
70
76
|
return self == ForwardMode.MIXED
|
71
77
|
|
78
|
+
def is_idle(self):
|
79
|
+
return self == ForwardMode.IDLE
|
80
|
+
|
81
|
+
def is_dummy_first(self):
|
82
|
+
return self == ForwardMode.DUMMY_FIRST
|
83
|
+
|
72
84
|
|
73
85
|
@dataclass
|
74
86
|
class ForwardBatch:
|
@@ -102,6 +114,7 @@ class ForwardBatch:
|
|
102
114
|
extend_seq_lens: Optional[torch.Tensor] = None
|
103
115
|
extend_prefix_lens: Optional[torch.Tensor] = None
|
104
116
|
extend_start_loc: Optional[torch.Tensor] = None
|
117
|
+
extend_prefix_lens_cpu: Optional[List[int]] = None
|
105
118
|
extend_seq_lens_cpu: Optional[List[int]] = None
|
106
119
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
107
120
|
|
@@ -117,6 +130,9 @@ class ForwardBatch:
|
|
117
130
|
# For LoRA
|
118
131
|
lora_paths: Optional[List[str]] = None
|
119
132
|
|
133
|
+
# For input embeddings
|
134
|
+
input_embeds: Optional[torch.tensor] = None
|
135
|
+
|
120
136
|
# Sampling info
|
121
137
|
sampling_info: SamplingBatchInfo = None
|
122
138
|
|
@@ -128,6 +144,11 @@ class ForwardBatch:
|
|
128
144
|
# For Qwen2-VL
|
129
145
|
mrope_positions: torch.Tensor = None
|
130
146
|
|
147
|
+
# For DP attention
|
148
|
+
global_num_tokens: Optional[List[int]] = None
|
149
|
+
gathered_buffer: Optional[torch.Tensor] = None
|
150
|
+
can_run_dp_cuda_graph: bool = False
|
151
|
+
|
131
152
|
def compute_mrope_positions(
|
132
153
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
133
154
|
):
|
@@ -209,31 +230,37 @@ class ForwardBatch:
|
|
209
230
|
seq_lens_sum=batch.seq_lens_sum,
|
210
231
|
return_logprob=batch.return_logprob,
|
211
232
|
top_logprobs_nums=batch.top_logprobs_nums,
|
233
|
+
global_num_tokens=batch.global_num_tokens,
|
234
|
+
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
212
235
|
lora_paths=batch.lora_paths,
|
213
236
|
sampling_info=batch.sampling_info,
|
237
|
+
input_embeds=batch.input_embeds,
|
214
238
|
)
|
215
239
|
|
240
|
+
if ret.global_num_tokens is not None:
|
241
|
+
max_len = max(ret.global_num_tokens)
|
242
|
+
ret.gathered_buffer = torch.zeros(
|
243
|
+
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
|
244
|
+
dtype=model_runner.dtype,
|
245
|
+
device=device,
|
246
|
+
)
|
247
|
+
|
248
|
+
if ret.forward_mode.is_idle():
|
249
|
+
return ret
|
250
|
+
|
216
251
|
# Init position information
|
217
252
|
if not ret.forward_mode.is_decode():
|
218
|
-
ret.positions = torch.concat(
|
219
|
-
[
|
220
|
-
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
221
|
-
for prefix_len, extend_len in zip(
|
222
|
-
batch.extend_prefix_lens, batch.extend_seq_lens
|
223
|
-
)
|
224
|
-
],
|
225
|
-
axis=0,
|
226
|
-
)
|
227
|
-
ret.extend_num_tokens = batch.extend_num_tokens
|
228
253
|
ret.extend_seq_lens = torch.tensor(
|
229
254
|
batch.extend_seq_lens, dtype=torch.int32
|
230
255
|
).to(device, non_blocking=True)
|
231
|
-
|
232
256
|
ret.extend_prefix_lens = torch.tensor(
|
233
257
|
batch.extend_prefix_lens, dtype=torch.int32
|
234
258
|
).to(device, non_blocking=True)
|
235
|
-
ret.
|
236
|
-
ret.extend_start_loc
|
259
|
+
ret.extend_num_tokens = batch.extend_num_tokens
|
260
|
+
ret.positions, ret.extend_start_loc = compute_position_triton(
|
261
|
+
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
|
262
|
+
)
|
263
|
+
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
237
264
|
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
238
265
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
239
266
|
|
@@ -250,3 +277,72 @@ class ForwardBatch:
|
|
250
277
|
model_runner.lora_manager.prepare_lora_batch(ret)
|
251
278
|
|
252
279
|
return ret
|
280
|
+
|
281
|
+
|
282
|
+
def compute_position_triton(
|
283
|
+
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
284
|
+
):
|
285
|
+
"""Compute positions. It is a fused version of `compute_position_torch`."""
|
286
|
+
batch_size = extend_seq_lens.shape[0]
|
287
|
+
positions = torch.empty(
|
288
|
+
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
|
289
|
+
)
|
290
|
+
extend_start_loc = torch.empty(
|
291
|
+
batch_size, dtype=torch.int32, device=extend_seq_lens.device
|
292
|
+
)
|
293
|
+
|
294
|
+
# Launch kernel
|
295
|
+
compute_position_kernel[(batch_size,)](
|
296
|
+
positions,
|
297
|
+
extend_start_loc,
|
298
|
+
extend_prefix_lens,
|
299
|
+
extend_seq_lens,
|
300
|
+
)
|
301
|
+
|
302
|
+
return positions, extend_start_loc
|
303
|
+
|
304
|
+
|
305
|
+
@triton.jit
|
306
|
+
def compute_position_kernel(
|
307
|
+
positions,
|
308
|
+
extend_start_loc,
|
309
|
+
extend_prefix_lens,
|
310
|
+
extend_seq_lens,
|
311
|
+
):
|
312
|
+
BLOCK_SIZE: tl.constexpr = 512
|
313
|
+
pid = tl.program_id(0)
|
314
|
+
|
315
|
+
prefix_len = tl.load(extend_prefix_lens + pid)
|
316
|
+
seq_len = tl.load(extend_seq_lens + pid)
|
317
|
+
|
318
|
+
# TODO: optimize this?
|
319
|
+
cumsum_start = 0
|
320
|
+
for i in range(pid):
|
321
|
+
cumsum_start += tl.load(extend_seq_lens + i)
|
322
|
+
|
323
|
+
num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
|
324
|
+
for i in range(num_loop):
|
325
|
+
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
326
|
+
tl.store(
|
327
|
+
positions + cumsum_start + offset,
|
328
|
+
prefix_len + offset,
|
329
|
+
mask=offset < seq_len,
|
330
|
+
)
|
331
|
+
tl.store(extend_start_loc + pid, cumsum_start)
|
332
|
+
|
333
|
+
|
334
|
+
def compute_position_torch(
|
335
|
+
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
|
336
|
+
):
|
337
|
+
positions = torch.concat(
|
338
|
+
[
|
339
|
+
torch.arange(
|
340
|
+
prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
|
341
|
+
)
|
342
|
+
for prefix_len, extend_len in zip(extend_prefix_lens, extend_seq_lens)
|
343
|
+
],
|
344
|
+
axis=0,
|
345
|
+
)
|
346
|
+
extend_start_loc = torch.zeros_like(extend_seq_lens)
|
347
|
+
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
348
|
+
return positions.to(torch.int64), extend_start_loc
|