sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_latency.py +1 -553
  4. sglang/bench_offline_throughput.py +48 -20
  5. sglang/bench_one_batch.py +472 -0
  6. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  7. sglang/bench_serving.py +125 -6
  8. sglang/check_env.py +3 -6
  9. sglang/lang/backend/base_backend.py +1 -1
  10. sglang/lang/backend/runtime_endpoint.py +2 -2
  11. sglang/srt/configs/model_config.py +13 -14
  12. sglang/srt/constrained/__init__.py +13 -14
  13. sglang/srt/constrained/base_grammar_backend.py +13 -15
  14. sglang/srt/constrained/outlines_backend.py +28 -17
  15. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  16. sglang/srt/constrained/xgrammar_backend.py +47 -58
  17. sglang/srt/conversation.py +13 -15
  18. sglang/srt/hf_transformers_utils.py +13 -15
  19. sglang/srt/layers/activation.py +16 -13
  20. sglang/srt/layers/attention/flashinfer_backend.py +106 -54
  21. sglang/srt/layers/attention/triton_backend.py +9 -7
  22. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  23. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  24. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  25. sglang/srt/layers/custom_op_util.py +25 -0
  26. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  27. sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
  28. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  29. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  30. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  31. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  32. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  33. sglang/srt/layers/layernorm.py +17 -15
  34. sglang/srt/layers/logits_processor.py +23 -25
  35. sglang/srt/layers/quantization/__init__.py +77 -17
  36. sglang/srt/layers/radix_attention.py +13 -15
  37. sglang/srt/layers/rotary_embedding.py +13 -13
  38. sglang/srt/layers/sampler.py +4 -8
  39. sglang/srt/layers/torchao_utils.py +2 -0
  40. sglang/srt/lora/lora.py +13 -14
  41. sglang/srt/lora/lora_config.py +13 -14
  42. sglang/srt/lora/lora_manager.py +22 -24
  43. sglang/srt/managers/data_parallel_controller.py +98 -27
  44. sglang/srt/managers/detokenizer_manager.py +13 -15
  45. sglang/srt/managers/io_struct.py +63 -21
  46. sglang/srt/managers/schedule_batch.py +154 -59
  47. sglang/srt/managers/schedule_policy.py +18 -16
  48. sglang/srt/managers/scheduler.py +278 -109
  49. sglang/srt/managers/session_controller.py +61 -0
  50. sglang/srt/managers/tokenizer_manager.py +63 -18
  51. sglang/srt/managers/tp_worker.py +25 -16
  52. sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
  53. sglang/srt/metrics/collector.py +13 -15
  54. sglang/srt/metrics/func_timer.py +13 -15
  55. sglang/srt/mm_utils.py +13 -14
  56. sglang/srt/model_executor/cuda_graph_runner.py +63 -25
  57. sglang/srt/model_executor/forward_batch_info.py +128 -32
  58. sglang/srt/model_executor/model_runner.py +132 -64
  59. sglang/srt/model_parallel.py +98 -0
  60. sglang/srt/models/chatglm.py +15 -16
  61. sglang/srt/models/commandr.py +15 -16
  62. sglang/srt/models/dbrx.py +15 -16
  63. sglang/srt/models/deepseek.py +15 -15
  64. sglang/srt/models/deepseek_v2.py +162 -59
  65. sglang/srt/models/exaone.py +14 -15
  66. sglang/srt/models/gemma.py +14 -14
  67. sglang/srt/models/gemma2.py +31 -25
  68. sglang/srt/models/gemma2_reward.py +13 -14
  69. sglang/srt/models/gpt_bigcode.py +14 -14
  70. sglang/srt/models/grok.py +15 -15
  71. sglang/srt/models/internlm2.py +13 -15
  72. sglang/srt/models/internlm2_reward.py +13 -14
  73. sglang/srt/models/llama.py +21 -21
  74. sglang/srt/models/llama_classification.py +13 -14
  75. sglang/srt/models/llama_reward.py +13 -14
  76. sglang/srt/models/llava.py +14 -16
  77. sglang/srt/models/llavavid.py +14 -16
  78. sglang/srt/models/minicpm.py +13 -15
  79. sglang/srt/models/minicpm3.py +13 -15
  80. sglang/srt/models/mistral.py +13 -15
  81. sglang/srt/models/mixtral.py +15 -15
  82. sglang/srt/models/mixtral_quant.py +14 -14
  83. sglang/srt/models/olmo.py +22 -20
  84. sglang/srt/models/olmoe.py +23 -20
  85. sglang/srt/models/phi3_small.py +447 -0
  86. sglang/srt/models/qwen.py +14 -14
  87. sglang/srt/models/qwen2.py +22 -19
  88. sglang/srt/models/qwen2_moe.py +17 -18
  89. sglang/srt/models/qwen2_vl.py +13 -6
  90. sglang/srt/models/stablelm.py +18 -16
  91. sglang/srt/models/torch_native_llama.py +107 -93
  92. sglang/srt/models/xverse.py +13 -14
  93. sglang/srt/models/xverse_moe.py +15 -16
  94. sglang/srt/models/yivl.py +13 -15
  95. sglang/srt/openai_api/adapter.py +19 -17
  96. sglang/srt/openai_api/protocol.py +14 -16
  97. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  98. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  99. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  100. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  101. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  102. sglang/srt/sampling/sampling_batch_info.py +61 -57
  103. sglang/srt/sampling/sampling_params.py +14 -16
  104. sglang/srt/server.py +86 -35
  105. sglang/srt/server_args.py +96 -80
  106. sglang/srt/utils.py +266 -68
  107. sglang/test/few_shot_gsm8k.py +8 -4
  108. sglang/test/runners.py +38 -20
  109. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  110. sglang/test/test_utils.py +31 -20
  111. sglang/version.py +1 -1
  112. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  113. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
  114. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  115. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
  116. sglang/srt/layers/fused_moe/__init__.py +0 -1
  117. sglang-0.3.5.post2.dist-info/RECORD +0 -156
  118. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,20 @@
1
- from __future__ import annotations
2
-
3
- """
4
- Copyright 2023-2024 SGLang Team
5
- Licensed under the Apache License, Version 2.0 (the "License");
6
- you may not use this file except in compliance with the License.
7
- You may obtain a copy of the License at
8
-
9
- http://www.apache.org/licenses/LICENSE-2.0
10
-
11
- Unless required by applicable law or agreed to in writing, software
12
- distributed under the License is distributed on an "AS IS" BASIS,
13
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- See the License for the specific language governing permissions and
15
- limitations under the License.
16
- """
17
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
18
14
  """Run the model with cuda graph and torch.compile."""
19
15
 
16
+ from __future__ import annotations
17
+
20
18
  import bisect
21
19
  from contextlib import contextmanager
22
20
  from typing import TYPE_CHECKING, Callable
@@ -25,7 +23,7 @@ import torch
25
23
  from vllm.distributed.parallel_state import graph_capture
26
24
  from vllm.model_executor.custom_op import CustomOp
27
25
 
28
- from sglang.srt.layers.fused_moe.patch import fused_moe_forward_native
26
+ from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native
29
27
  from sglang.srt.layers.logits_processor import (
30
28
  LogitsMetadata,
31
29
  LogitsProcessor,
@@ -67,7 +65,10 @@ def patch_model(
67
65
  _to_torch(model)
68
66
  monkey_patch_vllm_all_gather()
69
67
  backup_ca_comm = tp_group.ca_comm
70
- tp_group.ca_comm = None
68
+ # Use custom-allreduce here.
69
+ # We found the custom allreduce is much faster than the built-in allreduce in torch,
70
+ # even with ENABLE_INTRA_NODE_COMM=1.
71
+ # tp_group.ca_comm = None
71
72
  yield torch.compile(
72
73
  torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
73
74
  )
@@ -90,6 +91,8 @@ def set_torch_compile_config():
90
91
 
91
92
  # FIXME: tmp workaround
92
93
  torch._dynamo.config.accumulated_cache_size_limit = 1024
94
+ if hasattr(torch._dynamo.config, "cache_size_limit"):
95
+ torch._dynamo.config.cache_size_limit = 1024
93
96
 
94
97
 
95
98
  @maybe_torch_compile(dynamic=True)
@@ -111,6 +114,8 @@ class CudaGraphRunner:
111
114
  self.use_torch_compile = model_runner.server_args.enable_torch_compile
112
115
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
113
116
  self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
117
+ self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
118
+ self.tp_size = self.model_runner.tp_size
114
119
 
115
120
  # Batch sizes to capture
116
121
  if model_runner.server_args.disable_cuda_graph_padding:
@@ -165,6 +170,15 @@ class CudaGraphRunner:
165
170
  else:
166
171
  self.encoder_lens = None
167
172
 
173
+ if self.enable_dp_attention:
174
+ self.gathered_buffer = torch.zeros(
175
+ (
176
+ self.max_bs * self.tp_size,
177
+ self.model_runner.model_config.hidden_size,
178
+ ),
179
+ dtype=self.model_runner.dtype,
180
+ )
181
+
168
182
  # Capture
169
183
  try:
170
184
  with self.model_capture_mode():
@@ -190,11 +204,21 @@ class CudaGraphRunner:
190
204
  self.model_runner.model.capture_mode = False
191
205
 
192
206
  def can_run(self, forward_batch: ForwardBatch):
193
- is_bs_supported = (
194
- forward_batch.batch_size in self.graphs
195
- if self.disable_padding
196
- else forward_batch.batch_size <= self.max_bs
197
- )
207
+ if self.enable_dp_attention:
208
+ min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
209
+ forward_batch.global_num_tokens
210
+ )
211
+ is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
212
+ (min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
213
+ if self.disable_padding
214
+ else max_num_tokens <= self.max_bs
215
+ )
216
+ else:
217
+ is_bs_supported = (
218
+ forward_batch.batch_size in self.graphs
219
+ if self.disable_padding
220
+ else forward_batch.batch_size <= self.max_bs
221
+ )
198
222
 
199
223
  # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
200
224
  # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
@@ -239,6 +263,13 @@ class CudaGraphRunner:
239
263
  seq_lens_sum = seq_lens.sum().item()
240
264
  mrope_positions = self.mrope_positions[:, :bs]
241
265
 
266
+ if self.enable_dp_attention:
267
+ global_num_tokens = [bs] * self.tp_size
268
+ gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
269
+ else:
270
+ global_num_tokens = None
271
+ gathered_buffer = None
272
+
242
273
  # Attention backend
243
274
  self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
244
275
  bs,
@@ -265,6 +296,8 @@ class CudaGraphRunner:
265
296
  top_logprobs_nums=[0] * bs,
266
297
  positions=clamp_position(seq_lens),
267
298
  mrope_positions=mrope_positions,
299
+ global_num_tokens=global_num_tokens,
300
+ gathered_buffer=gathered_buffer,
268
301
  )
269
302
  logits_output = forward(input_ids, forward_batch.positions, forward_batch)
270
303
  return logits_output.next_token_logits
@@ -295,7 +328,12 @@ class CudaGraphRunner:
295
328
  raw_bs = forward_batch.batch_size
296
329
 
297
330
  # Pad
298
- index = bisect.bisect_left(self.capture_bs, raw_bs)
331
+ if self.enable_dp_attention:
332
+ index = bisect.bisect_left(
333
+ self.capture_bs, max(forward_batch.global_num_tokens)
334
+ )
335
+ else:
336
+ index = bisect.bisect_left(self.capture_bs, raw_bs)
299
337
  bs = self.capture_bs[index]
300
338
  if bs != raw_bs:
301
339
  self.seq_lens.fill_(1)
@@ -1,20 +1,16 @@
1
- from __future__ import annotations
2
-
3
- """
4
- Copyright 2023-2024 SGLang Team
5
- Licensed under the Apache License, Version 2.0 (the "License");
6
- you may not use this file except in compliance with the License.
7
- You may obtain a copy of the License at
8
-
9
- http://www.apache.org/licenses/LICENSE-2.0
10
-
11
- Unless required by applicable law or agreed to in writing, software
12
- distributed under the License is distributed on an "AS IS" BASIS,
13
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- See the License for the specific language governing permissions and
15
- limitations under the License.
16
- """
17
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
18
14
  """
19
15
  Store information about a forward batch.
20
16
 
@@ -31,11 +27,15 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
31
27
  It contains low-level tensor data. Most of the data consists of GPU tensors.
32
28
  """
33
29
 
30
+ from __future__ import annotations
31
+
34
32
  from dataclasses import dataclass
35
33
  from enum import IntEnum, auto
36
34
  from typing import TYPE_CHECKING, List, Optional
37
35
 
38
36
  import torch
37
+ import triton
38
+ import triton.language as tl
39
39
 
40
40
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
41
 
@@ -50,12 +50,18 @@ if TYPE_CHECKING:
50
50
  class ForwardMode(IntEnum):
51
51
  # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
52
52
  PREFILL = auto()
53
- # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
53
+ # Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
54
54
  EXTEND = auto()
55
55
  # Decode one token.
56
56
  DECODE = auto()
57
- # Contains both EXTEND and DECODE.
57
+ # Contains both EXTEND and DECODE when doing chunked prefill.
58
58
  MIXED = auto()
59
+ # No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
60
+ IDLE = auto()
61
+
62
+ # A dummy first batch to start the pipeline for overlap scheduler.
63
+ # It is now used for triggering the sampling_info_done event for the first prefill batch.
64
+ DUMMY_FIRST = auto()
59
65
 
60
66
  def is_prefill(self):
61
67
  return self == ForwardMode.PREFILL
@@ -69,6 +75,12 @@ class ForwardMode(IntEnum):
69
75
  def is_mixed(self):
70
76
  return self == ForwardMode.MIXED
71
77
 
78
+ def is_idle(self):
79
+ return self == ForwardMode.IDLE
80
+
81
+ def is_dummy_first(self):
82
+ return self == ForwardMode.DUMMY_FIRST
83
+
72
84
 
73
85
  @dataclass
74
86
  class ForwardBatch:
@@ -102,6 +114,7 @@ class ForwardBatch:
102
114
  extend_seq_lens: Optional[torch.Tensor] = None
103
115
  extend_prefix_lens: Optional[torch.Tensor] = None
104
116
  extend_start_loc: Optional[torch.Tensor] = None
117
+ extend_prefix_lens_cpu: Optional[List[int]] = None
105
118
  extend_seq_lens_cpu: Optional[List[int]] = None
106
119
  extend_logprob_start_lens_cpu: Optional[List[int]] = None
107
120
 
@@ -117,6 +130,9 @@ class ForwardBatch:
117
130
  # For LoRA
118
131
  lora_paths: Optional[List[str]] = None
119
132
 
133
+ # For input embeddings
134
+ input_embeds: Optional[torch.tensor] = None
135
+
120
136
  # Sampling info
121
137
  sampling_info: SamplingBatchInfo = None
122
138
 
@@ -128,6 +144,11 @@ class ForwardBatch:
128
144
  # For Qwen2-VL
129
145
  mrope_positions: torch.Tensor = None
130
146
 
147
+ # For DP attention
148
+ global_num_tokens: Optional[List[int]] = None
149
+ gathered_buffer: Optional[torch.Tensor] = None
150
+ can_run_dp_cuda_graph: bool = False
151
+
131
152
  def compute_mrope_positions(
132
153
  self, model_runner: ModelRunner, batch: ModelWorkerBatch
133
154
  ):
@@ -209,31 +230,37 @@ class ForwardBatch:
209
230
  seq_lens_sum=batch.seq_lens_sum,
210
231
  return_logprob=batch.return_logprob,
211
232
  top_logprobs_nums=batch.top_logprobs_nums,
233
+ global_num_tokens=batch.global_num_tokens,
234
+ can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
212
235
  lora_paths=batch.lora_paths,
213
236
  sampling_info=batch.sampling_info,
237
+ input_embeds=batch.input_embeds,
214
238
  )
215
239
 
240
+ if ret.global_num_tokens is not None:
241
+ max_len = max(ret.global_num_tokens)
242
+ ret.gathered_buffer = torch.zeros(
243
+ (max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
244
+ dtype=model_runner.dtype,
245
+ device=device,
246
+ )
247
+
248
+ if ret.forward_mode.is_idle():
249
+ return ret
250
+
216
251
  # Init position information
217
252
  if not ret.forward_mode.is_decode():
218
- ret.positions = torch.concat(
219
- [
220
- torch.arange(prefix_len, prefix_len + extend_len, device=device)
221
- for prefix_len, extend_len in zip(
222
- batch.extend_prefix_lens, batch.extend_seq_lens
223
- )
224
- ],
225
- axis=0,
226
- )
227
- ret.extend_num_tokens = batch.extend_num_tokens
228
253
  ret.extend_seq_lens = torch.tensor(
229
254
  batch.extend_seq_lens, dtype=torch.int32
230
255
  ).to(device, non_blocking=True)
231
-
232
256
  ret.extend_prefix_lens = torch.tensor(
233
257
  batch.extend_prefix_lens, dtype=torch.int32
234
258
  ).to(device, non_blocking=True)
235
- ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
236
- ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
259
+ ret.extend_num_tokens = batch.extend_num_tokens
260
+ ret.positions, ret.extend_start_loc = compute_position_triton(
261
+ ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
262
+ )
263
+ ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
237
264
  ret.extend_seq_lens_cpu = batch.extend_seq_lens
238
265
  ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
239
266
 
@@ -250,3 +277,72 @@ class ForwardBatch:
250
277
  model_runner.lora_manager.prepare_lora_batch(ret)
251
278
 
252
279
  return ret
280
+
281
+
282
+ def compute_position_triton(
283
+ extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
284
+ ):
285
+ """Compute positions. It is a fused version of `compute_position_torch`."""
286
+ batch_size = extend_seq_lens.shape[0]
287
+ positions = torch.empty(
288
+ extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
289
+ )
290
+ extend_start_loc = torch.empty(
291
+ batch_size, dtype=torch.int32, device=extend_seq_lens.device
292
+ )
293
+
294
+ # Launch kernel
295
+ compute_position_kernel[(batch_size,)](
296
+ positions,
297
+ extend_start_loc,
298
+ extend_prefix_lens,
299
+ extend_seq_lens,
300
+ )
301
+
302
+ return positions, extend_start_loc
303
+
304
+
305
+ @triton.jit
306
+ def compute_position_kernel(
307
+ positions,
308
+ extend_start_loc,
309
+ extend_prefix_lens,
310
+ extend_seq_lens,
311
+ ):
312
+ BLOCK_SIZE: tl.constexpr = 512
313
+ pid = tl.program_id(0)
314
+
315
+ prefix_len = tl.load(extend_prefix_lens + pid)
316
+ seq_len = tl.load(extend_seq_lens + pid)
317
+
318
+ # TODO: optimize this?
319
+ cumsum_start = 0
320
+ for i in range(pid):
321
+ cumsum_start += tl.load(extend_seq_lens + i)
322
+
323
+ num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
324
+ for i in range(num_loop):
325
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
326
+ tl.store(
327
+ positions + cumsum_start + offset,
328
+ prefix_len + offset,
329
+ mask=offset < seq_len,
330
+ )
331
+ tl.store(extend_start_loc + pid, cumsum_start)
332
+
333
+
334
+ def compute_position_torch(
335
+ extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
336
+ ):
337
+ positions = torch.concat(
338
+ [
339
+ torch.arange(
340
+ prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
341
+ )
342
+ for prefix_len, extend_len in zip(extend_prefix_lens, extend_seq_lens)
343
+ ],
344
+ axis=0,
345
+ )
346
+ extend_start_loc = torch.zeros_like(extend_seq_lens)
347
+ extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
348
+ return positions.to(torch.int64), extend_start_loc